In [4]:
# In[ ]:
%matplotlib inline

import os
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

# ────────────────────────────────────────────────────────────────────────────────
# CONFIGURATION
# ────────────────────────────────────────────────────────────────────────────────
DATASETS = ["dynahate", "hatecheck"]
TOP_X = 5

def plot_single_lime_words(words_weights, title, top_x=TOP_X):
    """
    Create a horizontal barplot of the top_x words by normalized weight.
    Labels and title are in bold.
    """
    sns.set(style="whitegrid", context="talk", font_scale=1.1)
    top = words_weights[:top_x]
    if not top:
        raise ValueError(f"No words to plot for '{title}'")
    words, weights = zip(*top)

    fig, ax = plt.subplots(figsize=(8, 6))
    palette = sns.color_palette("Blues", len(weights))
    sns.barplot(
        x=list(weights),
        y=list(words),
        palette=palette,
        edgecolor=".2",
        ax=ax
    )
    ax.set_title(title, fontweight="bold", fontsize=16)
    ax.set_xlabel("Normalized Weight", fontweight="bold")
    ax.set_ylabel("Word", fontweight="bold")
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontweight("bold")
    plt.tight_layout()
    return fig

for ds in DATASETS:
    # Paths to pickle files
    PKL_FILES = {
        "positive_pre":  f"./results/positive_words_pre_FT_{ds}.pkl",
        "negative_pre":  f"./results/negative_words_pre_FT_{ds}.pkl",
        "positive_post": f"./results/positive_words_post_FT_{ds}.pkl",
        "negative_post": f"./results/negative_words_post_FT_{ds}.pkl",
    }

    # Output directory for this dataset
    OUTPUT_DIR = f"./results/{ds}/barplots"
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # ────────────────────────────────────────────────────────────────────────────
    # INDIVIDUAL BARPLOTS
    # ────────────────────────────────────────────────────────────────────────────
    for name, path in PKL_FILES.items():
        if not os.path.isfile(path):
            print(f"Warning: file not found: {path}")
            continue

        with open(path, "rb") as f:
            words_weights = pickle.load(f)

        cls = "Hate" if name.startswith("positive") else "No-Hate"
        stage = "Pre-Fine-Tuning" if "pre" in name else "Post-Fine-Tuning"
        title = f"Top {TOP_X} {cls} Words — {stage} ({ds})"

        fig = plot_single_lime_words(words_weights, title)
        out_path = os.path.join(OUTPUT_DIR, f"{ds}_{name}_top_{TOP_X}.png")
        fig.savefig(out_path, dpi=600)
        plt.show()
        plt.close(fig)
        print(f"Saved plot: {out_path}")

    # ────────────────────────────────────────────────────────────────────────────
    # COMPARISON PLOT FOR POSITIVE PRE vs POST FINE-TUNING
    # ────────────────────────────────────────────────────────────────────────────
    # Load both pickle files
    pre_path  = PKL_FILES["positive_pre"]
    post_path = PKL_FILES["positive_post"]
    if os.path.isfile(pre_path) and os.path.isfile(post_path):
        with open(pre_path,  "rb") as f:
            pos_pre  = pickle.load(f)
        with open(post_path, "rb") as f:
            pos_post = pickle.load(f)

        pos_pre_top  = pos_pre[:TOP_X]
        pos_post_top = pos_post[:TOP_X]

        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        for ax, data, subtitle in zip(
            axes,
            [pos_pre_top, pos_post_top],
            ["Pre-Fine-Tuning", "Post-Fine-Tuning"]
        ):
            words, weights = zip(*data)
            palette = sns.color_palette("Blues", len(weights))
            sns.barplot(
                x=list(weights),
                y=list(words),
                palette=palette,
                edgecolor=".2",
                ax=ax
            )
            ax.set_title(f"Top {TOP_X} Hate Words — {subtitle} ({ds})", fontweight="bold", fontsize=16)
            ax.set_xlabel("Normalized Weight", fontweight="bold")
            ax.set_ylabel("Word", fontweight="bold")
            for label in ax.get_xticklabels() + ax.get_yticklabels():
                label.set_fontweight("bold")

        plt.tight_layout()
        comparison_path = os.path.join(
            OUTPUT_DIR,
            f"{ds}_positive_pre_post_comparison_top_{TOP_X}.png"
        )
        fig.savefig(comparison_path, dpi=600)
        plt.show()
        plt.close(fig)
        print(f"Saved comparison plot: {comparison_path}")
    else:
        print(f"Skipping comparison for {ds}: missing pickle files")



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


Saved plot: ./results/plots/barplots/positive_pre_top_5.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


Saved plot: ./results/plots/barplots/negative_pre_top_5.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


Saved plot: ./results/plots/barplots/positive_post_top_5.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


Saved plot: ./results/plots/barplots/negative_post_top_5.png



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


Saved comparison plot: ./results/plots/barplots/positive_pre_post_comparison_top_5.png
