# Setup

In [None]:
import torch as t
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import importlib
import numpy as np
import einops
from tqdm import tqdm
from typing import Callable
import os
import pickle

from circuits.dictionary_learning.dictionary import AutoEncoder, AutoEncoderNew, GatedAutoEncoder, IdentityDict
from circuits.utils import (
    othello_hf_dataset_to_generator,
    get_model,
    get_submodule,
    get_mlp_activations_submodule,
)

from feature_viz_othello_utils import (
    get_acts_IEs_VN,
    plot_lenses,
    plot_mean_metrics,
    plot_top_k_games,
    BoardPlayer,
)

import circuits.utils as utils
import circuits.analysis as analysis
import feature_viz_othello_utils as viz_utils
from circuits.othello_utils import games_batch_to_state_stack_length_lines_mine_BLRCC
from circuits.othello_engine_utils import to_board_label, to_string, to_int, stoi_indices #to_string: mode_output_vocab to interpretable square index
import circuits.eval_sae_as_classifier as eval_sae

device = 'cuda:0'
tracer_kwargs = {'validate' : False, 'scan' : False}
repo_dir = '/share/u/can/chess-gpt-circuits'
repo_dir = "/home/adam/chess-gpt-circuits"

### Load model, submodule, ae

In [None]:
model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
model = get_model(model_name, device)

In [None]:
layer = 5
# node_type = "sae_feature"
node_type = "mlp_neuron"


if node_type == "sae_feature":
    ae_group_name = 'all_layers_othello_p_anneal_0530_with_lines'
    ae_type = 'p_anneal'
    trainer_id = 0
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}/trainer{trainer_id}'
    submodule = get_submodule(model_name, layer, model)
elif node_type == "mlp_neuron":
    ae_group_name = 'othello_mlp_acts_identity_aes_lines' # with_lines
    ae_type = 'identity'
    ae_path = f'{repo_dir}/autoencoders/{ae_group_name}/layer_{layer}'
    submodule = get_mlp_activations_submodule(model_name, layer, model)
else:
    raise ValueError('Invalid node_type')

# download data from huggingface if needed
if not os.path.exists(f'{repo_dir}/autoencoders/{ae_group_name}'):
    hf_hub_download(repo_id='adamkarvonen/othello_saes', filename=f'{ae_group_name}.zip', local_dir=f'{repo_dir}/autoencoders')
    # unzip the data
    os.system(f'unzip {repo_dir}/autoencoders/{ae_group_name}.zip -d {repo_dir}/autoencoders')

# Initialize the autoencoder
if ae_type == 'standard' or ae_type == 'p_anneal':
    ae = AutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'gated' or ae_type == 'gated_anneal':
    ae = GatedAutoEncoder.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'standard_new':
    ae = AutoEncoderNew.from_pretrained(os.path.join(ae_path, 'ae.pt'), device='cuda:0')
elif ae_type == 'identity':
    ae = IdentityDict()
else:
    raise ValueError('Invalid ae_type')

### Load legal move neurons

In [None]:
# for each neuron, we need the full hypothesis [feature_idx, valid_square_idx, [config_idxs]]
import json
with open(os.path.join(ae_path, 'hpc_hrc_same_square_indexes_dict.json'), 'r') as f:
    hpc_hrc_same_square_indexes_dict = json.load(f)

print(hpc_hrc_same_square_indexes_dict.keys())

In [None]:
feat_idxs, valid_square_idxs, line_idxs = t.tensor(hpc_hrc_same_square_indexes_dict['intersection_FSqC']).T
unique_feat_idxs = t.unique(feat_idxs)
unique_valid_square_idxs = t.unique(valid_square_idxs)
n_unique_feat_idxs = len(unique_feat_idxs)

for feat_idx, square_idx in zip(unique_feat_idxs, unique_valid_square_idxs):
    config_idxs_per_feat = line_idxs[feat_idxs == feat_idx]
    feat_idx = int(feat_idx)
    square_idx = int(square_idx)
    if feat_idx > 500:
        break

print(feat_idx, square_idx, config_idxs_per_feat)

### Dataset

In [None]:

dataset_size = 500

ablation_dataset_name = "othello_ablation_dataset.pkl"

if os.path.exists(ablation_dataset_name):
    print("Loading ablation dataset")
    with open(ablation_dataset_name, "rb") as f:
        ablation_data = pickle.load(f)
else:
    ablation_data = eval_sae.construct_othello_dataset(
        custom_functions=[games_batch_to_state_stack_length_lines_mine_BLRCC],
        n_inputs=dataset_size,
        split='train',
        precompute_dataset = True,
        device=device,
    )
    print("Saving ablation dataset")
    with open(ablation_dataset_name, "wb") as f:
        pickle.dump(ablation_data, f)


In [None]:
def check_any_present(train_data, valid_move_idx, line_idxs):
    '''for every position in every game, evaluate whether any of the lines is present'''
    r = valid_move_idx // 8
    c = valid_move_idx % 8
    data = train_data['games_batch_to_state_stack_length_lines_mine_BLRCC'][:, :, r, c]
    line_idxs_expanded = line_idxs.view(1, 1, -1).expand(data.shape[0], data.shape[1], -1)
    data = t.gather(data, dim=-1, index=line_idxs_expanded)
    any_line_present = t.any(data, dim=-1)
    return any_line_present

def game_state_where_line_present(train_data, valid_move_idx, line_idxs):
    any_line_present = check_any_present(train_data, valid_move_idx, line_idxs)
    enc_inputs = t.tensor(train_data['encoded_inputs'], device=device)
    game_state_where_line_present = []
    for game, line_present in zip(enc_inputs, any_line_present):
        if line_present.sum() > 0:
            first_occurence = t.where(line_present)[0][0]
            game_state_where_line_present.append(game[:first_occurence+1])
    return game_state_where_line_present

# def game_state_my_move_before_line_present(train_data, valid_move_idx, line_idxs):
#     game_state_where_line_present = game_state_where_line_present(train_data, valid_move_idx, line_idxs)
#     return [game[:-2] for game in game_state_where_line_present]

In [None]:
# # Test
game_states_line_present = game_state_where_line_present(ablation_data, square_idx, config_idxs_per_feat)
game = game_states_line_present[1]
game = t.tensor(game, device=device)
player = BoardPlayer(game)

square_idx, config_idxs_per_feat

In [None]:
player.next()

# Steering

In [None]:
# Get 1 game, where any line is present 

In [None]:
# def steering(model, game_batch, submodule, ae, feat_idx, square_idx, steering_factor, device='cpu'):
    

#     for game, V_square_idx, V_rotated_square_idx, C_square_idx in zip(game_batch):
#         V_square_idx = square_idx
#     V_row = V_square_idx // 8
#     V_col = V_square_idx % 8
#     V_rotated_row = 7 - V_row
#     V_rotated_col = 7 - V_col
#     V_rotated_square_idx = V_rotated_row * 8 + V_rotated_col
#     C_square_idxs = game[-2]
#     # Clean forward pass
#     with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
#         x = submodule.output
#         f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#         logits_clean = model.unembed.output.save() # batch_size x seq_len x vocab_size

#     steering_value = f[:, -1, feat_idx] # Activation value where valid_move is present
    
#     # Steering forward pass
#     with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
#         x = submodule.output
#         f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
#         f[:, :, feat_idx] = steering_value * steering_factor
#         submodule.output = ae.decode(f)
#         logits_steer = model.unembed.output.save() # batch_size x seq_len x vocab_size

#     # Logit diffs for t-2
#     arange_batch = t.arange(batch_size, device=device)
#     logit_diff_clean = logits_clean[arange_batch, -3, to_int(V_square_idxs)] - logits_clean[arange_batch, -3, to_int(C_square_idxs)]
#     logit_diff_steer = logits_steer[arange_batch, -3, to_int(V_square_idxs)] - logits_steer[arange_batch, -3, to_int(C_square_idxs)]

#     rotated_logit_diff_clean = logits_clean[arange_batch, -3, to_int(V_rotated_square_idxs)] - logits_clean[arange_batch, -3, to_int(C_square_idxs)]
#     rotated_logit_diff_steer = logits_steer[arange_batch, -3, to_int(V_rotated_square_idxs)] - logits_steer[arange_batch, -3, to_int(C_square_idxs)]
    
#     steer_clean_diff = logit_diff_steer - logit_diff_clean
#     rotated_steer_clean_diff = rotated_logit_diff_steer - rotated_logit_diff_clean
#     return steer_clean_diff, rotated_steer_clean_diff

In [None]:
def steering(model, game_batch, submodule, ae, feat_idx, square_idx, steering_factor, timestep=-2, device='cpu'):
    batch_size = len(game_batch)
    steer_clean_diffs = t.zeros(batch_size, device=device)
    rotated_steer_clean_diffs = t.zeros(batch_size, device=device)
    boards_clean = t.zeros(batch_size, 64, device=device)
    boards_steer = t.zeros(batch_size, 64, device=device)

    for i, game in tqdm(enumerate(game_batch), desc='Steering Batch', total=batch_size):
        V_square_idx = square_idx
        V_row = V_square_idx // 8
        V_col = V_square_idx % 8
        V_rotated_row = 7 - V_row
        V_rotated_col = 7 - V_col
        V_rotated_square_idx = V_rotated_row * 8 + V_rotated_col
        C_square_idx = game[-2]

        # Clean forward pass
        with t.no_grad(), model.trace(game, **tracer_kwargs):
            x = submodule.output
            f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
            logits_clean = model.unembed.output.save() # batch_size x seq_len x vocab_size

        steering_value = f[:, -1, feat_idx] # Activation value where valid_move is present
        
        # Steering forward pass
        with t.no_grad(), model.trace(game, **tracer_kwargs):
            x = submodule.output
            f = ae.encode(x).save() # shape: [batch_size, seq_len, n_features]
            f[:, :, feat_idx] = steering_value * steering_factor
            submodule.output = ae.decode(f)
            logits_steer = model.unembed.output.save() # batch_size x seq_len x vocab_size

        # Logit diffs for t-2
        logit_diff_clean = logits_clean[:, timestep, to_int(V_square_idx)] - logits_clean[:, timestep, to_int(C_square_idx)]
        logit_diff_steer = logits_steer[:, timestep, to_int(V_square_idx)] - logits_steer[:, timestep, to_int(C_square_idx)]

        rotated_logit_diff_clean = logits_clean[:, timestep, to_int(V_rotated_square_idx)] - logits_clean[:, timestep, to_int(C_square_idx)]
        rotated_logit_diff_steer = logits_steer[:, timestep, to_int(V_rotated_square_idx)] - logits_steer[:, timestep, to_int(C_square_idx)]
    
        steer_clean_diffs[i] = logit_diff_steer - logit_diff_clean
        rotated_steer_clean_diffs[i] = rotated_logit_diff_steer - rotated_logit_diff_clean
        boards_clean[i][stoi_indices] = logits_clean[0, timestep, 1:]
        boards_steer[i][stoi_indices] = logits_steer[0, timestep, 1:]

    return steer_clean_diffs, rotated_steer_clean_diffs, boards_clean, boards_steer

In [None]:
square_idx, config_idxs_per_feat

In [None]:
game_batch_example = game_state_where_line_present(train_data_example, square_idx, config_idxs_per_feat)
print('Number of games where line is present:', len(game_batch_example))
for steering_factor in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]:
    steer_clean_diff, rotated_steer_clean_diff, boards_clean, boards_steer = steering(
        model, 
        game_batch_example, 
        submodule, 
        ae, 
        feat_idx, 
        square_idx, 
        steering_factor,
        timestep=-1,
        device=device
        )
    steer_clean_diff = steer_clean_diff.mean()
    rotated_steer_clean_diff = rotated_steer_clean_diff.mean()
    print(f"Steering factor: {steering_factor}, steer_clean_diff: {steer_clean_diff.item() - rotated_steer_clean_diff.item()}, rotated_steer_clean_diff: {rotated_steer_clean_diff.item()}")
    # plt.imshow(boards_clean[3].view(8, 8).cpu().numpy() - boards_steer[0].view(8, 8).cpu().numpy())
    # plt.show()

# Ablation

single batch of size 100
track number of times the condition is present

do clean forward pass
do mean ablation / zero ablation forward pass.

compute difference in IEs. 
Do Activation patching for HRC and HPC neurons first
Test wheter activation patching is also feasible for...

In [None]:
def activation_patching(model, train_data, submodule, ae, feat_idx, square_idx, config_idxs_per_feat, ablation_method='zero', device='cpu'):
    assert ablation_method in ['mean', 'zero'], "Invalid ablation method. Must be one of ['mean', 'zero']"
    game_batch = t.tensor(train_data['encoded_inputs'])

    # Get clean logits and mean submodule activations
    with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
        x = submodule.output
        if ablation_method == 'mean':
            f = ae.encode(x) # shape: [batch_size, seq_len, n_features]
            f_mean_clean = f.mean(dim=(0, 1))
        logits_clean = model.unembed.output.save() # batch_size x seq_len x vocab_size

    # Get patch logits
    with t.no_grad(), model.trace(game_batch, **tracer_kwargs):
        x = submodule.output
        f = ae.encode(x)
        if ablation_method == 'mean':
            f[:, :, feat_idx] = f_mean_clean[feat_idx]
        else:
            f[:, :, feat_idx] = 0
        submodule.output = ae.decode(f)
        logits_patch = model.unembed.output.save()

    logit_diff = logits_patch - logits_clean
    logit_diff = logit_diff[:, :, to_int(square_idx)]
    
    line_mask = check_any_present(train_data, square_idx, config_idxs_per_feat)
    logit_diff_line_present = logit_diff[line_mask]
    logit_diff_line_absent = logit_diff[~line_mask]
    return logit_diff_line_present, logit_diff_line_absent, line_mask

In [None]:
feat_idx, square_idx, config_idxs_per_feat

In [None]:
print(ablation_data.keys())
print(ablation_data['encoded_inputs'])
print(ablation_data['games_batch_to_state_stack_length_lines_mine_BLRCC'].shape)

In [None]:
batch_size=10
n_batches=40


def extract_batch(data, batch_idx, batch_size):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size
    batch_data = {
        key: value[start_idx:end_idx] if isinstance(value, list) else value[start_idx:end_idx].clone()
        for key, value in data.items()
    }
    return batch_data

logit_diffs_line_present = []
logit_diffs_line_absent = []
line_masks = []

# Iterate over the number of batches
for batch_idx in range(n_batches):
    train_data = extract_batch(ablation_data, batch_idx, batch_size)

    logit_diff_line_present, logit_diff_line_absent, line_mask = activation_patching(
        model,
        train_data,
        submodule,
        ae,
        feat_idx,
        square_idx,
        config_idxs_per_feat,
        ablation_method='zero',
        device=device,
    )

    if logit_diff_line_present.numel() > 0:
        logit_diffs_line_present.append(logit_diff_line_present.mean())
    logit_diffs_line_absent.append(logit_diff_line_absent.mean())
    line_masks.append(line_mask.float().mean())

    # print(f'batch {batch_idx}')
    # print(f'mean logit diff for line present: {logit_diff_line_present.mean()}')
    # print(f'mean logit diff for line absent: {logit_diff_line_absent.mean()}')
    # print(f'fraction of line present: {line_mask.float().mean()}')

print(f'mean logit diff for line present: {t.stack(logit_diffs_line_present).mean()}')
print(f'mean logit diff for line absent: {t.stack(logit_diffs_line_absent).mean()}')
print(f'fraction of line present: {t.stack(line_masks).mean()}')