In [None]:
from circuits.dictionary_learning.dictionary import AutoEncoder
import pickle
import os
import torch as t
import matplotlib.pyplot as plt
import torch.nn.functional as F



import circuits.analysis as analysis
import circuits.eval_sae_as_classifier as eval_sae


device = 'cuda'
device = 'cpu'

Key for mapping from the last entry of the board state tensor to pieces:
* 0 => black king
* 1 => black queen
* 2 => black rook
* 3 => black bishop
* 4 => black knight
* 5 => black pawn
* 6 => empty
* 7 => white pawn
* 8 => white knight
* 9 => white bishop
* 10 => white rook
* 11 => white queen
* 12 => white king

In [None]:
# load SAE
ae_path = '../autoencoders/chess_layer5_large_sweep/ef=16_lr=1e-03_l1=3e-02_layer_5'
ae_path = '../autoencoders/chess_layer5_large_sweep/ef=4_lr=1e-03_l1=1e-01_layer_5'

import circuits.dictionary_learning.dictionary as dictionary

def get_ae(ae_path):

    if "gated" in ae_path:
        ae = dictionary.GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
    else:
        ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
    return ae

def get_feature_labels(ae_path):
    with open(os.path.join(ae_path, 'indexing_find_dots_indices_n_inputs_1000_results.pkl'), 'rb') as f:
        results = pickle.load(f)
    results = to_device(results)
    feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, high_threshold=0.98, save_results=False, print_results=False, verbose=False, mask=True)
    return feature_labels


In [None]:
import torch.nn.functional as F

# def find_idx(feat_idx: int) -> int:
#     return (feature_labels['alive_features'] == feat_idx).nonzero().item()
# find_idx(1600)

threshold = 5
func_name = "board_to_piece_state"

linear_probe_path = "../linear_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth"

with open(linear_probe_path, "rb") as f:
    state_dict = t.load(f, map_location=device)
    print(state_dict.keys())
    linear_probe_MDRRC = state_dict["linear_probe"]


def get_average_from_list_of_tuples(l: list[tuple[float, int]]) -> float:
    return sum([x[0] for x in l]) / len(l)

def get_cos_sims(ae, feature_labels) -> list[float]:

    decoder_weights = ae.decoder.weight.data.to(device)

    cos_sims_true = []
    cos_sims_random1 = []
    cos_sims_random2 = []

    for idx in range(feature_labels['alive_features'].shape[0]):
        feat_idx = feature_labels['alive_features'][idx]
        num_classified_squares = (feature_labels[func_name][threshold][idx] > .95).sum()

        if num_classified_squares == 0:
            continue

        # if num_classified_squares > 10:
        #     continue

        classified_squares = (feature_labels[func_name][threshold][idx] > .95).nonzero()
        linear_probe_vector = t.zeros(linear_probe_MDRRC.shape[1]).to(device)
        random_linear_probe_vector = t.zeros(linear_probe_MDRRC.shape[1]).to(device)
        random_vector = t.randn(linear_probe_MDRRC.shape[1]).to(device)

            
        for square in classified_squares:
            # print(square)
            linear_probe_vector += linear_probe_MDRRC[0, :, square[0], square[1], square[2]] / linear_probe_MDRRC[0, :, square[0], square[1], square[2]].norm()

            random_square = t.randint(0, 8, (2,))
            random_class = t.randint(0, 13, (1,))
            random_linear_probe_vector += linear_probe_MDRRC[0, :, random_square[0], random_square[1], random_class[0]] / linear_probe_MDRRC[0, :, random_square[0], random_square[1], random_class[0]].norm()


        decoder_vector = decoder_weights[:, feat_idx]

        cos_sim_true = F.cosine_similarity(linear_probe_vector.squeeze(), decoder_vector.squeeze(), dim=0)
        cos_sim_random1 = F.cosine_similarity(random_linear_probe_vector.squeeze(), decoder_vector.squeeze(), dim=0)
        cos_sim_random2 = F.cosine_similarity(random_vector.squeeze(), decoder_vector.squeeze(), dim=0)

        cos_sims_true.append((cos_sim_true.item(), num_classified_squares))
        cos_sims_random1.append((cos_sim_random1.item(), num_classified_squares))
        cos_sims_random2.append((cos_sim_random2.item(), num_classified_squares))

    print(len(cos_sims_true))
    # print("Cosine Similarity True:", sum(cos_sims_true) / len(cos_sims_true))
    # print("Cosine Similarity Random1:", sum(cos_sims_random1) / len(cos_sims_random1))
    # print("Cosine Similarity Random2:", sum(cos_sims_random2) / len(cos_sims_random2))

    return cos_sims_true, cos_sims_random1, cos_sims_random2



ae_group_paths = ['../autoencoders/chess_layer5_large_sweep/']
ae_group_paths = [
    '../autoencoders/group-2024-05-14_chess/group-2024-05-14_chess-gated/', 
                  '../autoencoders/group-2024-05-14_chess/group-2024-05-14_chess-standard/', 
                  '../autoencoders/group-2024-05-14_chess/group-2024-05-14_chess-p_anneal/', 
                  '../autoencoders/group-2024-05-14_chess/group-2024-05-14_chess-gated_anneal/', 
                  ]

all_results = {}

for ae_group_path in ae_group_paths:

    folders = eval_sae.get_nested_folders(ae_group_path)
    print(folders)

    average_cos_sim_true = {}
    average_cos_sim_random1 = {}
    average_cos_sim_random2 = {}

    all_cos_sims_true = {}
    all_cos_sims_random1 = {}
    all_cos_sims_random2 = {}

    for autoencoder_path in folders:
        print(autoencoder_path)
        ae = get_ae(autoencoder_path)
        feature_labels = get_feature_labels(autoencoder_path)

        cos_sims_true, cos_sims_random1, cos_sims_random2 = get_cos_sims(ae, feature_labels)

        if len(cos_sims_true) < 30:
            continue

        all_cos_sims_true[autoencoder_path] = cos_sims_true
        all_cos_sims_random1[autoencoder_path] = cos_sims_random1
        all_cos_sims_random2[autoencoder_path] = cos_sims_random2

        average_cos_sim_true[autoencoder_path] = get_average_from_list_of_tuples(cos_sims_true)
        average_cos_sim_random1[autoencoder_path] = get_average_from_list_of_tuples(cos_sims_random1)
        average_cos_sim_random2[autoencoder_path] = get_average_from_list_of_tuples(cos_sims_random2)

    all_results[ae_group_path] = (average_cos_sim_true, average_cos_sim_random1, average_cos_sim_random2, all_cos_sims_true, all_cos_sims_random1, all_cos_sims_random2)

In [None]:
import pandas as pd

filename = "processed_results_group-2024-05-14-chess-v2.csv"

df = pd.read_csv(filename)

top_k_f1 = 20

top_k_paths = df.nlargest(top_k_f1, 'best_f1_score_per_square_average')['autoencoder_path']
print(top_k_paths.to_list())

best_ae_types = {}

for path in top_k_paths:
    for ae_type in ae_group_paths:
        if ae_type in path and ae_type not in best_ae_types:
            best_ae_types[ae_type] = path

In [None]:
import matplotlib.pyplot as plt

def create_histogram(hist_list: list[float], max_square_count: int, group_name: str, trainer_nums: list[str], title: str):

    plt.figure(figsize=(10, 6))
    plt.hist(hist_list, bins=20, color='blue', alpha=0.7, edgecolor='black')
    plt.title(f'{title } Maximum square count: {max_square_count} {group_name} trainers: {trainer_nums}')
    # plt.xlim(-0.15, 1)
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()


for ae_type in ae_group_paths:

    hist_list = []


    sort_metric = 0
    metric = 3

    hist_list = []

    max_val = 0
    for i, ae_path in enumerate(all_results[ae_type][sort_metric]):
        cos_sim = all_results[ae_type][sort_metric][ae_path]
        max_val = max(max_val, cos_sim)
        hist_list.append(cos_sim)
    print(max_val)

    top_k = 1
    max_square_count = 5

    sorted_ae_paths = sorted(all_results[ae_type][sort_metric], key=lambda x: all_results[ae_type][sort_metric][x], reverse=True)
    print(sorted_ae_paths[:top_k])

    hist_list = []

    trainer_nums = []

    for i, ae_path in enumerate(sorted_ae_paths[:top_k]):
        trainer_nums.append(ae_path.split("/")[-2])
        cos_sim = all_results[ae_type][metric][ae_path]
        # hist_list.append(cos_sim)
        hist_list.extend(cos_sim)

    if type(hist_list[0]) == tuple:
        temp_hist_list = []


        for cos_sims in hist_list:
            if cos_sims[1] <= max_square_count:
                temp_hist_list.append(cos_sims[0])
        hist_list = temp_hist_list

    # for cos_sims in all_cos_sims_true:
    # # for cos_sims in all_cos_sims_random1:
    #     if cos_sims[1] < 201:
    #         hist_list.append(cos_sims[0])
        
    # hist_list = all_results[ae_group_paths[group]][0]


    create_histogram(hist_list, max_square_count, ae_type.split("/")[-2], trainer_nums, "Best cos sims")

In [None]:
import pandas as pd

filename = "processed_results_group-2024-05-14-chess-v2.csv"

df = pd.read_csv(filename)

top_k_f1 = 20
f1_metric = 'best_f1_score_per_square_average'
# f1_metric = "best_custom_metric_average"

top_k_paths = df.nlargest(top_k_f1, f1_metric)['autoencoder_path']
print(top_k_paths.to_list())

best_ae_types = {}

for path in top_k_paths:
    for ae_type in ae_group_paths:
        if ae_type in path and ae_type not in best_ae_types:
            best_ae_types[ae_type] = path
        
group = 0
metric = 3



for ae_type in best_ae_types:
    hist_list = all_results[ae_type][metric][best_ae_types[ae_type]]

    if type(hist_list[0]) == tuple:
        temp_hist_list = []


        for cos_sims in hist_list:
            if cos_sims[1] <= max_square_count:
                temp_hist_list.append(cos_sims[0])
        hist_list = temp_hist_list

    trainer_nums = [best_ae_types[ae_type].split("/")[-2]]

    create_histogram(hist_list, max_square_count, ae_type.split("/")[-2], trainer_nums, "Cos sim for best board reconstruction")

In [None]:
print("Cosine Similarity True:", sum(cos_sims_true) / len(cos_sims_true))
print("Cosine Similarity Random1:", sum(cos_sims_random1) / len(cos_sims_random1))
print("Cosine Similarity Random2:", sum(cos_sims_random2) / len(cos_sims_random2))

In [None]:
# load SAE
ae_path = '../autoencoders/chess_layer5_large_sweep/ef=16_lr=1e-03_l1=3e-02_layer_5'
ae_path = '../autoencoders/chess_layer5_large_sweep/ef=4_lr=1e-03_l1=1e-01_layer_5'
ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')

# load information about features
with open(os.path.join(ae_path, 'indexing_find_dots_indices_n_inputs_1000_results.pkl'), 'rb') as f:
    results = pickle.load(f)

def to_device(d, device=device):
    if isinstance(d, t.Tensor):
        return d.to(device)
    if isinstance(d, dict):
        return {k: to_device(v, device) for k, v in d.items()}
results = to_device(results)

feature_labels, misc_stats = analysis.analyze_results_dict(results, "", device, high_threshold=0.95, save_results=False, print_results=False, verbose=False, mask=True)

start = 0
for idx in range(start, start + 600):
    feat_idx = feature_labels['alive_features'][idx]
    # print(feat_idx)

    best_idx = 1

    print(f"\n{idx}")
    # print("Board states that the feature classifies according to Adam's measurements:")
    print((feature_labels['board_to_piece_state'][best_idx][idx] > .95).nonzero())
    # print("Number of such board states:")
    # print(results['board_to_piece_state'].shape)
    print((feature_labels['board_to_piece_state'][best_idx][idx] > .95).sum())

    # print(results['thresholds'].shape)
    # print(results['thresholds'][best_idx][idx])



In [None]:
# set idx to be the index of the (alive) feature you want to visualize
idx = 1436
idx = 18
feat_idx = feature_labels['alive_features'][idx]
print(feat_idx)

best_idx = 1

print("Board states that the feature classifies according to Adam's measurements:")
print((feature_labels['board_to_piece_state'][best_idx][idx] > .95).nonzero())

labels = (feature_labels['board_to_piece_state'][best_idx][idx] > .95).nonzero()

print("Number of such board states:")
print(feature_labels['board_to_piece_state'].shape)
print((feature_labels['board_to_piece_state'][best_idx][idx] > .95).sum())

print(feature_labels['thresholds'].shape)
print(feature_labels['thresholds'][best_idx][idx])

In [None]:
import torch.nn.functional as F

linear_probe_path = "../linear_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth"

with open(linear_probe_path, "rb") as f:
    state_dict = t.load(f, map_location=device)
    print(state_dict.keys())
    linear_probe_MDRRC = state_dict["linear_probe"]

linear_probe_vector = t.zeros(linear_probe_MDRRC.shape[1]).to(device)

for label in labels:
    print(label)
    linear_probe_vector += linear_probe_MDRRC[0, :, label[0], label[1], label[2]]
    linear_probe_vector += linear_probe_MDRRC[0, :, label[0], label[1], 1]
    linear_probe_vector += linear_probe_MDRRC[0, :, label[0] + 1, label[1], label[2]]


decoder_weights = ae.decoder.weight.data

d_model, hidden_dim = decoder_weights.shape

max_cosine_similarity = 0

print("Hidden Dim:", hidden_dim)
for i in range(hidden_dim):
    

    decoder_vector = decoder_weights[:, i].to(device)

    # Calculate norms
    norm_linear_probe_vector = linear_probe_vector.norm()
    norm_decoder_vector = decoder_vector.norm()

    # Print norms
    # print("Norm of linear probe vector:", norm_linear_probe_vector.item())
    # print("Norm of decoder vector:", norm_decoder_vector.item())


    cos_sim = F.cosine_similarity(linear_probe_vector.unsqueeze(0), decoder_vector.unsqueeze(0), dim=1)
    # print("Cosine similarity:", cos_sim.item())
    max_cosine_similarity = max(max_cosine_similarity, cos_sim.item())

cosine_similarities = []

for i in range(hidden_dim):
    decoder_vector = decoder_weights[:, i].to(device)

    # Calculate the cosine similarity
    cos_sim = F.cosine_similarity(linear_probe_vector.unsqueeze(0), decoder_vector.unsqueeze(0), dim=1).item()
    cosine_similarities.append((i, cos_sim))

# Get the top 10 cosine similarities
top_10_cosine_similarities = sorted(cosine_similarities, key=lambda x: x[1], reverse=True)[:10]

# Calculate the average cosine similarity
average_cosine_similarity = sum(sim[1] for sim in cosine_similarities) / len(cosine_similarities)

# Print the top 10 and average cosine similarities
print("Top 10 Cosine Similarities:", top_10_cosine_similarities)
print("Average Cosine Similarity:", average_cosine_similarity)


In [None]:
# cosine sim 2 color histogram over SAEs x axis is the cosine similarity, y axis is the number of features that have that cosine similarity

In [None]:

random_cos_sims = []
for i in range(1000):
    random_vector_1 = t.randn((512,)).to(device)
    random_vector_2 = t.randn((512,)).to(device)
    cos_sim = F.cosine_similarity(random_vector_1.unsqueeze(0), random_vector_2.unsqueeze(0), dim=1).item()
    random_cos_sims.append(cos_sim)


plt.figure(figsize=(10, 6))
plt.hist(random_cos_sims, bins=20, color='blue', alpha=0.7, edgecolor='black')
plt.title(f'Histogram of Cosine Similarities {ae_group_paths[group]}')
plt.xlim(-0.15, 1)
plt.xlabel('Cosine Similarity')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()