In [None]:
# %%
import numpy as np
import matplotlib.pyplot as plt
from task import get_acts, get_acts_pca
from days_of_week_task import DaysOfWeekTask
from months_of_year_task import MonthsOfYearTask
import os
from adjustText import adjust_text
import dill as pickle
import matplotlib
from matplotlib.lines import Line2D
from sklearn.decomposition import PCA
import scipy.stats
import pandas as pd
from utils import BASE_DIR
import torch

os.makedirs("figs/paper_plots", exist_ok=True)

torch.set_grad_enabled(False)
device = "cpu"
dtype = "float32"

In [None]:
# Two PCA plots


s = 2

fig = plt.figure(figsize=(2.75, 1.25))
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)


for ax in [ax1, ax2]:
    ax.tick_params(axis="both", which="major", labelsize=8)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    # ax.spines['bottom'].set_visible(False)
    # ax.spines['left'].set_visible(False)
    ax.spines["left"].set_position("zero")
    ax.spines["bottom"].set_position("zero")
    ax.set_xticks([])
    ax.set_yticks([])


# Left plot
task = DaysOfWeekTask(device, "mistral", dtype=dtype)
problems = task.generate_problems()
tokens = task.allowable_tokens
acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]

texts = []
for token_index, token in enumerate(tokens):
    indices = [i for i, p in enumerate(problems) if p.info[0] == token_index]
    ax1.scatter(acts[indices, 0], acts[indices, 1], label=token, s=s)

    # Add label
    texts.append(
        ax1.text(acts[indices[0], 0], acts[indices[0], 1], token[:3], fontsize=8)
    )

adjust_text(texts, ax=ax1, force_text=(0.5, 1))

# Move thursday up
thursday_text = texts[3]
thursday_pos = texts[3].get_position()
texts[3].set_position(thursday_pos + np.array([0, 1.2]))

# Move sunday down
sunday_text = texts[6]
sunday_pos = texts[6].get_position()
texts[6].set_position(sunday_pos + np.array([0, -0.5]))

ax1.set_xlim(-8, 8)
ax1.set_ylim(-8, 8)

# Right plot
task = MonthsOfYearTask(device, "llama", dtype=dtype)
problems = task.generate_problems()
tokens = task.allowable_tokens
acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]
colorwheel = plt.cm.rainbow(np.linspace(0, 1 - 1 / 12, 12))

texts = []
for token_index, token in enumerate(tokens):
    indices = [i for i, p in enumerate(problems) if p.info[0] == token_index]
    ax2.scatter(acts[indices, 0], acts[indices, 1], s=s, color=colorwheel[token_index])

    # Add label
    texts.append(
        ax2.text(acts[indices[0], 0], acts[indices[0], 1], token[:3], fontsize=8)
    )

adjust_text(texts, ax=ax2, force_text=(0.25, 0.5))

ax2.set_xlim(-0.6, 0.6)
ax2.set_ylim(-0.6, 0.6)

x_line = ax1.get_position().x1 + 0.025
fig.add_artist(
    Line2D(
        [x_line, x_line],
        [0.05, 0.95],
        transform=fig.transFigure,
        color="grey",
        linewidth=0.5,
    )
)


plt.tight_layout(pad=1)

fig = plt.gcf()


plt.show()

fig.savefig("figs/paper_plots/paper_pcas.pdf", bbox_inches=0)

In [None]:
font = {"size": 6}

matplotlib.rc("font", **font)

width = 0.4

layer_averages_files = [
    "figs/mistral_days_of_week/rotation_probing/a_cos_sin_all_layers_token_15_mean_logit_diffs.pkl",
    "figs/mistral_months_of_year/rotation_probing/a_cos_sin_all_layers_token_13_mean_logit_diffs.pkl",
    "figs/llama_days_of_week/rotation_probing/a_cos_sin_all_layers_token_14_mean_logit_diffs.pkl",
    "figs/llama_months_of_year/rotation_probing/a_cos_sin_all_layers_token_11_mean_logit_diffs.pkl",
]

fig, axs = plt.subplots(2, 2, figsize=(3.3, 2), sharex=True)
axs = axs.flatten()

max_layer = 25

titles = ["Mistral Weekdays", "Mistral Months", "Llama Weekdays", "Llama Months"]

for i, layer_averages_file in enumerate(layer_averages_files):
    layer_averages = pickle.load(open(layer_averages_file, "rb"))
    intervention_pca_k = 5
    layers = []
    layer_averages = layer_averages[1::2]

    # average_no_replace.append(t[2])
    # average_replace_circle.append(t[3])
    # average_replace_pca.append(t[4])
    # average_replace_all.append(t[5])
    # average_average_ablate.append(t[6])
    # average_zero_circle.append(t[7])
    # average_zero_everything_but_circle.append(t[8])

    index_to_name = {
        2: "no_replace",
        3: "replace_circle",
        4: "replace_pca",
        5: "replace_all",
        6: "average_ablate",
        7: "zero_circle",
        8: "zero_everything_but_circle",
    }

    average_dict = {key: [] for key in index_to_name.values()}
    ci_dict = {key: [] for key in index_to_name.values()}

    def mean_confidence_interval(data, confidence=0.96):
        a = 1.0 * np.array(data)
        n = len(a)
        m, se = np.mean(a), scipy.stats.sem(a)
        h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1)
        return m, m - h, m + h

    for t in layer_averages:
        if t[1] == intervention_pca_k:
            layers.append(t[0])
            for index in range(2, 9):
                mean, lower, upper = mean_confidence_interval(t[index])
                average_dict[index_to_name[index]].append(mean)
                ci_dict[index_to_name[index]].append((lower, upper))

    plot_dict = {
        "no_replace": "No-op",
        "replace_all": "Patch layer",
        "replace_circle": "Patch circle",
        "replace_pca": "Patch pca",
        "average_ablate": "Average ablate",
    }

    for key, label in plot_dict.items():
        axs[i].plot(
            layers[:max_layer],
            average_dict[key][:max_layer],
            label=label,
            linewidth=width,
            alpha=0.9,
        )

        # Plot ci
        lower = [ci[0] for ci in ci_dict[key][:max_layer]]
        upper = [ci[1] for ci in ci_dict[key][:max_layer]]
        axs[i].fill_between(layers[:max_layer], lower, upper, alpha=0.6)

    axs[i].set_title(titles[i])


def format_subplot(ax, grid_x=True):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if grid_x:
        ax.grid(linestyle="--", alpha=0.4)
    else:
        ax.grid(axis="y", linestyle="--", alpha=0.4)


for ax in axs:
    format_subplot(ax)

xlabel = fig.supxlabel("Layer")
ylabel = fig.supylabel("Average logit diff")

# Move xlabel up
xlabel.set_position((0.55, 0.06))

# Move ylabel right
ylabel.set_position((0.01, 0.5))

fig = plt.gcf()
handles, labels = axs[0].get_legend_handles_labels()

colorwheel = plt.cm.tab10(np.linspace(0, 1, 10))
linesize = 0.01
markersize = 0
legend_elements = [
    Line2D(
        [0, linesize],
        [0, 0],
        marker="o",
        color=colorwheel[0],
        label="No-op",
        markerfacecolor=colorwheel[0],
        markersize=markersize,
    ),
    Line2D(
        [0, linesize],
        [0, 0],
        marker="o",
        color=colorwheel[1],
        label="Patch layer",
        markerfacecolor=colorwheel[1],
        markersize=markersize,
    ),
    Line2D(
        [0, linesize],
        [0, 0],
        marker="o",
        color=colorwheel[2],
        label="Patch circle",
        markerfacecolor=colorwheel[2],
        markersize=markersize,
    ),
    Line2D(
        [0, linesize],
        [0, 0],
        marker="o",
        color=colorwheel[3],
        label="Patch PCA",
        markerfacecolor=colorwheel[3],
        markersize=markersize,
    ),
    Line2D(
        [0, linesize],
        [0, 0],
        marker="o",
        color=colorwheel[4],
        label="Average ablate",
        markerfacecolor=colorwheel[4],
        markersize=markersize,
    ),
]

leg = fig.legend(
    handles=legend_elements,
    loc="upper center",
    ncol=3,
    bbox_to_anchor=(0.52, 0.06),
    labelspacing=0,
    handletextpad=0.3,
    columnspacing=1,
    handlelength=0.8,
)
for legobj in leg.legendHandles:
    legobj.set_linewidth(1.5)

fig.add_artist(
    Line2D(
        [0.53, 0.53], [0.15, 0.99], transform=fig.transFigure, color="grey", linewidth=1
    )
)
fig.add_artist(
    Line2D([0.1, 1], [0.58, 0.58], transform=fig.transFigure, color="grey", linewidth=1)
)


plt.tight_layout(pad=0.7)

plt.show()
plt.close()

fig.savefig(
    f"figs/paper_plots/combined_intervention.pdf",
    bbox_inches="tight",
)

In [None]:
from intervene_in_middle_of_circle import get_points

# from matplotlib.colors import tab20

font = {"size": 5}

matplotlib.rc("font", **font)

s = 0.1


task = DaysOfWeekTask(device, model_name="mistral", dtype=dtype)
layer = 5
token = task.a_token
durations = range(2, 6)
circle_letter = "a"
pca_k = 5
circle_size = len(task.allowable_tokens)


fig, axs = plt.subplots(2, 2, figsize=(2.2, 2.2))

for ax, duration in zip(axs.flatten(), durations):
    filename = f"figs/{task.name}/varying_circle/logits_{layer}_{token}_{pca_k}_{duration}_{circle_letter}.npy"
    all_logits = np.load(filename)
    all_points, angles, radius_vals = get_points()
    best_a = np.argmax(all_logits, axis=1)
    for i in range(circle_size):
        ax.scatter(
            all_points[best_a == i, 0],
            all_points[best_a == i, 1],
            label=task.allowable_tokens[i],
            s=s,
            # color=tab20(i)
        )
    ax.set_title(f"Task Duration = {duration} Days")
    # ax.legend()
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_aspect("equal", adjustable="box")

    # Plot unit circle of size 7 in black
    # ax.plot(np.cos(np.arange(0, 7) * 2 * np.pi / 7), np.sin(np.arange(0, 7) * 2 * np.pi / 7), 'o', color="black")

handles, labels = axs[0][0].get_legend_handles_labels()


# plt.suptitle("Highest Logit Day After Intervention")

fig = plt.gcf()

fig.tight_layout()

# Put legend below figure
lgnd = fig.legend(
    handles,
    labels,
    loc="upper center",
    ncol=4,
    bbox_to_anchor=(0.5, 0.02),
    fontsize=5,
    frameon=False,
    columnspacing=0,
)
for i in range(circle_size):
    lgnd.legendHandles[i]._sizes = [10]

plt.show()

fig.savefig(
    f"figs/paper_plots/mistral_highest_logit_day_after_intervention.pdf",
    bbox_extra_artists=(lgnd,),
    bbox_inches="tight",
)

In [None]:
font = {"size": 7}

matplotlib.rc("font", **font)

s = 2

fig = plt.figure(figsize=(1.65, 1.5))
ax = plt.gca()

task = DaysOfWeekTask(device, model_name="mistral", dtype=dtype)
acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)

problems = task.generate_problems()
a = np.array([problem.info[0] for problem in problems])
b = np.array([problem.info[1] for problem in problems])
c = np.array([problem.info[2] for problem in problems])

explaining_variables = []
for i in range(7):
    explaining_variables.append(a == i)
for i in range(1, 8):
    explaining_variables.append(b == i)

explaining_variables = np.array(explaining_variables).T

print(explaining_variables.shape)

least_squares_sol = np.linalg.lstsq(explaining_variables, acts)[0]

residuals = acts - explaining_variables @ least_squares_sol

pca = PCA(n_components=2)
pca.fit(residuals)

print(pca.explained_variance_ratio_)

projected = pca.transform(residuals)

for day_of_week in range(7):
    ax.plot(
        projected[c == day_of_week, 0],
        projected[c == day_of_week, 1],
        "o",
        label=task.allowable_tokens[day_of_week],
        markersize=s,
    )

# ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=3)

ax.tick_params(axis="both", which="major", labelsize=5)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)
ax.spines["left"].set_position("zero")
ax.spines["bottom"].set_position("zero")
# ax.set_xticks([])
# ax.set_yticks([])

positions = [
    [-4.5, 0.5],
    [-3, 2.2],
    [0.8, 2.3],
    [2.5, 0.7],
    [2.3, -1.0],
    [0.5, -2.6],
    [-3.5, -2.6],
]

# Add text to plot
texts = []
for i in range(7):
    x, y = positions[i]
    # texts.append(ax.text(x, y, rf'$\gamma$ = {task.allowable_tokens[i][:3]}', fontsize=6))
    texts.append(ax.text(x, y, rf"{task.allowable_tokens[i][:3]}", fontsize=6))

plt.show()

fig.savefig(
    f"figs/paper_plots/mistral_residuals_c_pca.pdf",
    bbox_inches="tight",
)

In [None]:
# Performance table

for task_name in ["days_of_week", "months_of_year"]:
    for model_name in ["mistral", "llama"]:
        results = pd.read_csv(
            f"{BASE_DIR}/{model_name}_{task_name}/results.csv", skipinitialspace=True
        )
        number_correct = results["best_token"] == results["ground_truth"]
        print(task_name, model_name, np.sum(number_correct))

# GPT 2
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2", device=device, dtype=dtype)

for task_name in ["days_of_week", "months_of_year"]:
    if task_name == "days_of_week":
        task = DaysOfWeekTask(device, model_name="gpt2", dtype=dtype)
    else:
        task = MonthsOfYearTask(device, model_name="gpt2", dtype=dtype)
    problems = task.generate_problems()
    answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]
    num_correct = 0
    for problem in problems:
        logits = model(problem.prompt).cpu()[0][-1]
        top_from_answer_logits = np.argmax(logits[answer_logits])
        correct_answer = problem.info[2]
        if top_from_answer_logits == correct_answer:
            num_correct += 1
    print(task_name, "gpt2", num_correct)