In [None]:
import argparse
import itertools
import os
import pickle
from collections.abc import Generator
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, List, Optional, Tuple, Union
from scipy.spatial.distance import pdist, squareform
from fastcluster import linkage

import numpy as np
import numpy.typing as npt
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from scipy.stats import ortho_group
from torchtyping import TensorType
from tqdm import tqdm

# Load computed matrices

In [None]:
import seaborn as sns
# palette = sns.color_palette()
# sns.color_palette()
palette = sns.color_palette("magma_r")
display(palette)
palette = [palette[i] for i in [0,1,3,5]]
plt.style.use('default')

In [None]:
# Load all matrices for subplots
pre = f"metrics_22"
p_name = f"1"

append = "^p" # "^p" if lp^p, else ""
# append = ""

sorted_gf_cosim = torch.load(f"{pre}/gfl{p_name}{append}.pt")
sorted_ff_cosim = torch.load(f"{pre}/ffl{p_name}{append}.pt")
sorted_jaccard = torch.load(f"{pre}/gfl{p_name}{append}_jaccard.pt")
sorted_auto_jaccard = torch.load(f"{pre}/ffl{p_name}{append}_jaccard.pt")
sorted_gf_corr = torch.load(f"{pre}/gfl{p_name}{append}_corr.pt")
sorted_ff_corr = torch.load(f"{pre}/ffl{p_name}{append}_corr.pt")

# p_name = f"0.2"
append = ""

pre = f"/root/sparsify/sorted_matrices"
sorted_gpt2_cosim = torch.load(f"{pre}/gpt2_sorted_cosim_matrix_l{p_name}{append}.pt").cpu()
sorted_gpt2_jaccard = torch.load(f"{pre}/gpt2_sorted_jaccard_l{p_name}{append}.pt").cpu()
sorted_gpt2_corr = torch.load(f"{pre}/gpt2_sorted_corr_l{p_name}{append}.pt").cpu()

sorted_pythia14m_cosim = torch.load(f"{pre}/pythia14m_sorted_cosim_matrix_l{p_name}{append}.pt").cpu()
sorted_pythia14m_jaccard = torch.load(f"{pre}/pythia14m_sorted_jaccard_l{p_name}{append}.pt").cpu()
sorted_pythia14m_corr = torch.load(f"{pre}/pythia14m_sorted_corr_l{p_name}{append}.pt").cpu()

In [None]:
import matplotlib.pyplot as plt

matrices = [sorted_gf_cosim, sorted_ff_cosim, sorted_gpt2_cosim[:100,:100], sorted_gf_corr, sorted_ff_corr, sorted_gpt2_corr[:100,:100]] 

# Create a 2x3 subplot
fig, axs = plt.subplots(2, 3, figsize=(10, 10), dpi=300)

# Row and Column Titles/Labels
row_titles = ["Cosine Similarity", "Correlation"]
col_titles = ["Toy Data Ground Truth vs SAE Features", "Toy Data SAE Features", "GPT2 SAE Features"]
y_axis_labels = [
    "Ground Truth",
    "SAE",
    "SAE",
    "Ground Truth",
    "SAE",
    "SAE",
]

# Set titles for the columns
for ax, col_title in zip(axs[0], col_titles):
    ax.set_title(col_title)

# # Set titles for the rows
# for ax, row_title in zip(axs[:,0], row_titles):
#     ax.set_ylabel(row_title, rotation=90, size='large', labelpad=20)

# Loop through all plots to plot the matrices
for idx, ax in enumerate(axs.flat):
    N, M = matrices[idx].shape
    to_square_aspect = M / N
    im = ax.imshow(matrices[idx], cmap="PiYG", vmin=-1, vmax=1, aspect=to_square_aspect)
    ax.set_xlabel("SAE")
    ax.set_ylabel(y_axis_labels[idx])
    ax.set_xticks([0, M/3, 2*M/3, M])
    ax.set_xticklabels([f"{int(0)}", f"{int(M/3)}", f"{int(2*M/3)}", f"{int(M)}"])
    ax.set_yticks([0, N/3, 2*N/3, N])
    ax.set_yticklabels([f"{int(0)}", f"{int(N/3)}", f"{int(2*N/3)}", f"{int(N)}"])
    
for i, label in enumerate(row_titles):
    fig.text(0.0, 1/2 - (2*i-1)*0.22, label, va='center', ha='center', rotation='vertical', size='large')

# Create a colorbar with a bit of a hack, since all plots share the same color scale
fig.subplots_adjust(right=0.8, wspace=0.5, hspace=0.5)
cbar_ax = fig.add_axes([1, 0.15, 0.03, 0.7])  # Adjust these values as needed for your layout
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel('Cosine Similarity / Correlation', rotation=270, labelpad=15)

plt.tight_layout()
# plt.show()
plt.savefig(f'2by3_l{p_name}{append}.png', bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt

matrices = [sorted_gf_cosim, sorted_ff_cosim, sorted_pythia14m_cosim[:100,:100], sorted_gf_corr, sorted_ff_corr, sorted_pythia14m_corr[:100,:100]] 

# Create a 2x3 subplot
fig, axs = plt.subplots(2, 3, figsize=(10, 10), dpi=300)

# Row and Column Titles/Labels
row_titles = ["Cosine Similarity", "Correlation"]
col_titles = ["Toy Data Ground Truth vs SAE Features", "Toy Data SAE Features", "Pythia14m SAE Features"]
y_axis_labels = [
    "Ground Truth",
    "SAE",
    "SAE",
    "Ground Truth",
    "SAE",
    "SAE",
]

# Set titles for the columns
for ax, col_title in zip(axs[0], col_titles):
    ax.set_title(col_title)

# # Set titles for the rows
# for ax, row_title in zip(axs[:,0], row_titles):
#     ax.set_ylabel(row_title, rotation=90, size='large', labelpad=20)

# Loop through all plots to plot the matrices
for idx, ax in enumerate(axs.flat):
    N, M = matrices[idx].shape
    to_square_aspect = M / N
    im = ax.imshow(matrices[idx], cmap="PiYG", vmin=-1, vmax=1, aspect=to_square_aspect)
    ax.set_xlabel("SAE")
    ax.set_ylabel(y_axis_labels[idx])
    ax.set_xticks([0, M/3, 2*M/3, M])
    ax.set_xticklabels([f"{int(0)}", f"{int(M/3)}", f"{int(2*M/3)}", f"{int(M)}"])
    ax.set_yticks([0, N/3, 2*N/3, N])
    ax.set_yticklabels([f"{int(0)}", f"{int(N/3)}", f"{int(2*N/3)}", f"{int(N)}"])
    
for i, label in enumerate(row_titles):
    fig.text(0.0, 1/2 - (2*i-1)*0.22, label, va='center', ha='center', rotation='vertical', size='large')

# Create a colorbar with a bit of a hack, since all plots share the same color scale
fig.subplots_adjust(right=0.8, wspace=0.5, hspace=0.5)
cbar_ax = fig.add_axes([1, 0.15, 0.03, 0.7])  # Adjust these values as needed for your layout
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.ax.set_ylabel('Cosine Similarity / Correlation', rotation=270, labelpad=15)

plt.tight_layout()
# plt.show()
plt.savefig(f'2by3_l{p_name}{append}_pythia14m.png', bbox_inches='tight')

In [None]:
plt.imshow(sorted_auto_jaccard, cmap="PiYG", vmin=-1, vmax=1)  # cosim_matrix.detach().cpu()
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270, labelpad=15)
plt.xlabel("Sorted SAE features")
plt.ylabel("Sorted SAE features")
plt.title(f"SAE Features Jaccard Index (L{p_name})")

In [None]:
plt.imshow(sorted_auto_jaccard, cmap="PiYG", vmin=-1, vmax=1)  # cosim_matrix.detach().cpu()
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270)
plt.xlabel("Sorted SAE features")
plt.ylabel("Sorted SAE features")
plt.title(f"SAE Features Jaccard Index (L{p_name})")
plt.savefig(f"images_{cfg.seed}/ffl{p_name}_jaccard.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_auto_jaccard, f"{pre}/ffl{p_name}_jaccard.pt")