# Setup

In [None]:
import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib
import numpy as np
import einops
from tqdm import tqdm
from typing import Callable, Optional
import os
import pickle
import json

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
    get_mlp_activations_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games,
    BoardPlayer,
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils
import circuits.othello_utils as othello_utils
from circuits.othello_engine_utils import to_board_label, to_string, to_int, stoi_indices #to_string: mode_output_vocab to interpretable square index
import circuits.eval_sae_as_classifier as eval_sae

device = 'cuda:0'
tracer_kwargs = {'validate' : False, 'scan' : False}
repo_dir = '/share/u/can/chess-gpt-circuits'
repo_dir = "/home/adam/chess-gpt-circuits"

### Load model, submodule, ae

In [None]:
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = get_model(model_name, device)

In [None]:
layer = 5

def get_ae(layer: int):
    # node_type = "sae_feature"
    node_type = "mlp_neuron"


    if node_type == "sae_feature":
        ae_group_name = 'all_layers_othello_p_anneal_0530_with_lines'
        ae_type = 'p_anneal'
        trainer_id = 0
        ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}/trainer{trainer_id}'
        submodule = get_submodule(model_name, layer, model)
    elif node_type == "mlp_neuron":
        ae_group_name = 'othello_mlp_acts_identity_aes_lines' # with_lines
        ae_type = 'identity'
        ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
        submodule = get_mlp_activations_submodule(model_name, layer, model)
    else:
        raise ValueError('Invalid node_type')

    # download data from huggingface if needed
    if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
        hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
        # unzip the data
        os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d {repo_dir}/autoencoders')

    # Initialize the autoencoder
    if ae_type == 'standard' or ae_type == 'p_anneal':
        ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
    elif ae_type == 'gated' or ae_type == 'gated_anneal':
        ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
    elif ae_type == 'standard_new':
        ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
    elif ae_type == 'identity':
        ae = IdentityDict()
    else:
        raise ValueError('Invalid ae_type')

    return ae, submodule

ae, submodule = get_ae(layer)

### Load legal move neurons

In [None]:
# # for each neuron, we need the full hypothesis [feature_idx, valid_square_idx, [config_idxs]]
# import json
# with open(os.path.join(ae_path, 'hpc_hrc_same_square_indexes_dict.json'), 'r') as f:
#     hpc_hrc_same_square_indexes_dict = json.load(f)

# print(hpc_hrc_same_square_indexes_dict.keys())
# print(hpc_hrc_same_square_indexes_dict['high_precision_and_recall'])
# print(hpc_hrc_same_square_indexes_dict['intersection_FSqC'])

In [None]:
# feat_idxs, valid_square_idxs, line_idxs = t.tensor(hpc_hrc_same_square_indexes_dict['intersection_FSqC']).T
# unique_feat_idxs = t.unique(feat_idxs)
# unique_valid_square_idxs = t.unique(valid_square_idxs)
# n_unique_feat_idxs = len(unique_feat_idxs)

# for feat_idx, square_idx in zip(unique_feat_idxs, unique_valid_square_idxs):
#     print(feat_idx, square_idx)
#     config_idxs_per_feat = line_idxs[feat_idxs == feat_idx]
#     feat_idx = int(feat_idx)
#     square_idx = int(square_idx)
#     if feat_idx > 500:
#         break

# print(feat_idx, square_idx, config_idxs_per_feat)

### Dataset

In [None]:
dataset_size = 500

ablation_dataset_name = "othello_ablation_dataset.pkl"

if os.path.exists(ablation_dataset_name):
    print("Loading ablation dataset")
    with open(ablation_dataset_name, "rb") as f:
        ablation_data = pickle.load(f)
else:
    ablation_data = eval_sae.construct_othello_dataset(
        custom_functions=[
            othello_utils.games_batch_to_state_stack_length_lines_mine_BLRRC,
            othello_utils.games_batch_to_valid_moves_BLRRC,
        ],
        n_inputs=dataset_size,
        split="train",
        precompute_dataset=True,
        device=device,
    )
    print("Saving ablation dataset")
    with open(ablation_dataset_name, "wb") as f:
        pickle.dump(ablation_data, f)

In [None]:
def check_any_present(train_data, valid_move_idx, line_idxs):
    '''for every position in every game, evaluate whether any of the lines is present'''
    r = valid_move_idx // 8
    c = valid_move_idx % 8
    data_BLC = train_data['games_batch_to_state_stack_length_lines_mine_BLRRC'][:, :, r, c, :]
    B, L, C = data_BLC.shape

    lines_C = t.zeros(C, dtype=t.int64)
    lines_C[line_idxs] = 1

    data_with_lines_BLC = data_BLC * lines_C
    any_line_present_BL = einops.reduce(data_with_lines_BLC, 'b L C -> b L', 'sum') > 0
    return any_line_present_BL

def check_valid_move(train_data: dict, valid_move_idx: int) -> t.Tensor:
    '''for every position in every game, evaluate whether there is a valid move on the square'''
    r = valid_move_idx // 8
    c = valid_move_idx % 8
    data_BL1 = train_data['games_batch_to_valid_moves_BLRRC'][:, :, r, c]
    data_BL = data_BL1.squeeze().bool()
    return data_BL

def game_state_where_line_present(train_data, valid_move_idx, line_idxs):
    any_line_present = check_any_present(train_data, valid_move_idx, line_idxs)
    enc_inputs = t.tensor(train_data['encoded_inputs'], device=device)
    game_state_where_line_present = []
    for game, line_present in zip(enc_inputs, any_line_present):
        if line_present.sum() > 0:
            first_occurence = t.where(line_present)[0][0]
            game_state_where_line_present.append(game[:first_occurence+1])
    return game_state_where_line_present

# square_test = check_valid_move(ablation_data, square_idx)
# print(square_test.shape)

# def game_state_my_move_before_line_present(train_data, valid_move_idx, line_idxs):
#     game_state_where_line_present = game_state_where_line_present(train_data, valid_move_idx, line_idxs)
#     return [game[:-2] for game in game_state_where_line_present]

In [None]:
# # # Test
# game_states_line_present = game_state_where_line_present(ablation_data, square_idx, config_idxs_per_feat)
# game = game_states_line_present[1]
# game = t.tensor(game, device=device)
# player = BoardPlayer(game)

# square_idx, config_idxs_per_feat

In [None]:
# player.next()

# Steering

In [None]:
# Get 1 game, where any line is present 

In [None]:
# def steering(model, game_batch, submodule, ae, feat_idx, square_idx, steering_factor, device='cpu'):
    

#     for game, V_square_idx, V_rotated_square_idx, C_square_idx in zip(game_batch):
#         V_square_idx = square_idx
#     V_row = V_square_idx // 8
#     V_col = V_square_idx % 8
#     V_rotated_row = 7 - V_row
#     V_rotated_col = 7 - V_col
#     V_rotated_square_idx = V_rotated_row * 8 + V_rotated_col
#     C_square_idxs = game[-2]
#     # Clean forward pass
#     with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
#         x = submodule.output
#         f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#         logits_clean = model.unembed.output.save() # batch_size x seq_len x vocab_size

#     steering_value = f[:, -1, feat_idx] # Activation value where valid_move is present
    
#     # Steering forward pass
#     with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
#         x = submodule.output
#         f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#         f[:, :, feat_idx] = steering_value * steering_factor
#         submodule.output = ae.decode(f)
#         logits_steer = model.unembed.output.save() # batch_size x seq_len x vocab_size

#     # Logit diffs for t-2
#     arange_batch = t.arange(batch_size, device=device)
#     logit_diff_clean = logits_clean[arange_batch, -3, to_int(V_square_idxs)] - logits_clean[arange_batch, -3, to_int(C_square_idxs)]
#     logit_diff_steer = logits_steer[arange_batch, -3, to_int(V_square_idxs)] - logits_steer[arange_batch, -3, to_int(C_square_idxs)]

#     rotated_logit_diff_clean = logits_clean[arange_batch, -3, to_int(V_rotated_square_idxs)] - logits_clean[arange_batch, -3, to_int(C_square_idxs)]
#     rotated_logit_diff_steer = logits_steer[arange_batch, -3, to_int(V_rotated_square_idxs)] - logits_steer[arange_batch, -3, to_int(C_square_idxs)]
    
#     steer_clean_diff = logit_diff_steer - logit_diff_clean
#     rotated_steer_clean_diff = rotated_logit_diff_steer - rotated_logit_diff_clean
#     return steer_clean_diff, rotated_steer_clean_diff

In [None]:
# def steering(model, game_batch, submodule, ae, feat_idx, square_idx, steering_factor, timestep=-2, device='cpu'):
#     batch_size = len(game_batch)
#     steer_clean_diffs = t.zeros(batch_size, device=device)
#     rotated_steer_clean_diffs = t.zeros(batch_size, device=device)
#     boards_clean = t.zeros(batch_size, 64, device=device)
#     boards_steer = t.zeros(batch_size, 64, device=device)

#     for i, game in tqdm(enumerate(game_batch), desc='Steering Batch', total=batch_size):
#         V_square_idx = square_idx
#         V_row = V_square_idx // 8
#         V_col = V_square_idx % 8
#         V_rotated_row = 7 - V_row
#         V_rotated_col = 7 - V_col
#         V_rotated_square_idx = V_rotated_row * 8 + V_rotated_col
#         C_square_idx = game[-2]

#         # Clean forward pass
#         with t.no_grad(), model.trace(game, **tracer_kwargs):
#             x = submodule.output
#             f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#             logits_clean = model.unembed.output.save() # batch_size x seq_len x vocab_size

#         steering_value = f[:, -1, feat_idx] # Activation value where valid_move is present
        
#         # Steering forward pass
#         with t.no_grad(), model.trace(game, **tracer_kwargs):
#             x = submodule.output
#             f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#             f[:, :, feat_idx] = steering_value * steering_factor
#             submodule.output = ae.decode(f)
#             logits_steer = model.unembed.output.save() # batch_size x seq_len x vocab_size

#         # Logit diffs for t-2
#         logit_diff_clean = logits_clean[:, timestep, to_int(V_square_idx)] - logits_clean[:, timestep, to_int(C_square_idx)]
#         logit_diff_steer = logits_steer[:, timestep, to_int(V_square_idx)] - logits_steer[:, timestep, to_int(C_square_idx)]

#         rotated_logit_diff_clean = logits_clean[:, timestep, to_int(V_rotated_square_idx)] - logits_clean[:, timestep, to_int(C_square_idx)]
#         rotated_logit_diff_steer = logits_steer[:, timestep, to_int(V_rotated_square_idx)] - logits_steer[:, timestep, to_int(C_square_idx)]
    
#         steer_clean_diffs[i] = logit_diff_steer - logit_diff_clean
#         rotated_steer_clean_diffs[i] = rotated_logit_diff_steer - rotated_logit_diff_clean
#         boards_clean[i][stoi_indices] = logits_clean[0, timestep, 1:]
#         boards_steer[i][stoi_indices] = logits_steer[0, timestep, 1:]

#     return steer_clean_diffs, rotated_steer_clean_diffs, boards_clean, boards_steer

In [None]:
# square_idx, config_idxs_per_feat

In [None]:
# game_batch_example = game_state_where_line_present(train_data_example, square_idx, config_idxs_per_feat)
# print('Number of games where line is present:', len(game_batch_example))
# for steering_factor in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]:
#     steer_clean_diff, rotated_steer_clean_diff, boards_clean, boards_steer = steering(
#         model, 
#         game_batch_example, 
#         submodule, 
#         ae, 
#         feat_idx, 
#         square_idx, 
#         steering_factor,
#         timestep=-1,
#         device=device
#         )
#     steer_clean_diff = steer_clean_diff.mean()
#     rotated_steer_clean_diff = rotated_steer_clean_diff.mean()
#     print(f"Steering factor: {steering_factor}, steer_clean_diff: {steer_clean_diff.item() - rotated_steer_clean_diff.item()}, rotated_steer_clean_diff: {rotated_steer_clean_diff.item()}")
#     # plt.imshow(boards_clean[3].view(8, 8).cpu().numpy() - boards_steer[0].view(8, 8).cpu().numpy())
#     # plt.show()

# Ablation

single batch of size 100
track number of times the condition is present

do clean forward pass
do mean ablation / zero ablation forward pass.

compute difference in IEs. 
Do Activation patching for HRC and HPC neurons first
Test wheter activation patching is also feasible for...

In [None]:
def activation_patching(
    model,
    train_data: dict,
    submodule,
    ae,
    feat_idxs: t.Tensor,
    square_idx: int,
    config_idxs_per_feat,
    ablation_method="zero",
    device="cpu",
    filter_for_valid_moves: bool = False,
):
    allowed_methods = ["mean", "zero", "max"]
    assert ablation_method in allowed_methods, f"Invalid ablation method. Must be one of {allowed_methods}"
    game_batch = t.tensor(train_data["encoded_inputs"])

    # Get clean logits and mean submodule activations
    with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
        x = submodule.output
        if ablation_method == "mean":
            f = ae.encode(x)  # shape: [batch_size, seq_len, n_features]
            # print(f.shape)
            # f_mean_clean = einops.reduce(f, 'b l f -> f', 'mean')#.save()
            f_mean_clean = f.mean(dim=(0, 1)).save()
            # print(f_mean_clean.shape)
        elif ablation_method == "max":
            f = ae.encode(x)  # shape: [batch_size, seq_len, n_features]
            f_max_clean = f.max(dim=0).values
            f_max_clean = f_max_clean.max(dim=0).values.save()
        logits_clean_BLV = model.unembed.output.save()  # batch_size x seq_len x vocab_size

    # Get patch logits
    with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
        x = submodule.output
        f = ae.encode(x)
        if ablation_method == "mean":
            f[:, :, feat_idxs] = f_mean_clean[feat_idxs]
        elif ablation_method == "max":
            f[:, :, feat_idxs] = f_max_clean[feat_idxs]
        else:
            f[:, :, feat_idxs] = 0
        submodule.output = ae.decode(f)
        logits_patch_BLV = model.unembed.output.save()

    B, L , V = logits_clean_BLV.shape

    logit_diff_BLV = logits_patch_BLV - logits_clean_BLV
    logit_diff_BL = logit_diff_BLV[:, :, to_int(square_idx)]

    probs_clean_BLV = t.nn.functional.softmax(logits_clean_BLV, dim=-1)
    probs_patch_BLV = t.nn.functional.softmax(logits_patch_BLV, dim=-1)

    probs_diff_BLV = probs_patch_BLV - probs_clean_BLV
    probs_diff_BL = probs_diff_BLV[:, :, to_int(square_idx)]

    # 
    probs_diff_others_BLV = probs_diff_BLV[:, :, t.arange(V) != to_int(square_idx)]


    if filter_for_valid_moves:
        # Filter for valid moves on square_idx
        custom_mask_BL = check_valid_move(train_data, square_idx)
    else:
        # Filter for a list of lines in config_idxs_per_feat
        custom_mask_BL = check_any_present(train_data, square_idx, config_idxs_per_feat)

    probs_diff_others_present_BLV = probs_diff_others_BLV[custom_mask_BL]
    probs_diff_others_absent_BLV = probs_diff_others_BLV[~custom_mask_BL]

    logit_diff_line_present_BL = logit_diff_BL[custom_mask_BL]
    logit_diff_line_absent_BL = logit_diff_BL[~custom_mask_BL]

    probs_diff_line_present_BL = probs_diff_BL[custom_mask_BL]
    probs_diff_line_absent_BL = probs_diff_BL[~custom_mask_BL]

    logits_present_BLV = logits_patch_BLV[custom_mask_BL]
    logits_absent_BLV = logits_patch_BLV[~custom_mask_BL]

    return (
        logit_diff_line_present_BL,
        logit_diff_line_absent_BL,
        custom_mask_BL,
        probs_diff_line_present_BL,
        probs_diff_line_absent_BL,
        logits_absent_BLV,
        logits_present_BLV,
        probs_diff_others_present_BLV,
        probs_diff_others_absent_BLV,
    )

In [None]:
batch_size = 400
n_batches = 1


def extract_batch(data, batch_idx, batch_size):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size
    batch_data = {
        key: (
            value[start_idx:end_idx]
            if isinstance(value, list)
            else value[start_idx:end_idx].clone()
        )
        for key, value in data.items()
    }
    return batch_data


def get_all_neurons_for_square(intersection_FSqC: list, square_of_interest: int) -> t.Tensor:
    unique_neurons = set()
    for feat_idx, square_idx, config_idxs_per_feat in intersection_FSqC:
        if square_idx == square_of_interest:
            unique_neurons.add(feat_idx)
    return list(unique_neurons)


def get_square_for_neuron(intersection_FSqC: list, feat_idx: int) -> int:
    for feat_idx_, square_idx, config_idxs_per_feat in intersection_FSqC:
        if feat_idx_ == feat_idx:
            return square_idx
    raise ValueError(f"Neuron {feat_idx} not found in intersection_FSqC")


def run_ablations(
    ablation_method: str, feat_idxs: t.Tensor, square_idx: int, config_idxs_per_feat: Optional[t.Tensor] = None
):

    logit_diffs_line_present = []
    logit_diffs_line_absent = []
    line_masks = []
    probs_diffs_line_present = []
    probs_diffs_line_absent = []

    if config_idxs_per_feat is None:
        filter_for_valid_moves = True
    else:
        filter_for_valid_moves = False

    # Iterate over the number of batches
    for batch_idx in range(n_batches):
        train_data = extract_batch(ablation_data, batch_idx, batch_size)

        (
            logit_diff_line_present,
            logit_diff_line_absent,
            line_mask,
            probs_diff_line_present,
            probs_diff_line_absent,
            logits_absent_BLV,
            logits_present_BLV,
            probs_diff_others_present_BLV,
            probs_diff_others_absent_BLV,
        ) = activation_patching(
            model,
            train_data,
            submodule,
            ae,
            feat_idxs,
            square_idx,
            config_idxs_per_feat,
            ablation_method=ablation_method,
            device=device,
            filter_for_valid_moves=filter_for_valid_moves,
        )

        if logit_diff_line_present.numel() > 0:
            logit_diffs_line_present.append(logit_diff_line_present.mean())
            probs_diffs_line_present.append(probs_diff_line_present.mean())
        probs_diffs_line_absent.append(probs_diff_line_absent.mean())
        logit_diffs_line_absent.append(logit_diff_line_absent.mean())
        line_masks.append(line_mask.float().mean())

        # print(f'batch {batch_idx}')
        # print(f'mean logit diff for line present: {logit_diff_line_present.mean()}')
        # print(f'mean logit diff for line absent: {logit_diff_line_absent.mean()}')
        # print(f'fraction of line present: {line_mask.float().mean()}')

    try:
        print(f"mean logit diff for condition present: {t.stack(logit_diffs_line_present).mean()}")
        print(f"mean logit diff for condition absent: {t.stack(logit_diffs_line_absent).mean()}")
        print(
            f"mean probability diff for condition present: {t.stack(probs_diffs_line_present).mean()}"
        )
        print(
            f"mean probability diff for condition absent: {t.stack(probs_diffs_line_absent).mean()}"
        )
        print(f"fraction of condition present: {t.stack(line_masks).mean()}")
    except:
        print("No valid moves")

    return (
        logit_diff_line_present,
        logit_diff_line_absent,
        line_mask,
        probs_diff_line_present,
        probs_diff_line_absent,
        logits_absent_BLV,
        logits_present_BLV,
        probs_diff_others_present_BLV,
        probs_diff_others_absent_BLV,
    )

pickle_dict = {}
ablation_method = 'zero'
ablation_method = 'mean'
ablation_method = 'max'


for layer in range(0, 8):
# for layer in [5]:
    pickle_dict[layer] = {}

    ae, submodule = get_ae(layer)

    filename = 'union_neurons-T0.95.json'
    with open(filename, 'r') as f:
        union_neurons = json.load(f)

    print(union_neurons.keys())
    print(union_neurons[str(layer)])

    neurons = {}

    for feat_idx, square_idx, length, dir in union_neurons[str(layer)]:
        feat_idx = int(feat_idx)
        square_idx = int(square_idx)
        line_index = length * 8 + dir
        if feat_idx not in neurons:
            neurons[feat_idx] = {}
            neurons[feat_idx]['square'] = square_idx
            neurons[feat_idx]['lines'] = [line_index]
        else:
            neurons[feat_idx]['lines'].append(line_index)

    print(neurons.keys())
    print(len(neurons))

    all_probs_diff_line_present = []
    all_probs_diff_others_present = []

    for i, feat_idx in enumerate(neurons):
        square_idx = neurons[feat_idx]['square']
        config_idxs_per_feat = neurons[feat_idx]['lines']
        print(feat_idx, square_idx, config_idxs_per_feat)
        
        (
            logit_diff_line_present,
            logit_diff_line_absent,
            line_mask,
            probs_diff_line_present,
            probs_diff_line_absent,
            logits_absent_BLV,
            logits_present_BLV,
            probs_diff_others_present,
            probs_diff_others_absent,
        ) = run_ablations(ablation_method, feat_idx, square_idx, config_idxs_per_feat)

        all_probs_diff_line_present.append(probs_diff_line_present)
        all_probs_diff_others_present.append(probs_diff_others_present)

        # if i >= 10:
        #     break

    if len(all_probs_diff_line_present) == 0:
        continue
    probs_diff_line_present = t.cat(all_probs_diff_line_present, dim=0)
    probs_diff_others_present = t.cat(all_probs_diff_others_present, dim=0)

    pickle_dict[layer] = {
        'probs_diff_line_present': probs_diff_line_present,
        'probs_diff_line_absent': probs_diff_line_absent,
        'probs_diff_others_present': probs_diff_others_present,
        'probs_diff_others_absent': probs_diff_others_absent,

    }

with open(f'all_layers_probs_method_{ablation_method}_diff.pkl', 'wb') as f:
    pickle.dump(pickle_dict, f)

# Analyze every square with all neurons per square
# for square_idx in range(64):
#     square_valid_move_neurons = get_all_neurons_for_square(
#         hpc_hrc_same_square_indexes_dict["intersection_FSqC"], square_idx
#     )

#     print(square_valid_move_neurons)

#     print(f"\nFor square {square_idx}, there are {len(square_valid_move_neurons)} neurons")
#     if len(square_valid_move_neurons) > 0:
#         logits_absent_BLV, logits_present_BLV = run_ablations(t.tensor(square_valid_move_neurons), square_idx)

#     if square_idx >= 0:
#         break

In [None]:
import pickle
import matplotlib.pyplot as plt
import torch as t

ablation_method = 'max'
ablation_method = 'mean'

with open(f'all_layers_probs_method_{ablation_method}_diff.pkl', 'rb') as f:
    pickle_dict = pickle.load(f)

key1 = 'probs_diff_line_present'
key2 = 'probs_diff_others_present'

if ablation_method == "max":
    key1 = 'probs_diff_line_absent'
    key2 = 'probs_diff_others_absent'

def get_layers_results(layers: list[int]) -> tuple[t.Tensor, t.Tensor]:
    all_probs_diff_line_present = []
    all_probs_diff_others_present = []

    for layer in layers:
        if layer not in pickle_dict or len(pickle_dict[layer]) == 0:
            continue
        layer_probs_diff = pickle_dict[layer][key1]
        layer_probs_diff_others = pickle_dict[layer][key2]
        all_probs_diff_line_present.append(layer_probs_diff)
        all_probs_diff_others_present.append(layer_probs_diff_others)
    total_probs_diff_line_present = t.cat(all_probs_diff_line_present, dim=0)
    total_probs_diff_others_present = t.cat(all_probs_diff_others_present, dim=0)

    return total_probs_diff_line_present, total_probs_diff_others_present

layers = [0, 1, 2, 3, 4, 5, 6]
# layers = [6, 7]
# layers = [7]
# layers = [5]

probs_diff_line_present, probs_diff_others_present = get_layers_results(layers)

density = False
# density = True

print(probs_diff_line_present.shape)
print(probs_diff_others_present.shape)

print(probs_diff_line_present.mean())
print(probs_diff_others_present.mean())

plt.xlabel('Token Probability Difference')
plt.ylabel('Count')

plot_title = 'Histogram of Probability Differences for Layer(s) '
for layer in layers:
    plot_title += f'{layer}, '

plot_title = plot_title[:-2]

plot_title += f"\n\nAverage probability difference for neuron square token: {probs_diff_line_present.mean().item():.4f}"
plot_title += f"\nAverage probability difference for all other tokens: {probs_diff_others_present.mean().item():.4f}"

plt.title(plot_title)

x = plt.hist(probs_diff_line_present[abs(probs_diff_line_present)>1e-8].cpu().numpy().flatten(), bins=100, alpha=0.5, label="Neuron Square Token", density=density)
plt.hist(probs_diff_others_present[abs(probs_diff_others_present)>1e-8].cpu().numpy().flatten(), bins=100, alpha=0.5, label="All Other Tokens", density=density)
print(x)
# plt.xlim(-0.3, 0.3)
# plt.ylim(0, 40000)
plt.yscale('log')
plt.legend()
plt.show()

# plt.hist(logits_present_BLV.cpu().numpy().flatten(), bins=100, alpha=0.5, label="present")
# plt.hist(logits_absent_BLV.cpu().numpy().flatten(), bins=100, alpha=0.5, label="absent")