In [None]:
import os
import pickle
import json

import torch
import einops
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder
import circuits.utils as utils
import circuits.analysis as analysis
import circuits.eval_sae_as_classifier as eval_sae


In [None]:
def get_results_filename(othello: bool) -> str:
    if othello:
        return 'indexing_None_n_inputs_1000_results.pkl'
    else:
        return 'indexing_find_dots_indices_n_inputs_1000_results.pkl'

def get_layer_data(base_path: str, othello: bool):

    results_filename = get_results_filename(othello)

    with open(os.path.join(base_path, results_filename), "rb") as f:
        results = pickle.load(f)
    results = utils.to_device(results, device)
    feature_labels, misc_stats = analysis.analyze_results_dict(
        results,
        "",
        device,
        save_results=False,
        verbose=False,
        print_results=False,
        significance_threshold=10,
    )

    with open(os.path.join(base_path, "n_inputs_1000_evals.pkl"), "rb") as f:
        eval_results = pickle.load(f)

    return feature_labels


def get_dataset(othello: bool, context_length: int):
    if not othello:
        with open("models/meta.pkl", "rb") as f:
            meta = pickle.load(f)

        dataset_name = "adamkarvonen/chess_sae_text"
        data = utils.chess_hf_dataset_to_generator(
            dataset_name, meta, context_length=context_length, split="train", streaming=True
        )
    else:
        dataset_name = "adamkarvonen/othello_45MB_games"
        data = utils.othello_hf_dataset_to_generator(
            dataset_name, context_length=context_length, split="train", streaming=False
        )
    return data


def rc_to_square_notation(row, col):
    letters = "ABCDEFGH"
    number = 8 - row
    letter = letters[col]
    return f"{letter}{number}"


def collect_activations(
    games_bL: torch.Tensor,
    context_length: int,
    ae_bundle: utils.AutoEncoderBundle,
    feature_labels: dict,
    device,
    total_games: int,
    batch_size: int,
) -> torch.Tensor:
    alive_features_F = feature_labels['alive_features']
    activations_bLF = torch.zeros(
        (total_games, context_length, alive_features_F.shape[0]), dtype=torch.float32, device=device
    )
    for i in tqdm(range(0, total_games, batch_size)):
        games_batch_BL = games_bL[i:i + batch_size]
        activations_FBL, tokens = eval_sae.collect_activations_batch(ae_bundle, games_batch_BL, alive_features_F)
        activations_BLF = einops.rearrange(activations_FBL, "F B L -> B L F")
        activations_bLF[i:i + batch_size] = activations_BLF


    feature_activations_Fb = einops.rearrange(activations_bLF, "B L F -> F (B L)")
    return feature_activations_Fb


func_name = "games_batch_to_valid_moves_BLRRC"
othello = True
device = "cuda"
# device = 'cpu'
torch.set_grad_enabled(False)

layer1 = 5
layer2 = 5

ae_path1 = f"../autoencoders/othello_mlp_acts_identity_aes/layer_{layer1}/"
# ae_path2 = f"../autoencoders/othello_mlp_acts_identity_aes/layer_{layer2}/"
ae_path2 = f"../autoencoders/all_layers_othello_p_anneal_0530/layer_{layer2}/trainer0/"

assert(f"layer_{layer1}" in ae_path1)
assert(f"layer_{layer2}" in ae_path2)


ae_bundle1 = utils.get_ae_bundle(
    ae_path1, device, data=[], batch_size=1
)
ae_bundle2 = utils.get_ae_bundle(
    ae_path2, device, data=[], batch_size=1
)

ae_bundle1.buffer = None
ae_bundle2.buffer = None
torch.cuda.empty_cache()

context_length = ae_bundle1.context_length
ae1 = ae_bundle1.ae
ae2 = ae_bundle2.ae

data = get_dataset(othello, context_length)

threshold1 = 3
threshold2 = 3

feature_labels1 = get_layer_data(ae_path1, othello)
feature_labels2 = get_layer_data(ae_path2, othello)

In [None]:
json_filename = "hpc_hrc_same_square_indexes_dict.json"

with open(json_filename, "r") as f:
    json_data = json.load(f)

print(json_data.keys())

hpr_indices = json_data['high_precision_and_recall']
hpr_indices_tensor = torch.tensor(hpr_indices)

In [None]:
def get_all_cos_sims_for_single_feature(x_vector_D: torch.Tensor, y_vectors_FD: torch.Tensor) -> list[float]:
    return F.cosine_similarity(x_vector_D, y_vectors_FD)

def get_all_max_cos_sims(x_vectors_FD: torch.Tensor, y_vectors_FD: torch.Tensor) -> torch.Tensor:
    max_cos_sims = []

    for x_vector in x_vectors_FD:
        max_cos_sims.append(get_all_cos_sims_for_single_feature(x_vector, y_vectors_FD).max())
    return torch.stack(max_cos_sims)

x_vectors_FD = ae_bundle2.model.blocks[layer1].mlp.W_out
filtered_x_vectors_FD = x_vectors_FD[hpr_indices_tensor]
y_vectors_DF = ae_bundle2.ae.decoder.weight
y_vectors_FD = einops.rearrange(y_vectors_DF, 'D F -> F D')
y_vectors_FD = ae_bundle2.ae.encoder.weight
print(x_vectors_FD.shape)
print(y_vectors_DF.shape)
print(y_vectors_FD.shape)

all_mlp_sims = get_all_max_cos_sims(x_vectors_FD, y_vectors_FD)

hpr_cos_sims = get_all_max_cos_sims(filtered_x_vectors_FD, y_vectors_FD)

plt.title("MLP Decoder vs SAE Encoder")
plt.xlabel("Cosine Similarity")
plt.ylabel("Density")

# Plot the first histogram (all_mlp_sims) with some transparency (alpha)
plt.hist(all_mlp_sims.cpu().numpy(), bins=20, alpha=0.5, label='All 2048 MLP Neurons', density=True)

# Plot the second histogram (hpr_cos_sims) on top of the first one
plt.hist(hpr_cos_sims.cpu().numpy(), bins=20, alpha=0.5, label='150 Legal Move MLP Neurons', density=True)

plt.legend()
plt.savefig("mlp_ae_encoder_cos_sims.png")
plt.show()

In [None]:
batch_size = 50
total_games = batch_size * 10

games_bL = torch.zeros((total_games, context_length), dtype=torch.long, device=device)

for i in range(0, total_games, batch_size):
    game_batch_BL = [next(data) for _ in range(batch_size)]
    game_batch_BL = torch.tensor(game_batch_BL, device=device)
    games_bL[i:i + batch_size] = game_batch_BL
    
activations1 = collect_activations(
    games_bL, context_length, ae_bundle1, feature_labels1, device, total_games, batch_size
)
activations2 = collect_activations(
    games_bL, context_length, ae_bundle2, feature_labels2, device, total_games, batch_size
)

In [None]:
def pearson_corr(x, y):
    mean_x = x.mean(dim=-1, keepdim=True)
    mean_y = y.mean(dim=-1, keepdim=True)
    xm = x - mean_x
    ym = y - mean_y
    r_num = torch.sum(xm * ym, dim=-1)
    r_den = torch.sqrt(torch.sum(xm * xm, dim=-1) * torch.sum(ym * ym, dim=-1))

    with torch.no_grad():
        zero_variance = r_den == 0
    r = torch.where(zero_variance, torch.zeros_like(r_num), r_num / r_den)

    return r

# def get_correlation_for_activation(x_activations_Fb: torch.Tensor, x_activation_index: int, y_activations_Fb: torch.Tensor) -> torch.Tensor:
#     x_activation_b = x_activations_Fb[x_activation_index]
#     correlations = pearson_corr(x_activation_b.unsqueeze(0), y_activations_Fb)
#     return correlations

# def get_correlation_per_feature(activations1: torch.Tensor, activations2: torch.Tensor) -> list[float]:
#     correlations = pearson_corr(activations1.unsqueeze(1), activations2.unsqueeze(0))
#     best_correlations, _ = torch.max(correlations, dim=1)
#     best_correlations = best_correlations[best_correlations != 0]
#     return best_correlations.tolist()

def get_correlation_for_activation(x_activations_Fb: torch.Tensor, x_activation_index: int, y_activations_Fb: torch.Tensor) -> torch.Tensor:
    x_activation_b = x_activations_Fb[x_activation_index]
    correlations = pearson_corr(x_activation_b.unsqueeze(0), y_activations_Fb)
    return correlations

def get_correlation_per_feature(activations1: torch.Tensor, activations2: torch.Tensor, batch_size: int = 1000) -> list[float]:
    num_activations1 = activations1.shape[0]
    num_activations2 = activations2.shape[0]
    
    best_correlations = []
    
    for i in tqdm(range(0, num_activations1, batch_size)):
        start_idx = i
        end_idx = min(i + batch_size, num_activations1)
        
        correlations = pearson_corr(activations1[start_idx:end_idx].unsqueeze(1), activations2.unsqueeze(0))
        batch_best_correlations, _ = torch.max(correlations, dim=1)
        batch_best_correlations = batch_best_correlations[batch_best_correlations != 0]
        best_correlations.extend(batch_best_correlations.tolist())
    
    return best_correlations


In [None]:


hpr_mlp_activations1 = activations1[hpr_indices]
print(hpr_mlp_activations1.shape)

In [None]:
best_correlations_hpr2 = get_correlation_per_feature(hpr_mlp_activations1, activations2, batch_size=1)

In [None]:

best_correlations_hpr = get_correlation_per_feature(hpr_mlp_activations1, activations2, batch_size=4)

best_correlations_all = get_correlation_per_feature(activations1, activations2, batch_size=4)

In [None]:


plt.title("Max Activation Correlation between Layer 5 MLP Neurons and SAE Features")
plt.xlabel("Pearson Correlation")
plt.ylabel("Density")

# Plot the second histogram (hpr_cos_sims) on top of the first one
plt.hist(best_correlations_all, bins=20, alpha=0.5, label='All 2048 MLP Neurons', density=True)

# Plot the first histogram (all_mlp_sims) with some transparency (alpha)
plt.hist(best_correlations_hpr, bins=20, alpha=0.5, label='150 Legal Move MLP Neurons', density=True)


plt.legend()

plt.savefig("mlp_ae_activation_correlation.png")

plt.show()


In [None]:


def get_correlations_per_labeled_feature(
    activations1: torch.Tensor,
    activations2: torch.Tensor,
    features1_F: torch.Tensor,
    features2_F: torch.Tensor,
    verbose: bool = False,
) -> list[float]:

    best_correlations_same_label = []
    best_correlations = []
    for f in features1_F[0]:
        if verbose:
            print(f"\n\n\n\nFeature {f.item()}")
        correlations = get_correlation_for_activation(activations1, f, activations2)
        k = 5
        values, indices = torch.topk(correlations, k)

        # Printing the top n values and their corresponding indices
        best_same_label_correlation = 0.0
        best_correlation = values[0].item()
        if best_correlation != 0.0:
            best_correlations.append(best_correlation)
        for index, value in zip(indices, values):
            if verbose:
                print(f"Feature Index: {index}, Value: {value.item():.2f}")
            if index in features2_F[0] and best_same_label_correlation == 0.0:
                best_same_label_correlation = value.item()
        if best_correlation != 0.0:
            best_correlations_same_label.append(best_same_label_correlation)
    return best_correlations_same_label, best_correlations


def compare_feature_labels(
    feature_labels1: dict,
    feature_labels2: dict,
    activations1: torch.Tensor,
    activations2: torch.Tensor,
    threshold1: int,
    threshold2: int,
    func_name: str,
    verbose: bool = False,
):
    labels1_TFRRC = feature_labels1[func_name]
    labels2_TFRRC = feature_labels2[func_name]

    counts1_T = einops.reduce(labels1_TFRRC, "T F R1 R2 C -> T", "sum")
    counts2_T = einops.reduce(labels2_TFRRC, "T F R1 R2 C -> T", "sum")

    print(f"Counts1: {counts1_T}")
    print(f"Counts2: {counts2_T}")

    labels1_FRRC = labels1_TFRRC[threshold1]
    labels2_FRRC = labels2_TFRRC[threshold2]

    best_correlations1 = []
    best_correlation_same_label1 = []
    best_correlations2 = []
    best_correlation_same_label2 = []

    for square in tqdm(range(64)):
        r = square // 8
        c = square % 8

        if r != 0:
            continue
        if c != 4:
            continue

        index = (r, c, 0)
        features1_F = torch.where(labels1_FRRC[:, index[0], index[1], index[2]] == 1)
        features2_F = torch.where(labels2_FRRC[:, index[0], index[1], index[2]] == 1)

        

        correlations_same_label1, correlations1 = get_correlations_per_labeled_feature(activations1, activations2, features1_F, features2_F, verbose=verbose)
        best_correlation_same_label1.extend(correlations_same_label1)
        best_correlations1.extend(correlations1)

        correlations_same_label2, correlations2 = get_correlations_per_labeled_feature(activations2, activations1, features2_F, features1_F, verbose=verbose)
        best_correlation_same_label2.extend(correlations_same_label2)
        best_correlations2.extend(correlations2)

        if verbose:
            print(r, c)
            print(f"Features1: {features1_F}")
            print(f"Features2: {features2_F}")
            print(f"Correlations1: {correlations1}")
            print(f"Correlations2: {correlations2}")

    return best_correlations1, best_correlations2, best_correlation_same_label1, best_correlation_same_label2


In [None]:
best_correlations1, best_correlations2, best_correlations_same_label1, best_correlations_same_label2 = compare_feature_labels(
    feature_labels1, feature_labels2, activations1, activations2, threshold1, threshold2, func_name, verbose=True
)