In [None]:
%load_ext jupyter_black

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import os
import pandas as pd
import pickle
import seaborn as sns
import warnings

from eval_conversations import detect_indirect_q_answer, detect_direct_q_answer
from utils import answer_detect_map, direct_questions, indirect_questions

warnings.filterwarnings("ignore")

In [None]:
id_to_val = {
    "gender": {0: "non-binary", 1: "female", 2: "male", 3: "none"},
    "race": {0: "black", 1: "white", 2: "asian", 3: "hispanic", 4: "none"},
    "socio-economic status": {0: "high", 1: "low", 2: "none"},
    "age": {0: "child", 1: "teenager", 2: "adult", 3: "older adult", 4: "none"},
}

id_to_val_questions = {
    "gender": {0: "non-binary", 1: "female", 2: "male", 3: "mixed", 4: "none"},
    "race": {0: "black", 1: "white", 2: "asian", 3: "hispanic", 4: "mixed", 5: "none"},
    "socio-economic status": {0: "high", 1: "low", 2: "mixed", 3: "none"},
    "age": {
        0: "child",
        1: "teenager",
        2: "adult",
        3: "older adult",
        4: "mixed",
        5: "none",
    },
}

# Final results

## Qualitative analysis

In [None]:
with open("qual_analysis.txt", "a") as f:
    for model in ["OLMo-2-1124-7B-Instruct", "gemma-2-9b-it", "Llama-3.1-8B-Instruct"]:
        for demographic in ["age", "gender", "race", "socio-economic status"]:
            with open(
                f"results/convos_{demographic}_250_{model}.pkl", "rb"
            ) as infile:
                convos = pickle.load(infile)
            for group in convos:
                if "neutral_demo" in convos[group]:
                    for i in np.random.randint(250, size=3):
                        f.write(
                            f"{model},{demographic},{group},{detect_direct_q_answer(
                                answer_detect_map[demographic],
                                convos[group]['neutral_demo'][6]['direct_question'][i],
                                demographic,) if detect_direct_q_answer(
                                answer_detect_map[demographic],
                                convos[group]['neutral_demo'][6]['direct_question'][i],
                                demographic,) else 'none'} \n"
                        )
                        f.write(
                            convos[group]["neutral_demo"][6]["direct_question"][i]
                            + "\n"
                        )
                        f.write(
                            "----------------------------------------------------------------\n"
                        )
                        for j in range(5):
                            f.write(
                                f"{model},{demographic},{group},{detect_indirect_q_answer(
                                    answer_detect_map[demographic],
                                    indirect_questions[demographic][j],
                                    convos[group]['neutral_demo'][6]['indirect_question'][
                                        j
                                    ][i],
                                    demographic,
                                ) if detect_indirect_q_answer(
                                    answer_detect_map[demographic],
                                    indirect_questions[demographic][j],
                                    convos[group]['neutral_demo'][6]['indirect_question'][
                                        j
                                    ][i],
                                    demographic,
                                ) else 'none'} \n"
                            )
                            f.write(
                                convos[group]["neutral_demo"][6]["indirect_question"][
                                    j
                                ][i]
                                + "\n"
                            )
                            f.write(
                                "----------------------------------------------------------------\n"
                            )

## Example conversation

In [None]:
with open(f"results/convos_gender_250_gemma-2-9b-it.pkl", "rb") as infile:
    gemma_gender = pickle.load(infile)

In [None]:
i = 42
print(gemma_gender["male"]["anti_demo"]["female"]["conversation"][i])
for t in [6]:
    print("direct")
    print(gemma_gender["male"]["anti_demo"]["female"][t]["direct_question"][i])
    print(gemma_gender["male"]["anti_demo"]["female"][t]["mod_direct_question"][i])
    for j in range(5):
        print("indirect")
        print(gemma_gender["male"]["anti_demo"]["female"][t]["indirect_question"][j][i])
        print(
            gemma_gender["male"]["anti_demo"]["female"][t]["mod_indirect_question"][j][
                i
            ]
        )

## Conversations

In [None]:
olmo = {}
llama = {}
gemma = {}

for demographic in id_to_val:
    with open(
            f"results/results_{demographic}_250_OLMo-2-1124-7B-Instruct.pkl",
            "rb",
        ) as infile:
            olmo[demographic] = pickle.load(infile)
    with open(
            f"results/results_{demographic}_250_Llama-3.1-8B-Instruct.pkl",
            "rb",
        ) as infile:
            llama[demographic] = pickle.load(infile)
    with open(
            f"results/results_{demographic}_250_gemma-2-9b-it.pkl", "rb"
        ) as infile:
            gemma[demographic] = pickle.load(infile)

In [None]:
def get_df(probe, demographic, description, description2="", mod_col=False):
    source = []
    turn = []
    layer_n = []
    group = []
    val = []
    descr = []
    descr2 = []
    if mod_col:
        mod = []
    for i in [0, 1, 3, 6]:
        for k, layer in enumerate(probe[i]["probe"]):
            for l, entry in enumerate(probe[i]["probe"][k]):
                turn.append(i)
                source.append("Probe")
                layer_n.append(k)
                group.append(id_to_val[demographic][l])
                val.append(entry)
                descr.append(description)
                if mod_col:
                    mod.append("No")
                if description2:
                    descr2.append(description2)
        for l, entry in enumerate(probe[i]["surprisal"].T):
            turn.append(i)
            source.append("Surprisal")
            layer_n.append(-1)
            group.append(id_to_val[demographic][l])
            val.append(entry)
            descr.append(description)
            if mod_col:
                mod.append("No")
            if description2:
                descr2.append(description2)
        for l, entry in enumerate(probe[i]["mod_surprisal"].T):
            turn.append(i)
            source.append("Surprisal")
            layer_n.append(-1)
            group.append(id_to_val[demographic][l])
            val.append(entry)
            if mod_col:
                mod.append("Yes")
                descr.append(description)
            else:
                descr.append(description + " mod")
            if description2:
                descr2.append(description2)
        for l, entry in enumerate(probe[i]["direct_question"]):
            turn.append(i)
            source.append("Direct question")
            group.append(id_to_val_questions[demographic][l])
            val.append(entry)
            layer_n.append(-1)
            descr.append(description)
            if mod_col:
                mod.append("No")
            if description2:
                descr2.append(description2)
        for l in range(len(probe[i]["indirect_question"][0])):
            turn.append(i)
            source.append("Indirect question")
            group.append(id_to_val_questions[demographic][l])
            val.append(probe[i]["indirect_question"].sum(axis=0)[l])
            layer_n.append(-1)
            descr.append(description)
            if mod_col:
                mod.append("No")
            if description2:
                descr2.append(description2)
        for l, entry in enumerate(probe[i]["mod_direct_question"]):
            turn.append(i)
            source.append("Direct question")
            group.append(id_to_val_questions[demographic][l])
            val.append(entry)
            layer_n.append(-1)
            if mod_col:
                mod.append("Yes")
                descr.append(description)
            else:
                descr.append(description + " mod")
            if description2:
                descr2.append(description2)
        for l in range(len(probe[i]["mod_indirect_question"][0])):
            turn.append(i)
            source.append("Indirect question")
            group.append(id_to_val_questions[demographic][l])
            val.append(probe[i]["mod_indirect_question"].sum(axis=0)[l])
            layer_n.append(-1)
            if mod_col:
                descr.append(description)
                mod.append("Yes")
            else:
                descr.append(description + " mod")
            if description2:
                descr2.append(description2)

    data = {
        "source": source,
        "turn": turn,
        "group": group,
        "val": val,
        "layer": layer_n,
        "descr": descr,
    }
    if description2:
        data["descr2"] = descr2
    if mod_col:
        data["mod"] = mod
    return pd.DataFrame(data)

In [None]:
from scipy.stats import ttest_ind, chi2_contingency

plt.rcParams.update({"font.size": 15, "legend.fontsize": 15, "legend.handlelength": 2})


def get_line_plot_rq3(df, file_name, mult=1):
    df[["from", "to"]] = pd.DataFrame(
        df.descr.str.split(" -> ").tolist(), index=df.index
    )
    df["val"] = df["val"] / (250 * mult) * 100
    df = df.sort_values(by="descr2").rename(columns={"descr2": "Attribute"})
    df["type"] = ""
    df.loc[(df["from"] == df["group"]), "type"] = "Introduction"
    df.loc[(df["to"] == df["group"]), "type"] = "Stereotypes"
    df = df.loc[df["type"] != ""]
    df = df.drop(columns=["descr", "from", "to"])
    df = df.rename(columns={"type": "Group", "mod": "Mitigation"})
    for demographic in df["Attribute"].unique():
        df_demo = df.loc[df["Attribute"] == demographic]
        if "Mitigation" in df:
            ax = sns.lineplot(
                data=df_demo, x="turn", y="val", hue="Group", style="Mitigation"
            )
            plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
            plt.xlabel("User turn")
            plt.ylabel("Accuracy")
            plt.savefig(
                file_name
                + "_"
                + demographic.replace(" ", "-")
                + "_shaded_mitigation.pdf",
                pad_inches=0,
                bbox_inches="tight",
            )
            plt.show()
            ax = sns.lineplot(
                data=df_demo[df_demo["Mitigation"] == "No"],
                x="turn",
                y="val",
                hue="Group",
            )
        else:
            ax = sns.lineplot(data=df_demo, x="turn", y="val", hue="Group")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.xlabel("User turn")
        plt.ylabel("Accuracy")
        plt.ylim(bottom=0, top=100)
        plt.savefig(
            file_name + "_" + demographic.replace(" ", "-") + "_shaded.pdf",
            pad_inches=0,
            bbox_inches="tight",
        )
        plt.show()


def get_line_plot_rq2(df, file_name, mult=1, descr2_name="Attribute"):
    df_neutral = df[df["descr"] == "neutral"]
    df_neutral = df_neutral[~df_neutral["group"].isin(["none", "mixed"])].drop(
        columns=["descr"]
    )
    df = df[df["descr"] == df["group"]].drop(columns=["descr"])
    groups = ["descr2", "group", "turn"]
    if "mod" in df:
        groups.append("mod")
    df = df.merge(df_neutral, how="left", on=groups)
    df["val"] = (df["val_x"] - df["val_y"]) / (250 * mult) * 100
    df = df.drop(columns=["val_x", "val_y"])
    df = df.sort_values(by="descr2").rename(
        columns={"descr2": descr2_name, "mod": "Mitigation"}
    )
    for group in df["group"].unique():
        df_group = df[df["group"] == group]
        if "Mitigation" in df_group:
            ax = sns.lineplot(
                data=df_group, x="turn", y="val", hue=descr2_name, style="Mitigation"
            )
            plt.savefig(
                file_name + "_" + group + "_shaded_mitigation.pdf",
                pad_inches=0,
                bbox_inches="tight",
            )
            plt.xlabel("User turn")
            plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
            plt.ylabel("Δ Accuracy")
            plt.show()
            ax = sns.lineplot(
                data=df_group[df_group["Mitigation"] == "No"],
                x="turn",
                y="val",
                hue=descr2_name,
            )
        else:
            ax = sns.lineplot(data=df_group, x="turn", y="val", hue=descr2_name)
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("Δ Accuracy")
        plt.savefig(
            file_name + "_" + group + "_shaded.pdf", pad_inches=0, bbox_inches="tight"
        )
        plt.show()


def get_line_plot(df, file_name, mult=1, descr2_name="Attribute"):
    df = df[df["descr"] == df["group"]].drop(columns=["descr"])
    df["val"] = df["val"] / (250 * mult) * 100
    df = df.sort_values(by="descr2").rename(
        columns={"descr2": descr2_name, "mod": "Mitigation"}
    )
    if "Mitigation" in df:
        ax = sns.lineplot(
            data=df, x="turn", y="val", hue=descr2_name, style="Mitigation"
        )
        plt.savefig(
            file_name + "_shaded_mitigation.pdf", pad_inches=0, bbox_inches="tight"
        )
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("Accuracy")
        plt.show()
        ax = sns.lineplot(
            data=df[df["Mitigation"] == "No"], x="turn", y="val", hue=descr2_name
        )
    else:
        ax = sns.lineplot(data=df, x="turn", y="val", hue=descr2_name)
    plt.xlabel("User turn")
    plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
    plt.ylabel("Accuracy")
    plt.ylim(bottom=0, top=100)
    plt.savefig(file_name + "_shaded.pdf", pad_inches=0, bbox_inches="tight")
    plt.show()


def get_line_plot_surprisal_rq3(df, file_name):

    df[["from", "to"]] = pd.DataFrame(
        df.descr.str.split(" -> ").tolist(), index=df.index
    )
    df["val"] = df["val"] / 250 * 100
    df = df.sort_values(by="descr2").rename(columns={"descr2": "Attribute"})
    df["type"] = ""
    df.loc[(df["from"] == df["group"]), "type"] = "Introduction"
    df.loc[(df["to"] == df["group"]), "type"] = "Stereotypes"
    # df.loc[
    #     (df["group"] != df["from"]) & (df["group"] != df["to"]),
    #     "type",
    # ] = "Neither"
    df = df.rename(columns={"type": "Group", "mod": "Mitigation"})
    dfs = {}
    for demographic in df["Attribute"].unique():
        df_demo = df.loc[df["Attribute"] == demographic]
        df_demo[[str(i) for i in range(250)]] = pd.DataFrame(
            df_demo.val.tolist(), index=df_demo.index
        )
        df_demo = df_demo.drop(columns="val")
        df_demo[[str(i) for i in range(250)]] = df_demo[
            [str(i) for i in range(250)]
        ] == df_demo.groupby(["turn", "descr", "Attribute", "Mitigation"]).transform(
            min
        ).drop(
            columns=["Group", "group", "to", "from"]
        )
        df_demo["val"] = df_demo[[str(i) for i in range(250)]].sum(axis=1) / 250 * 100
        df_demo = df_demo.drop(columns=[str(i) for i in range(250)])
        df_demo = df_demo.sort_values(by="Attribute")
        df_demo = df_demo[df_demo["Group"] != ""]
        dfs[demographic] = df_demo
        df_demo = df_demo.drop(columns=["from", "to"])

        ax = sns.lineplot(
            data=df_demo, x="turn", y="val", hue="Group", style="Mitigation"
        )
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("% of lowest surprisal values")
        plt.savefig(
            file_name + f"_{demographic.replace(" ", "-")}_shaded_mitigation.pdf",
            pad_inches=0,
            bbox_inches="tight",
        )
        plt.show()
        ax = sns.lineplot(
            data=df_demo[df_demo["Mitigation"] == "No"], x="turn", y="val", hue="Group"
        )
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("% of lowest surprisal values")
        plt.ylim(bottom=0, top=100)
        plt.savefig(
            file_name + f"_{demographic.replace(" ", "-")}_shaded.pdf",
            pad_inches=0,
            bbox_inches="tight",
        )
        plt.show()
    return dfs


def get_line_plot_surprisal_rq2(df, file_name, descr2_name="Attribute"):
    df = df[~df["descr"].str.contains(" mod")]
    df_neutral = df[df["descr"] == "neutral"]
    df = df[df["descr"] != "neutral"]

    df[[str(i) for i in range(250)]] = pd.DataFrame(df.val.tolist(), index=df.index)
    df = df.drop(columns="val")
    df[[str(i) for i in range(250)]] = df[[str(i) for i in range(250)]] == df.groupby(
        ["turn", "descr", "descr2", "mod"]
    ).transform(min).drop(columns="group")

    df_neutral[[str(i) for i in range(250)]] = pd.DataFrame(
        df_neutral.val.tolist(), index=df_neutral.index
    )
    df_neutral = df_neutral.drop(columns="val")
    df_neutral[[str(i) for i in range(250)]] = df_neutral[
        [str(i) for i in range(250)]
    ] == df_neutral.groupby(["turn", "descr", "descr2", "mod"]).transform(min).drop(
        columns="group"
    )
    df_neutral = df_neutral.drop(columns=["descr"])
    df_neutral["val"] = df_neutral[[str(i) for i in range(250)]].sum(axis=1)
    df_neutral = df_neutral.drop(columns=[str(i) for i in range(250)])

    df = df[df["descr"] == df["group"]].drop(columns=["descr"])
    df["val"] = df[[str(i) for i in range(250)]].sum(axis=1)
    df = df.drop(columns=[str(i) for i in range(250)])

    df = df.merge(df_neutral, how="left", on=["descr2", "group", "turn", "mod"])
    df_contingency = df
    df_contingency["not_val_x"] = 250 - df["val_x"]
    df_contingency["not_val_y"] = 250 - df["val_y"]

    df["val"] = (df["val_x"] - df["val_y"]) / 250 * 100
    df = df.drop(columns=["val_x", "val_y"])
    df = df.sort_values(by="descr2").rename(
        columns={"descr2": descr2_name, "mod": "Mitigation"}
    )
    for group in df["group"].unique():
        df_group = df[df["group"] == group]
        for model in df_contingency["descr2"].unique():
            for turn in df_contingency["turn"].unique():
                if all(
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    pval = chi2_contingency(
                        [
                            df_contingency.loc[
                                (df_contingency["group"] == group)
                                & (df_contingency["mod"] == "No")
                                & (df_contingency["turn"] == turn)
                                & (df_contingency["descr2"] == model),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                            df_contingency.loc[
                                (df_contingency["group"] == group)
                                & (df_contingency["mod"] == "No")
                                & (df_contingency["turn"] == turn)
                                & (df_contingency["descr2"] == model),
                                ["val_y", "not_val_y"],
                            ]
                            .astype("int")
                            .values,
                        ]
                    )[1]
                print(
                    group,
                    model,
                    turn,
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )

                if all(
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "Yes")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    pval = chi2_contingency(
                        [
                            df_contingency.loc[
                                (df_contingency["group"] == group)
                                & (df_contingency["mod"] == "No")
                                & (df_contingency["turn"] == turn)
                                & (df_contingency["descr2"] == model),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                            df_contingency.loc[
                                (df_contingency["group"] == group)
                                & (df_contingency["mod"] == "Yes")
                                & (df_contingency["turn"] == turn)
                                & (df_contingency["descr2"] == model),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                        ]
                    )[1]
                print(
                    "MODIFIED",
                    group,
                    model,
                    turn,
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "No")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency.loc[
                        (df_contingency["group"] == group)
                        & (df_contingency["mod"] == "Yes")
                        & (df_contingency["turn"] == turn)
                        & (df_contingency["descr2"] == model),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )
        if "Mitigation" in df:
            ax = sns.lineplot(
                data=df_group, x="turn", y="val", hue=descr2_name, style="Mitigation"
            )
            # sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
            plt.xlabel("User turn")
            plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
            plt.ylabel("Δ % of lowest surprisal values")
            plt.savefig(
                file_name + "_" + group + "_shaded_mitigation.pdf",
                pad_inches=0,
                bbox_inches="tight",
            )
            plt.show()
            ax = sns.lineplot(
                data=df_group[df_group["Mitigation"] == "No"],
                x="turn",
                y="val",
                hue=descr2_name,
            )
        else:
            ax = sns.lineplot(data=df_group, x="turn", y="val", hue=descr2_name)
        # sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("Δ % of lowest surprisal values")
        plt.savefig(
            file_name + "_" + group + "_shaded.pdf", pad_inches=0, bbox_inches="tight"
        )
        plt.show()


def get_line_plot_surprisal(df, file_name, descr2_name="Attribute"):
    df = df[~df["descr"].str.contains(" mod")]
    df[[str(i) for i in range(250)]] = pd.DataFrame(df.val.tolist(), index=df.index)
    df = df.drop(columns="val")
    groups = ["turn", "descr", "descr2"]
    if "mod" in df:
        groups.append("mod")
    df[[str(i) for i in range(250)]] = df[[str(i) for i in range(250)]] == df.groupby(
        groups
    ).transform(min).drop(columns="group")
    df["val"] = df[[str(i) for i in range(250)]].sum(axis=1) / 250 * 100
    df = df.drop(columns=[str(i) for i in range(250)])
    return_df = df
    df = df[df["descr"] == df["group"]].drop(columns=["descr"])
    df = df.sort_values(by="descr2").rename(
        columns={"descr2": descr2_name, "mod": "Mitigation"}
    )
    if "Mitigation" in df:
        ax = sns.lineplot(
            data=df, x="turn", y="val", hue=descr2_name, style="Mitigation"
        )
        plt.xlabel("User turn")
        plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
        plt.ylabel("% of lowest surprisal values")
        plt.savefig(
            file_name + "_shaded_mitigation.pdf", pad_inches=0, bbox_inches="tight"
        )
        plt.show()
        ax = sns.lineplot(
            data=df[df["Mitigation"] == "No"], x="turn", y="val", hue=descr2_name
        )
    else:
        ax = sns.lineplot(data=df, x="turn", y="val", hue=descr2_name)
    plt.xlabel("User turn")
    plt.xticks([0, 1, 3, 6], [0, 1, 3, 6])
    plt.ylabel("% of lowest surprisal values")
    plt.savefig(file_name + "_shaded.pdf", pad_inches=0, bbox_inches="tight")
    plt.show()
    return return_df


def get_plots(plot_df, probe_layer, file_name, rq2=False, rq3=False, models=False):
    dfs = ""
    df = ""
    plot_df["descr2"] = plot_df["descr2"].replace(
        {
            "age": "Age",
            "gender natural": "Gender Natural",
            "gender": "Gender",
            "race": "Race",
            "socio-economic status": "SES",
            "gemma": "Gemma",
            "llama": "Llama",
            "olmo": "OLMo",
        }
    )
    if "mod" in plot_df:
        probe = plot_df.drop(columns="mod")
    else:
        probe = plot_df
    probe = (
        probe.loc[(probe["source"] == "Probe") & (probe["layer"] > probe_layer)]
        .drop(columns=["source", "layer"])
        .groupby(["descr", "descr2", "group", "turn"])
        .mean()
        .reset_index()
    )
    if rq3:
        get_line_plot_rq3(probe, file_name + "probe")
    elif rq2:
        get_line_plot_rq2(
            probe, file_name + "probe", descr2_name="Model" if models else "Attribute"
        )
    else:
        get_line_plot(
            probe, file_name + "probe", descr2_name="Model" if models else "Attribute"
        )

    surprisal = plot_df.loc[(plot_df["source"] == "Surprisal")].drop(
        columns=["source", "layer"]
    )
    if rq3:
        dfs = get_line_plot_surprisal_rq3(surprisal, file_name + "surprisal")
    elif rq2:
        get_line_plot_surprisal_rq2(
            surprisal,
            file_name + "surprisal",
            descr2_name="Model" if models else "Attribute",
        )
    else:
        df = get_line_plot_surprisal(
            surprisal,
            file_name + "surprisal",
            descr2_name="Model" if models else "Attribute",
        )

    directq = plot_df.loc[(plot_df["source"] == "Direct question")].drop(
        columns=["source", "layer"]
    )
    if rq3:
        get_line_plot_rq3(directq, file_name + "directq")
    elif rq2:
        get_line_plot_rq2(
            directq,
            file_name + "directq",
            descr2_name="Model" if models else "Attribute",
        )
    else:
        get_line_plot(
            directq,
            file_name + "directq",
            descr2_name="Model" if models else "Attribute",
        )

    indirectq = plot_df.loc[(plot_df["source"] == "Indirect question")].drop(
        columns=["source", "layer"]
    )
    if rq3:
        get_line_plot_rq3(indirectq, file_name + "indirectq", mult=5)
    elif rq2:
        get_line_plot_rq2(
            indirectq,
            file_name + "indirectq",
            mult=5,
            descr2_name="Model" if models else "Attribute",
        )
    else:
        get_line_plot(
            indirectq,
            file_name + "indirectq",
            mult=5,
            descr2_name="Model" if models else "Attribute",
        )
    if len(df):
        return df
    if len(dfs):
        return dfs


def get_new_plots_per_demo(models):
    for demographic in models[list(models.keys())[0]]:
        dfs_rq1 = []
        dfs_rq2 = []
        for model in models:
            data = models[model]
            values = [k for k in data[demographic].keys() if k != "neutral_none"]
            dfs = [
                get_df(
                    data[demographic][value]["neutral_demo"],
                    demographic,
                    value.replace("adolescent", "teenager"),
                    model,
                )
                for value in values
            ]
            if model != "gemma":
                for df in dfs:
                    df["layer"] += 10
            dfs_rq1 += dfs
            dfs = [
                get_df(
                    data[demographic][value]["stereo_none"],
                    demographic,
                    value.replace("adolescent", "teenager"),
                    model,
                    mod_col=True,
                )
                for value in values
                if "stereo_none" in data[demographic][value]
            ] + [
                get_df(
                    data[demographic]["neutral_none"],
                    demographic,
                    "neutral",
                    model,
                    mod_col=True,
                )
            ]
            if model != "gemma":
                for df in dfs:
                    df["layer"] += 10
            dfs_rq2 += dfs
        get_plots(
            pd.concat(dfs_rq2),
            37,
            f"results/final_results/rq2/{demographic}_",
            rq2=True,
            models=True,
        )
        get_plots(
            pd.concat(dfs_rq1),
            37,
            f"results/final_results/rq1/{demographic}_",
            models=True,
        )
        if demographic == "gender":
            dfs_rq1 = []
            for model in models:
                data = models[model]
                dfs = [
                    get_df(
                        data[demographic][value]["neutral_natural"],
                        demographic,
                        value.replace("adolescent", "teenager"),
                        model,
                    )
                    for value in values
                ]
                if model != "gemma":
                    for df in dfs:
                        df["layer"] += 10
                dfs_rq1 += dfs
            get_plots(
                pd.concat(dfs_rq1),
                37,
                f"results/final_results/rq1/gender_natural_",
                models=True,
            )

    for model in models:
        data = models[model]
        probe_layer = 27 if model != "gemma" else 37
        dfs_rq1 = []
        dfs_rq2 = []
        dfs_rq3 = []
        for demographic in data:
            values = [k for k in data[demographic].keys() if k != "neutral_none"]
            dfs_rq1 += [
                get_df(
                    data[demographic][value]["neutral_demo"],
                    demographic,
                    value.replace("adolescent", "teenager"),
                    demographic,
                )
                for value in values
            ]
            dfs_rq3 += [
                get_df(
                    data[demographic][value]["anti_demo"][val],
                    demographic,
                    f"{value} -> {val}".replace("adolescent", "teenager"),
                    demographic,
                    mod_col=True,
                )
                for value in values
                for val in data[demographic][value]["anti_demo"]
            ]
            if demographic == "gender":
                dfs_rq1 += [
                    get_df(
                        data[demographic][value]["neutral_natural"],
                        demographic,
                        value.replace("adolescent", "teenager"),
                        demographic + " natural",
                    )
                    for value in values
                ]
                dfs_rq3 += [
                    get_df(
                        data[demographic][value]["anti_natural"][val],
                        demographic,
                        f"{value} -> {val}".replace("adolescent", "teenager"),
                        demographic + " natural",
                        mod_col=True,
                    )
                    for value in values
                    for val in data[demographic][value]["anti_natural"]
                ]
        df = get_plots(
            pd.concat(dfs_rq1), probe_layer, f"results/final_results/rq1/{model}_"
        )
        dfs = get_plots(
            pd.concat(dfs_rq3),
            probe_layer,
            f"results/final_results/rq3/{model}_",
            rq3=True,
        )
        for df_group_name in dfs:
            df_group = dfs[df_group_name]
            df_group_intro = df_group[df_group["Group"] == "Introduction"].drop(
                columns="Group"
            )
            df_group_stereo = df_group[df_group["Group"] == "Stereotypes"].drop(
                columns="Group"
            )
            df_contingency_intro = df_group_intro.merge(
                df[df["descr"] == df["group"]],
                left_on=["from", "Attribute", "turn"],
                right_on=["group", "descr2", "turn"],
            )
            df_contingency_stereo = df_group_stereo.merge(
                df,
                left_on=["to", "Attribute", "from", "turn"],
                right_on=["group", "descr2", "descr", "turn"],
            )
            df_contingency_intro["not_val_x"] = 250 - df_contingency_intro["val_x"]
            df_contingency_intro["not_val_y"] = 250 - df_contingency_intro["val_y"]
            df_contingency_stereo["not_val_x"] = 250 - df_contingency_stereo["val_x"]
            df_contingency_stereo["not_val_y"] = 250 - df_contingency_stereo["val_y"]
            for descr in df_contingency_intro["descr_x"].unique():
                if all(
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    pval = chi2_contingency(
                        [
                            df_contingency_intro.loc[
                                (df_contingency_intro["descr_x"] == descr)
                                & (df_contingency_intro["Mitigation"] == "No")
                                & (df_contingency_intro["turn"] == 6),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                            df_contingency_intro.loc[
                                (df_contingency_intro["descr_x"] == descr)
                                & (df_contingency_intro["Mitigation"] == "No")
                                & (df_contingency_intro["turn"] == 6),
                                ["val_y", "not_val_y"],
                            ]
                            .astype("int")
                            .values,
                        ]
                    )[1]
                print(
                    descr,
                    "Introduction",
                    model,
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )
                if all(
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    try:
                        pval = chi2_contingency(
                            [
                                df_contingency_stereo.loc[
                                    (df_contingency_stereo["descr_x"] == descr)
                                    & (df_contingency_stereo["Mitigation"] == "No")
                                    & (df_contingency_stereo["turn"] == 6),
                                    ["val_x", "not_val_x"],
                                ]
                                .astype("int")
                                .values,
                                df_contingency_stereo.loc[
                                    (df_contingency_stereo["descr_x"] == descr)
                                    & (df_contingency_stereo["Mitigation"] == "No")
                                    & (df_contingency_stereo["turn"] == 6),
                                    ["val_y", "not_val_y"],
                                ]
                                .astype("int")
                                .values,
                            ]
                        )[1]
                    except:
                        print(descr, "Stereotypes", model, "Error")
                print(
                    descr,
                    "Stereotypes",
                    model,
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_y", "not_val_y"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )

                if all(
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "Yes")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    pval = chi2_contingency(
                        [
                            df_contingency_intro.loc[
                                (df_contingency_intro["descr_x"] == descr)
                                & (df_contingency_intro["Mitigation"] == "No")
                                & (df_contingency_intro["turn"] == 6),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                            df_contingency_intro.loc[
                                (df_contingency_intro["descr_x"] == descr)
                                & (df_contingency_intro["Mitigation"] == "Yes")
                                & (df_contingency_intro["turn"] == 6),
                                ["val_x", "not_val_x"],
                            ]
                            .astype("int")
                            .values,
                        ]
                    )[1]
                print(
                    "MITIGATION",
                    descr,
                    "Introduction",
                    model,
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "No")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency_intro.loc[
                        (df_contingency_intro["descr_x"] == descr)
                        & (df_contingency_intro["Mitigation"] == "Yes")
                        & (df_contingency_intro["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )
                if all(
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                    == df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "Yes")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0]
                ):
                    pval = 100
                else:
                    try:
                        pval = chi2_contingency(
                            [
                                df_contingency_stereo.loc[
                                    (df_contingency_stereo["descr_x"] == descr)
                                    & (df_contingency_stereo["Mitigation"] == "No")
                                    & (df_contingency_stereo["turn"] == 6),
                                    ["val_x", "not_val_x"],
                                ]
                                .astype("int")
                                .values,
                                df_contingency_stereo.loc[
                                    (df_contingency_stereo["descr_x"] == descr)
                                    & (df_contingency_stereo["Mitigation"] == "Yes")
                                    & (df_contingency_stereo["turn"] == 6),
                                    ["val_x", "not_val_x"],
                                ]
                                .astype("int")
                                .values,
                            ]
                        )[1]
                    except:
                        print("MITIGATION", descr, "Stereotypes", model, "Error")
                print(
                    "MITIGATION",
                    descr,
                    "Stereotypes",
                    model,
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "No")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    df_contingency_stereo.loc[
                        (df_contingency_stereo["descr_x"] == descr)
                        & (df_contingency_stereo["Mitigation"] == "Yes")
                        & (df_contingency_stereo["turn"] == 6),
                        ["val_x", "not_val_x"],
                    ]
                    .astype("int")
                    .values[0],
                    pval,
                )

In [None]:
get_new_plots_per_demo({"gemma": gemma, "llama": llama, "olmo": olmo})

In [None]:
def get_table_rq3(
    plot_df, probe_layer, file_name, label, caption, comparison_df, mitigation=False
):
    plot_df["descr2"] = plot_df["descr2"].replace(
        {
            "age": "Age",
            "gender natural": "Gender Natural",
            "gender": "Gender",
            "race": "Race",
            "socio-economic status": "SES",
        }
    )
    if "mod" in plot_df:
        probe = plot_df.drop(columns="mod")
    else:
        probe = plot_df

    probe = (
        probe.loc[(probe["source"] == "Probe") & (probe["layer"] > probe_layer)]
        .drop(columns=["source", "layer"])
        .groupby(["descr", "descr2", "group", "turn"])
        .mean()
        .reset_index()
    )

    probe[["Introduction", "Stereotypes"]] = pd.DataFrame(
        probe.descr.str.split(" -> ").tolist(), index=probe.index
    )
    probe["val"] = probe["val"] / 250 * 100
    probe = probe.sort_values(by="descr2")
    probe["type"] = ""
    probe.loc[(probe["Introduction"] == probe["group"]), "type"] = "Introduction"
    probe.loc[(probe["Stereotypes"] == probe["group"]), "type"] = "Stereotypes"
    probe = probe.loc[probe["type"] != ""]

    probe = probe.drop(columns=["descr"])
    probe = (
        probe.groupby(["turn", "descr2", "type", "Introduction", "Stereotypes"])["val"]
        .mean()
        .reset_index()
    )
    probe = probe.rename(columns={"type": "Group", "mod": "Mitigation"})

    surprisal = plot_df.loc[(plot_df["source"] == "Surprisal")].drop(
        columns=["source", "layer"]
    )

    surprisal[["Introduction", "Stereotypes"]] = pd.DataFrame(
        surprisal.descr.str.split(" -> ").tolist(), index=surprisal.index
    )
    surprisal = surprisal.sort_values(by="descr2")
    surprisal["type"] = ""
    surprisal.loc[(surprisal["Introduction"] == surprisal["group"]), "type"] = (
        "Introduction"
    )
    surprisal.loc[(surprisal["Stereotypes"] == surprisal["group"]), "type"] = (
        "Stereotypes"
    )

    surprisal[[str(i) for i in range(250)]] = pd.DataFrame(
        surprisal.val.tolist(), index=surprisal.index
    )
    surprisal = surprisal.drop(columns="val")
    surprisal[[str(i) for i in range(250)]] = surprisal[
        [str(i) for i in range(250)]
    ] == surprisal.groupby(["turn", "descr", "descr2", "mod"]).transform(min).drop(
        columns=["group", "type", "Stereotypes", "Introduction"]
    )
    surprisal["val"] = surprisal[[str(i) for i in range(250)]].sum(axis=1) / 250 * 100
    surprisal = surprisal.drop(columns=[str(i) for i in range(250)])
    surprisal = surprisal[surprisal["type"] != ""]
    surprisal = (
        surprisal.groupby(
            ["turn", "descr2", "Stereotypes", "Introduction", "type", "mod"]
        )["val"]
        .mean()
        .reset_index()
    )
    surprisal = surprisal.rename(columns={"type": "Group", "mod": "Mitigation"})

    directq = plot_df.loc[(plot_df["source"] == "Direct question")].drop(
        columns=["source", "layer"]
    )

    directq[["Introduction", "Stereotypes"]] = pd.DataFrame(
        directq.descr.str.split(" -> ").tolist(), index=directq.index
    )
    directq["val"] = directq["val"] / 250 * 100
    directq = directq.sort_values(by="descr2")
    directq["type"] = ""
    directq.loc[(directq["Introduction"] == directq["group"]), "type"] = "Introduction"
    directq.loc[(directq["Stereotypes"] == directq["group"]), "type"] = "Stereotypes"
    directq = directq.loc[directq["type"] != ""]
    directq = directq.drop(columns=["descr"])
    directq = (
        directq.groupby(
            ["turn", "descr2", "type", "mod", "Introduction", "Stereotypes"]
        )["val"]
        .mean()
        .reset_index()
    )
    directq = directq.rename(columns={"type": "Group", "mod": "Mitigation"})

    indirectq = plot_df.loc[(plot_df["source"] == "Indirect question")].drop(
        columns=["source", "layer"]
    )

    indirectq[["Introduction", "Stereotypes"]] = pd.DataFrame(
        indirectq.descr.str.split(" -> ").tolist(), index=indirectq.index
    )
    indirectq["val"] = indirectq["val"] / (250 * 5) * 100
    indirectq = indirectq.sort_values(by="descr2")
    indirectq["type"] = ""
    indirectq.loc[(indirectq["Introduction"] == indirectq["group"]), "type"] = (
        "Introduction"
    )
    indirectq.loc[(indirectq["Stereotypes"] == indirectq["group"]), "type"] = (
        "Stereotypes"
    )
    indirectq = indirectq.loc[indirectq["type"] != ""]
    indirectq = indirectq.drop(columns=["descr"])
    indirectq = (
        indirectq.groupby(
            ["turn", "descr2", "type", "mod", "Introduction", "Stereotypes"]
        )["val"]
        .mean()
        .reset_index()
    )
    indirectq = indirectq.rename(columns={"type": "Group", "mod": "Mitigation"})

    table_df = directq.rename(columns={"val": "directq"})
    probe = probe.rename(columns={"val": "probe"})
    table_df = pd.merge(
        table_df,
        probe,
        how="outer",
        on=["turn", "descr2", "Group", "Introduction", "Stereotypes"],
    )
    table_df = pd.merge(
        table_df,
        indirectq,
        on=["turn", "descr2", "Group", "Mitigation", "Introduction", "Stereotypes"],
    ).rename(columns={"val": "indirectq"})
    surprisal = surprisal.rename(columns={"val": "surprisal"})
    table_df = pd.merge(
        table_df,
        surprisal,
        how="outer",
        on=["turn", "descr2", "Group", "Mitigation", "Introduction", "Stereotypes"],
    )
    table_df = table_df[table_df["turn"] == 6].drop(columns="turn")
    if not mitigation:
        table_df = table_df[table_df["Mitigation"] == "No"].drop(columns="Mitigation")
    table_df_intro = table_df[table_df["Group"] == "Introduction"]
    table_df_stereo = table_df[table_df["Group"] == "Stereotypes"]
    table_df_intro = table_df_intro.drop(columns="Group")
    table_df_stereo = table_df_stereo.drop(columns="Group")
    if not mitigation:
        comparison_df = comparison_df[comparison_df["turn"] == 6].drop(columns="turn")

        table_df_intro = table_df_intro.merge(
            comparison_df[comparison_df["descr"] == comparison_df["group"]],
            left_on=["Introduction", "descr2"],
            right_on=["group", "descr2"],
        )
        table_df_stereo = table_df_stereo.merge(
            comparison_df,
            left_on=["Stereotypes", "descr2", "Introduction"],
            right_on=["group", "descr2", "descr"],
        )
        for attr in ["probe", "surprisal", "directq", "indirectq"]:
            sign = (table_df_stereo[f"{attr}_x"] - table_df_stereo[f"{attr}_y"]).astype(
                float
            ).round(1) > 0
            sign = sign.map({True: "+", False: ""})
            table_df_stereo[attr] = (
                "$"
                + table_df_stereo[f"{attr}_x"].astype(float).round(1).astype(str)
                + " (\Delta "
                + sign
                + (table_df_stereo[f"{attr}_x"] - table_df_stereo[f"{attr}_y"])
                .astype(float)
                .round(1)
                .astype(str)
                + ")$"
            )
            sign = (table_df_intro[f"{attr}_x"] - table_df_intro[f"{attr}_y"]).astype(
                float
            ).round(1) > 0
            sign = sign.map({True: "+", False: ""})
            table_df_intro[attr] = (
                "$"
                + table_df_intro[f"{attr}_x"].astype(float).round(1).astype(str)
                + " (\Delta "
                + sign
                + (table_df_intro[f"{attr}_x"] - table_df_intro[f"{attr}_y"])
                .astype(float)
                .round(1)
                .astype(str)
                + ")$"
            )
    else:
        no_mod_df = table_df_stereo[table_df_stereo["Mitigation"] == "No"]
        table_df_stereo = table_df_stereo[table_df_stereo["Mitigation"] == "Yes"]
        for attr in ["surprisal", "directq", "indirectq"]:
            table_df_stereo[f"{attr}_diff"] = (
                table_df_stereo[attr].values - no_mod_df[attr].values
            )
            sign = table_df_stereo[f"{attr}_diff"].astype(float).round(1) > 0
            sign = sign.map({True: "+", False: ""})
            table_df_stereo[attr] = (
                "$"
                + table_df_stereo[attr].astype(float).round(1).astype(str)
                + " (\Delta "
                + sign
                + table_df_stereo[f"{attr}_diff"].astype(float).round(1).astype(str)
                + ")$"
            )
            table_df_stereo = table_df_stereo.drop(columns=f"{attr}_diff")
        table_df_stereo = table_df_stereo.drop(columns="probe")

        no_mod_df = table_df_intro[table_df_intro["Mitigation"] == "No"]
        table_df_intro = table_df_intro[table_df_intro["Mitigation"] == "Yes"]
        for attr in ["surprisal", "directq", "indirectq"]:
            table_df_intro[f"{attr}_diff"] = (
                table_df_intro[attr].values - no_mod_df[attr].values
            )
            sign = table_df_intro[f"{attr}_diff"].astype(float).round(1) > 0
            sign = sign.map({True: "+", False: ""})
            table_df_intro[attr] = (
                "$"
                + table_df_intro[attr].astype(float).round(1).astype(str)
                + " (\Delta "
                + sign
                + table_df_intro[f"{attr}_diff"].astype(float).round(1).astype(str)
                + ")$"
            )
            table_df_intro = table_df_intro.drop(columns=f"{attr}_diff")
        table_df_intro = table_df_intro.drop(columns="probe")

    table_df_intro = table_df_intro.reset_index()
    table_df_intro = table_df_intro[table_df_intro["descr2"] != "Gender Natural"]
    table_df_stereo = table_df_stereo.reset_index()
    table_df_stereo = table_df_stereo[table_df_stereo["descr2"] != "Gender Natural"]
    if not mitigation:
        table_df_stereo = table_df_stereo.rename(
            columns={
                "descr2": "Attribute",
                "probe": "Probe",
                "surprisal": "Surprisal",
                "directq": "Direct question",
                "indirectq": "Indirect questions",
            }
        )
        table_df_stereo = table_df_stereo[
            [
                "Attribute",
                "Introduction",
                "Stereotypes",
                "Probe",
                "Surprisal",
                "Direct question",
                "Indirect questions",
            ]
        ]
        table_df_stereo.to_latex(
            buf=file_name.split(".tex")[0] + "_stereo.tex",
            caption=caption,
            label=label,
            index=False,
        )
        table_df_intro = table_df_intro.rename(
            columns={
                "descr2": "Attribute",
                "probe": "Probe",
                "surprisal": "Surprisal",
                "directq": "Direct question",
                "indirectq": "Indirect questions",
            }
        )
        table_df_intro = table_df_intro[
            [
                "Attribute",
                "Introduction",
                "Stereotypes",
                "Probe",
                "Surprisal",
                "Direct question",
                "Indirect questions",
            ]
        ]
        table_df_intro.to_latex(
            buf=file_name.split(".tex")[0] + "_intro.tex",
            caption=caption,
            label=label,
            index=False,
        )
    else:
        table_df_stereo = table_df_stereo.rename(
            columns={
                "descr2": "Attribute",
                "surprisal": "Surprisal",
                "directq": "Direct question",
                "indirectq": "Indirect questions",
            }
        )
        table_df_stereo = table_df_stereo[
            [
                "Attribute",
                "Introduction",
                "Stereotypes",
                "Surprisal",
                "Direct question",
                "Indirect questions",
            ]
        ]
        table_df_stereo.to_latex(
            buf=file_name.split(".tex")[0] + "_modified_stereo.tex",
            caption=caption,
            label=label,
            index=False,
        )
        table_df_intro = table_df_intro.rename(
            columns={
                "descr2": "Attribute",
                "surprisal": "Surprisal",
                "directq": "Direct question",
                "indirectq": "Indirect questions",
            }
        )
        table_df_intro = table_df_intro[
            [
                "Attribute",
                "Introduction",
                "Stereotypes",
                "Surprisal",
                "Direct question",
                "Indirect questions",
            ]
        ]
        table_df_intro.to_latex(
            buf=file_name.split(".tex")[0] + "_modified_intro.tex",
            caption=caption,
            label=label,
            index=False,
        )


def get_table(
    plot_df, probe_layer, file_name, label, caption, rq2=False, mitigation=False
):
    plot_df["descr2"] = plot_df["descr2"].replace(
        {
            "age": "Age",
            "gender natural": "Gender Natural",
            "gender": "Gender",
            "race": "Race",
            "socio-economic status": "SES",
        }
    )
    if "mod" in plot_df:
        probe = plot_df.drop(columns="mod")
    else:
        probe = plot_df

    probe = (
        probe.loc[(probe["source"] == "Probe") & (probe["layer"] > probe_layer)]
        .drop(columns=["source", "layer"])
        .groupby(["descr", "descr2", "group", "turn"])
        .mean()
        .reset_index()
    )
    if rq2:
        probe_neutral = probe[probe["descr"] == "neutral"]
        probe = probe[probe["descr"] != "neutral"]
    probe["val"] = probe["val"] / 250 * 100
    if not rq2:
        return_probe = probe
    probe = probe[probe["descr"] == probe["group"]].drop(columns=["descr"])

    if rq2:
        probe_neutral = probe_neutral[
            ~probe_neutral["group"].isin(["none", "mixed"])
        ].drop(columns=["descr"])
        probe_neutral["val"] = probe_neutral["val"] / 250 * 100
        probe = probe.merge(probe_neutral, how="left", on=["descr2", "group", "turn"])
        probe["val"] = probe["val_x"] - probe["val_y"]
        probe = probe.drop(columns=["val_y"])

    surprisal = plot_df.loc[(plot_df["source"] == "Surprisal")].drop(
        columns=["source", "layer"]
    )
    if rq2:
        neutral_surprisal = surprisal[surprisal["descr"] == "neutral"]
        surprisal = surprisal[surprisal["descr"] != "neutral"]

    surprisal[[str(i) for i in range(250)]] = pd.DataFrame(
        surprisal.val.tolist(), index=surprisal.index
    )
    surprisal = surprisal.drop(columns="val")
    groups = ["turn", "descr", "descr2"]
    if "mod" in surprisal:
        groups.append("mod")
    surprisal[[str(i) for i in range(250)]] = surprisal[
        [str(i) for i in range(250)]
    ] == surprisal.groupby(groups).transform(min).drop(columns="group")

    surprisal["val"] = surprisal[[str(i) for i in range(250)]].sum(axis=1) / 250 * 100
    surprisal = surprisal.drop(columns=[str(i) for i in range(250)])
    if not rq2:
        return_surprisal = surprisal
    surprisal = surprisal[surprisal["descr"] == surprisal["group"]].drop(
        columns=["descr"]
    )

    if rq2:
        neutral_surprisal[[str(i) for i in range(250)]] = pd.DataFrame(
            neutral_surprisal.val.tolist(), index=neutral_surprisal.index
        )
        neutral_surprisal = neutral_surprisal.drop(columns="val")
        groups = ["turn", "descr", "descr2"]
        if "mod" in neutral_surprisal:
            groups.append("mod")
        neutral_surprisal[[str(i) for i in range(250)]] = neutral_surprisal[
            [str(i) for i in range(250)]
        ] == neutral_surprisal.groupby(groups).transform(min).drop(columns="group")
        neutral_surprisal = neutral_surprisal.drop(columns=["descr"])
        neutral_surprisal["val"] = (
            neutral_surprisal[[str(i) for i in range(250)]].sum(axis=1) / 250 * 100
        )
        neutral_surprisal = neutral_surprisal.drop(columns=[str(i) for i in range(250)])
        surprisal = surprisal.merge(
            neutral_surprisal, how="left", on=["descr2", "group", "turn", "mod"]
        )
        surprisal["val"] = surprisal["val_x"] - surprisal["val_y"]
        surprisal = surprisal.drop(columns=["val_y"])

    directq = plot_df.loc[(plot_df["source"] == "Direct question")].drop(
        columns=["source", "layer"]
    )
    if rq2:
        directq_neutral = directq[directq["descr"] == "neutral"]
        directq = directq[directq["descr"] != "neutral"]

    directq["val"] = directq["val"] / 250 * 100
    if not rq2:
        return_directq = directq
    directq = directq[directq["descr"] == directq["group"]].drop(columns=["descr"])

    if rq2:
        directq_neutral = directq_neutral[
            ~directq_neutral["group"].isin(["none", "mixed"])
        ].drop(columns=["descr"])
        directq_neutral["val"] = directq_neutral["val"] / 250 * 100
        directq = directq.merge(
            directq_neutral, how="left", on=["descr2", "group", "turn", "mod"]
        )
        directq["val"] = directq["val_x"] - directq["val_y"]
        directq = directq.drop(columns=["val_y"])

    indirectq = plot_df.loc[(plot_df["source"] == "Indirect question")].drop(
        columns=["source", "layer"]
    )
    if rq2:
        indirectq_neutral = indirectq[indirectq["descr"] == "neutral"]
        indirectq = indirectq[indirectq["descr"] != "neutral"]

    indirectq["val"] = indirectq["val"] / (250 * 5) * 100
    if not rq2:
        return_indirectq = indirectq
    indirectq = indirectq[indirectq["descr"] == indirectq["group"]].drop(
        columns=["descr"]
    )

    if rq2:
        indirectq_neutral = indirectq_neutral[
            ~indirectq_neutral["group"].isin(["none", "mixed"])
        ].drop(columns=["descr"])
        indirectq_neutral["val"] = indirectq_neutral["val"] / (250 * 5) * 100
        indirectq = indirectq.merge(
            indirectq_neutral, how="left", on=["descr2", "group", "turn", "mod"]
        )
        indirectq["val"] = indirectq["val_x"] - indirectq["val_y"]
        indirectq = indirectq.drop(columns=["val_y"])

    if not rq2:
        return_df = return_surprisal.rename(columns={"val": "surprisal"})
        groups = ["turn", "group", "descr2", "descr"]
        return_df = pd.merge(return_probe, return_df, on=groups, how="outer").rename(
            columns={"val": "probe"}
        )
        return_df = pd.merge(return_directq, return_df, on=groups, how="outer").rename(
            columns={
                "val": "directq",
            }
        )
        return_df = pd.merge(
            return_indirectq, return_df, on=groups, how="outer"
        ).rename(columns={"val": "indirectq"})

    table_df = surprisal.rename(columns={"val": "surprisal", "val_x": "old_surprisal"})
    groups = ["turn", "group", "descr2"]
    table_df = pd.merge(probe, table_df, on=groups, how="outer").rename(
        columns={"val": "probe", "val_x": "old_probe"}
    )
    if "mod" in table_df:
        groups.append("mod")
    table_df = pd.merge(directq, table_df, on=groups).rename(
        columns={"val": "directq", "val_x": "old_directq"}
    )
    table_df = pd.merge(indirectq, table_df, on=groups).rename(
        columns={"val": "indirectq", "val_x": "old_indirectq"}
    )
    table_df = table_df[table_df["turn"].isin([0, 6])]
    if rq2:
        table_df = table_df[table_df["turn"] == 6]
    groups = ["descr2", "turn"]
    if "mod" in table_df:
        groups.append("mod")
    if rq2:
        groups.append("group")
    vals = ["probe", "surprisal", "directq", "indirectq"]
    if rq2:
        vals += ["old_probe", "old_surprisal", "old_directq", "old_indirectq"]
    table_df = table_df.groupby(groups)[vals].mean().reset_index()
    if not rq2:
        table_df = table_df.pivot(
            index=["descr2"],
            columns="turn",
            values=vals,
        )
        table_df.columns = [
            "_".join([col[0], str(col[1])]) for col in table_df.columns.to_flat_index()
        ]
        for attr in vals:
            table_df[attr] = (
                "$"
                + table_df[f"{attr}_0"].astype(float).round(1).astype(str)
                + " \\rightarrow "
                + table_df[f"{attr}_6"].astype(float).round(1).astype(str)
                + "$"
            )
        table_df = table_df.drop(
            columns=[f"{attr}_0" for attr in vals] + [f"{attr}_6" for attr in vals]
        )
        table_df = table_df.reset_index()
    if rq2:
        if not mitigation:
            table_df = table_df[table_df["mod"] == "No"]
            for attr in ["probe", "surprisal", "directq", "indirectq"]:
                sign = table_df[attr].astype(float).round(1) > 0
                sign = sign.map({True: "+", False: ""})
                table_df[attr] = (
                    "$"
                    + table_df[f"old_{attr}"].astype(float).round(1).astype(str)
                    + " (\Delta "
                    + sign
                    + table_df[attr].astype(float).round(1).astype(str)
                    + ")$"
                )
        else:
            no_mod_df = table_df[table_df["mod"] == "No"]
            table_df = table_df[table_df["mod"] == "Yes"]
            for attr in ["surprisal", "directq", "indirectq"]:
                table_df[attr] = (
                    table_df[f"old_{attr}"].values - no_mod_df[f"old_{attr}"].values
                )
                sign = table_df[attr].astype(float).round(1) > 0
                sign = sign.map({True: "+", False: ""})
                table_df[attr] = (
                    "$"
                    + table_df[f"old_{attr}"].astype(float).round(1).astype(str)
                    + " (\Delta "
                    + sign
                    + table_df[attr].astype(float).round(1).astype(str)
                    + ")$"
                )
            table_df = table_df.drop(columns="probe")
        table_df = table_df.drop(
            columns=[
                "old_probe",
                "old_surprisal",
                "old_directq",
                "old_indirectq",
                "turn",
                "mod",
            ]
        )
    if "mod" in table_df:
        table_df.loc[table_df["mod"] == "Yes", "probe"] = "-"
    table_df = table_df.rename(
        columns={
            "group": "Group",
            "descr2": "Attribute",
            "mod": "Mitigation",
            "probe": "Probe",
            "surprisal": "Surprisal",
            "directq": "Direct question",
            "indirectq": "Indirect questions",
        }
    )
    if not mitigation:
        table_df.to_latex(buf=file_name, caption=caption, label=label, index=False)
    else:
        table_df.to_latex(
            buf=file_name.split(".tex")[0] + "_modified.tex",
            caption=caption,
            label=label,
            index=False,
        )
    if not rq2:
        return return_df


def get_table_per_demo(data, model, probe_layer):
    dfs_rq1 = []
    dfs_rq2 = []
    dfs_rq3 = []
    for demographic in data:
        values = [k for k in data[demographic].keys() if k != "neutral_none"]
        dfs_rq1 += [
            get_df(
                data[demographic][value]["neutral_demo"],
                demographic,
                value.replace("adolescent", "teenager"),
                demographic,
            )
            for value in values
        ]
        dfs_rq2 += [
            get_df(
                data[demographic][value]["stereo_none"],
                demographic,
                value.replace("adolescent", "teenager"),
                demographic,
                mod_col=True,
            )
            for value in values
            if "stereo_none" in data[demographic][value]
        ]
        dfs_rq2 += [
            get_df(
                data[demographic]["neutral_none"],
                demographic,
                "neutral",
                demographic,
                mod_col=True,
            )
        ]
        dfs_rq3 += [
            get_df(
                data[demographic][value]["anti_demo"][val],
                demographic,
                f"{value} -> {val}".replace("adolescent", "teenager"),
                demographic,
                mod_col=True,
            )
            for value in values
            for val in data[demographic][value]["anti_demo"]
        ]
        if demographic == "gender":
            dfs_rq1 += [
                get_df(
                    data[demographic][value]["neutral_natural"],
                    demographic,
                    value.replace("adolescent", "teenager"),
                    demographic + " natural",
                )
                for value in values
            ]
            dfs_rq3 += [
                get_df(
                    data[demographic][value]["anti_natural"][val],
                    demographic,
                    f"{value} -> {val}".replace("adolescent", "teenager"),
                    demographic + " natural",
                    mod_col=True,
                )
                for value in values
                for val in data[demographic][value]["anti_natural"]
            ]
    return_df = get_table(
        pd.concat(dfs_rq1),
        probe_layer,
        f"results/rq1/{model}.tex",
        f"tab:rq1_{model}",
        model,
    )
    get_table(
        pd.concat(dfs_rq2),
        probe_layer,
        f"results/rq2/{model}.tex",
        f"tab:rq2_{model}",
        model,
        rq2=True,
    )
    get_table(
        pd.concat(dfs_rq2),
        probe_layer,
        f"results/rq2/{model}.tex",
        f"tab:rq2_{model}",
        model,
        rq2=True,
        mitigation=True,
    )
    get_table_rq3(
        pd.concat(dfs_rq3),
        probe_layer,
        f"results/rq3/{model}.tex",
        f"tab:rq3_{model}",
        model,
        return_df,
    )
    get_table_rq3(
        pd.concat(dfs_rq3),
        probe_layer,
        f"results/rq3/{model}.tex",
        f"tab:rq3_{model}",
        model,
        return_df,
        mitigation=True,
    )

In [None]:
models = {"llama": llama, "olmo": olmo, "gemma": gemma}

for model in models:
    get_table_per_demo(models[model], model, 27 if model != "gemma" else 37)

In [None]:
from scipy.stats import chisquare, chi2_contingency, kstest, ttest_ind


def get_stats(df, order, rq3=False, mitigation=False):
    df_contingency = pd.crosstab(
        df["descr"], df["group"], values=df["val"], aggfunc=sum
    ).astype("int32")
    for l in order[1:]:
        label = l
        if rq3 and not "->" in l:
            continue
        if mitigation and not "mod" in l:
            continue
        if rq3:
            compare = l.split(" ->")[0]
            if modified:
                # for explicit val: " ->" [0]
                val = l.split("-> ")[1].split(" mod")[0]
            else:
                # for stereotype val: "-> " [-1]
                val = l.split(" ->")[0]
        else:
            compare = order[0]
            val = l.split(" mod")[0]
        df_curr = df_contingency.loc[[compare, l], :]
        df_curr[f"not {val}"] = df_curr[[c for c in df_curr if c != val]].sum(axis=1)
        df_curr = df_curr[[val, f"not {val}"]]
        df_curr = df_curr.loc[:, (df_curr != 0).any(axis=0)]
        pvalue = chi2_contingency(df_curr)[1]
        print(l, pvalue)
        if modified:
            df_curr = df_contingency.loc[[l, l.split(" mod")[0]], :]
            df_curr[f"not {val}"] = df_curr[[c for c in df_curr if c != val]].sum(
                axis=1
            )
            df_curr = df_curr[[val, f"not {val}"]]
            df_curr = df_curr.loc[:, (df_curr != 0).any(axis=0)]
            pvalue = chi2_contingency(df_curr)[1]
            print(l, pvalue)
    return labels


def get_stat_dfs(
    data,
    demographic,
    setting,
    mitigation=False,
    probe_layer=27,
):
    values = [k for k in data[demographic].keys() if k != "neutral_none"]
    if setting == "rq1":
        if natural_intro:
            demo_col = "neutral_natural"
        else:
            demo_col = "neutral_demo"
        dfs = [
            get_df(
                data[demographic]["neutral_none"],
                demographic,
                "no info",
            )
        ] + [
            get_df(
                data[demographic][value][demo_col],
                demographic,
                value.replace("adolescent", "teenager"),
            )
            for value in values
        ]
        order = ["no info"] + sorted(
            [value.replace("adolescent", "teenager") for value in values]
        )

    if setting == "rq2":
        dfs = [
            get_df(
                data[demographic]["neutral_none"],
                demographic,
                "neutral",
            )
        ] + [
            get_df(
                data[demographic][value]["stereo_none"],
                demographic,
                value.replace("adolescent", "teenager"),
            )
            for value in values
            if "stereo_none" in data[demographic][value]
        ]
        order = ["neutral"]
        for val in sorted(
            [
                value.replace("adolescent", "teenager")
                for value in values
                if "stereo_none" in data[demographic][value]
            ]
        ):
            order.append(val)
            if modified:
                order.append(val + " mod")
    if setting == "rq3":
        demo_col = "neutral_demo"
        anti_col = "anti_demo"
        dfs = [
            get_df(
                data[demographic][value][demo_col],
                demographic,
                value.replace("adolescent", "teenager"),
            )
            for value in values
        ] + [
            get_df(
                data[demographic][value][anti_col][val],
                demographic,
                f"{value} -> {val}".replace("adolescent", "teenager"),
            )
            for value in values
            for val in data[demographic][value][anti_col]
        ]
        order = []
        for value in sorted(
            [value.replace("adolescent", "teenager") for value in values]
        ):
            order.append(value)
            for val in data[demographic][value.replace("teenager", "adolescent")][
                anti_col
            ]:
                order.append(f"{value} -> {val}".replace("adolescent", "teenager"))
                if modified:
                    order.append(
                        f"{value} -> {val} mod".replace("adolescent", "teenager")
                    )
    group_order = sorted(
        [value.replace("adolescent", "teenager") for value in values]
    ) + ["none", "mixed"]
    stat_df = pd.concat(dfs)
    stat_df["descr"] = pd.Categorical(stat_df["descr"], categories=order, ordered=True)
    stat_df.dropna(subset=["descr"], axis=0, inplace=True)
    if not mitigation:
        probe = (
                stat_df.loc[
                    (stat_df["source"] == "Probe")
                    & (stat_df["turn"] == 6)
                    & (stat_df["layer"] > probe_layer)
                ]
                .drop(columns=["source", "turn", "layer"])
                .groupby(["descr", "group"])
                .mean()
                .reset_index()
        )
        get_stats(
                probe, order, rq3=setting == "rq3", mitigation=mitigation
        )
        directq = stat_df.loc[
            (stat_df["source"] == "Direct question") & (stat_df["turn"] == 6)
        ].drop(columns=["source", "turn", "layer"])
        get_stats(
            directq,
            order,
            rq3=setting == "rq3",
            mitigation=mitigation,
        )
        indirectq = stat_df.loc[
            (stat_df["source"] == "Indirect question") & (stat_df["turn"] == turn)
        ].drop(columns=["source", "turn", "layer"])
        get_stats(
            indirectq,
            order,
            rq3=setting == "rq3",
            mitigation=mitigation,
        )

In [None]:
models = {"llama": llama, "olmo": olmo, "gemma": gemma}

for model in models:
    for demographic in ["age", "gender", "race", "socio-economic status"]:
        for setting in ["rq1", "rq2", "rq3"]:
            get_stat_dfs(
                models[model],
                demographic,
                setting=setting,
                save_file=f"results/{model}/{demographic}_{setting}_new.eps".replace(
                    " ", "-"
                ),
                probe_layer=27 if model != "gemma" else 37,
            )
            if setting in ["rq2", "rq3"]:
                get_stat_dfs(
                    models[model],
                    demographic,
                    setting=setting,
                    save_file=f"results/{model}/{demographic}_{setting}_steering_new.eps".replace(
                        " ", "-"
                    ),
                    modified=True,
                    probe_layer=27 if model != "gemma" else 37,
                )
            if setting in ["rq1", "rq3"] and demographic == "gender":
                get_stat_dfs(
                    models[model],
                    demographic,
                    setting=setting,
                    save_file=f"results/{model}/{demographic}_{setting}_natural_new.eps",
                    natural_intro=True,
                    probe_layer=27 if model != "gemma" else 37,
                )
                if setting == "rq3":
                    get_stat_dfs(
                        models[model],
                        demographic,
                        setting=setting,
                        save_file=f"results/{model}/{demographic}_{setting}_natural_steering_new.eps",
                        natural_intro=True,
                        modified=True,
                        probe_layer=27 if model != "gemma" else 37,
                    )

## Probes

In [None]:
with open(
    "results/Llama-3.1-8B-Instruct_probe_results.pkl", "rb"
) as infile:
    llama_probe = pickle.load(infile)
with open(
    "results/OLMo-2-1124-7B-Instruct_probe_results.pkl", "rb"
) as infile:
    olmo_probe = pickle.load(infile)
with open("results/gemma-2-9b-it_probe_results.pkl", "rb") as infile:
    gemma_probe = pickle.load(infile)

In [None]:
plt.rcParams.update({"font.size": 30})

for name, results in [
    ("llama", llama_probe),
    ("olmo", olmo_probe),
    ("gemma", gemma_probe),
]:
    df = pd.DataFrame(results)
    df = df.T
    df = df.stack()
    df = df.reset_index()
    df = df.rename(columns={"level_0": "demographic", "level_1": "layer"})
    acc = (
        pd.DataFrame(df[0].tolist(), columns=["1", "2", "3", "4", "5"])
        .stack()
        .reset_index()
        .rename(columns={"level_1": "cv", 0: "acc"})
    )
    acc["acc"] = acc["acc"].mul(100)
    df = pd.merge(df, acc, left_index=True, right_on=["level_0"]).drop(
        columns=[0, "level_0"]
    )
    print(df[df["acc"] == 100].groupby("demographic")["layer"].min())
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
        2, 2, sharex=True, sharey=True, figsize=(12, 10)
    )

    sns.lineplot(data=df[df["demographic"] == "age"], x="layer", y="acc", ax=ax1)
    plt.ylim(0, 105)
    ax1.set_title("Age")
    ax1.set_ylabel("Accuracy")
    ax1.set_xlabel("Layer")

    sns.lineplot(data=df[df["demographic"] == "gender"], x="layer", y="acc", ax=ax2)
    plt.ylim(0, 105)
    ax2.set_title("Gender")
    ax2.set_xlabel("Layer")

    sns.lineplot(data=df[df["demographic"] == "race"], x="layer", y="acc", ax=ax3)
    plt.ylim(0, 105)
    ax3.set_title("Race")
    ax3.set_ylabel("Accuracy")
    ax3.set_xlabel("Layer")

    sns.lineplot(
        data=df[df["demographic"] == "socio-economic status"],
        x="layer",
        y="acc",
        ax=ax4,
    )
    plt.ylim(0, 105)
    ax4.set_title("Socio-economic Status")
    ax4.set_xlabel("Layer")
    fig.savefig(f"results/probes/{name}_probe.pdf", bbox_inches="tight")
    plt.show()

In [None]:
plt.rcParams.update({"font.size": 30})
model_results = {
    "Gemma": gemma_probe,
    "Llama": llama_probe,
    "OLMo": olmo_probe,
}
for model in model_results:
    results = model_results[model]
    df = pd.DataFrame(results)
    df = df.T
    df = df.stack()
    df = df.reset_index()
    df = df.rename(columns={"level_0": "demographic", "level_1": "layer"})
    acc = (
        pd.DataFrame(df[0].tolist(), columns=["1", "2", "3", "4", "5"])
        .stack()
        .reset_index()
        .rename(columns={"level_1": "cv", 0: "acc"})
    )
    acc["acc"] = acc["acc"].mul(100)
    df = pd.merge(df, acc, left_index=True, right_on=["level_0"]).drop(
        columns=[0, "level_0"]
    )
    print(df[df["acc"] == 100].groupby("demographic")["layer"].min())
    model_results[model] = df

for demographic in ["age", "gender", "socio-economic status", "race"]:
    fig, axes = plt.subplots(1, 3, sharex=False, sharey=True, figsize=(18, 5))
    for i, model in enumerate(model_results):
        df = model_results[model]
        sns.lineplot(
            data=df[df["demographic"] == demographic], x="layer", y="acc", ax=axes[i]
        )
        plt.ylim(0, 105)
        axes[i].xaxis.set_major_locator(ticker.MultipleLocator(10))
        axes[i].set_title(model)
        axes[i].set_ylabel("Accuracy")
        axes[i].set_xlabel("Layer")

    fig.savefig(
        f"results/probes/{demographic.replace(" ", "-")}_probe.pdf",
        bbox_inches="tight",
    )
    plt.show()