每个MLP层中的神经元对每个tile分类的贡献，可以理解为该神经元对每个tile的激活程度。通过已经已经训练好的线性格子类别分类器，可以使用cosine similarity来计算每个神经元对每个tile的激活程度。
如果神经元是正值，则认为该神经元对该tile的分类有贡献。如果神经元是负值，则认为该神经元对该tile的分类有抑制作用。
此时，如果cosine similarity大于阈值，则认为该神经元对该tile的分类有贡献。
- 因为输入是整个棋局的sequence，所以如果想预测每个tile的分类，需要在Transformer中计算每一个棋子的翻转，所以推断翻转次数越多的tiles在越靠后的层被激活。
- 但是在Cosine Similarity中，没有对应的每个tile的翻转次数进行衡量，因为分析的是模型中神经元的表现，所以无法证明翻转次数的相关性。尝试使用SAE来衡量每个tile的翻转次数。


In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
# Imports
import torch
import matplotlib.pyplot as plt
from utils.dataloaders import get_dataloader
import utils.dataloaders
from tqdm import tqdm
import os
import sys
import numpy as np
import seaborn as sns
import analysis
from neel_plotly import line, scatter, imshow, histogram
# GPU acceleration
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
layers = list(range(0, 8))
mode = 0 # ["empty", "own", "enemy"]
seed = 9999
threshold = 0.25
save_path = True

model_location = f"trained_model_full_{seed}.pkl"
with open(model_location, 'rb') as f:
    othello_gpt=torch.load(f, map_location=device)

for layer in layers:
    # Probe
    probe_path = f"probes/probe_layer_{layer}_{seed}_trimmed.pkl"
    full_linear_probe = torch.load(probe_path)
    # print(full_linear_probe)
    my_probe_W = full_linear_probe[f'classifier.{mode}.weight'] # (64, 512)
    my_probe_W = my_probe_W.t()
    # my_probe_W[:, [27, 28, 35, 36]] = 0.
    my_probe_normalised = my_probe_W / my_probe_W.norm(dim=0, keepdim=True) # torch.Size([512, 64])
    my_probe_normalised[my_probe_normalised.isnan()] = 0.

    # Weight
    weight_in_key = f"blocks.{layer}.mlp_sublayer.encode.weight"
    
    heatmaps_my = []
    w_in = othello_gpt.state_dict()[weight_in_key] # torch.Size([512, 2048])
    w_in /= w_in.norm(dim=0, keepdim=True)
    for neuron in range(0, 2048):
        neuron_weight = w_in[neuron, :]
        heatmaps_my.append((neuron_weight[:, None] * my_probe_normalised).sum(dim=0))
        
    heatmaps_my = torch.stack(heatmaps_my) # (2048, 64)
    
    activation_map = (heatmaps_my.abs()>threshold).sum(dim=0)

    activation_map = activation_map.reshape(8, 8).cpu().numpy()

    fig, ax = plt.subplots(figsize=(4, 4))
    cax = ax.matshow(activation_map, cmap="viridis", vmin=0, vmax=10)

    plt.colorbar(cax)

    
    for (i, j), val in np.ndenumerate(activation_map):
        ax.text(j, i, f"{val}", ha='center', va='center', color='white' if val <= 5 else 'black')

    plt.title(f"Activation map of layer {layer+1}")
    
    ax.set_xticks(range(8))
    ax.set_yticks(range(8))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    if save_path:
        if not os.path.exists("./results/pics/cosim_results"):
            os.makedirs("./results/pics/cosim_results")
        save_file = f"./results/pics/cosim_results/{seed}_activation_map_layer_{layer+1}.png"
        plt.savefig(save_file, dpi=300, bbox_inches='tight')
        print(f"Saved activation map for layer {layer+1} to {save_file}")

    plt.close(fig)  # Close the figure to free memory



Saved activation map for layer 1 to ./results/pics/cosim_results/9999_activation_map_layer_1.png
Saved activation map for layer 2 to ./results/pics/cosim_results/9999_activation_map_layer_2.png
Saved activation map for layer 3 to ./results/pics/cosim_results/9999_activation_map_layer_3.png
Saved activation map for layer 4 to ./results/pics/cosim_results/9999_activation_map_layer_4.png
Saved activation map for layer 5 to ./results/pics/cosim_results/9999_activation_map_layer_5.png
Saved activation map for layer 6 to ./results/pics/cosim_results/9999_activation_map_layer_6.png
Saved activation map for layer 7 to ./results/pics/cosim_results/9999_activation_map_layer_7.png
Saved activation map for layer 8 to ./results/pics/cosim_results/9999_activation_map_layer_8.png
