In [None]:
# Imports
import torch as t
import matplotlib.pyplot as plt
import numpy as np
import einops
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso
from sklearn.tree import DecisionTreeRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.tree import export_text
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.multioutput import MultiOutputClassifier
import sklearn

from circuits.utils import othello_hf_dataset_to_generator, get_model, get_aes
import circuits.othello_utils as othello_utils
from circuits.eval_sae_as_classifier import construct_othello_dataset

# Setup
device = "cuda:0" if t.cuda.is_available() else "cpu"
t.set_grad_enabled(False)

# Model and data loading
def load_model_and_data(model_name, batch_size, n_batches):
    model = get_model(model_name, device)
    data = construct_othello_dataset(
        custom_functions=[
            othello_utils.games_batch_to_classifier_input_BLC,
            othello_utils.games_batch_to_classifier_input_board_state_BLC,
            ],
        n_inputs=batch_size * n_batches,
        split="train",
        device=device,
    )
    return model, data

# Cache Neuron Activations
def cache_neuron_activations(model, data, layers, batch_size, n_batches):
    neuron_acts = defaultdict(list)

    for batch_idx in range(n_batches):
        batch_start = batch_idx * batch_size
        batch_end = (batch_idx + 1) * batch_size
        data_batch = data["encoded_inputs"][batch_start:batch_end]
        data_batch = t.tensor(data_batch, device=device)

        with t.no_grad(), model.trace(data_batch, scan=False, validate=False):
            for layer in layers:
                neuron_activations_BLD = model.blocks[layer].mlp.hook_post.output.save()
                neuron_acts[layer].append(neuron_activations_BLD)

    for layer in neuron_acts:
        neuron_acts[layer] = t.stack(neuron_acts[layer])
        neuron_acts[layer] = einops.rearrange(neuron_acts[layer], "n b l c -> (n b) l c")
    
    return neuron_acts

def calculate_binary_activations(neuron_acts, threshold=0.1):
    binary_acts = {}

    for layer in neuron_acts:
        max_activations_D = t.full((2048,), float("-inf"), device=device)

        neuron_acts_BLD = neuron_acts[layer]
        neuron_acts_BD = einops.rearrange(neuron_acts_BLD, "b l d -> (b l) d")

        max_activations_D = t.max(max_activations_D, neuron_acts_BD.max(dim=0).values)

        binary_acts[layer] = (neuron_acts[layer] > (threshold * max_activations_D)).int()
    return binary_acts

# Prepare data for modeling
def prepare_data(games_BLC, mlp_acts_BLD):
    X = einops.rearrange(games_BLC, "b l c -> (b l) c").cpu().numpy()
    y = einops.rearrange(mlp_acts_BLD, "b l d -> (b l) d").cpu().numpy()
    return train_test_split(X, y, test_size=0.2, random_state=42)

# Train and evaluate models
def train_and_evaluate(model, X_train, X_test, y_train, y_test):
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)
    return model, mse, r2

# Calculate metrics for all neurons
def calculate_neuron_metrics(model, X, y):
    y_pred = model.predict(X)
    mse_list = []
    r2_list = []
    for neuron_index in range(y.shape[1]):
        y_true_neuron = y[:, neuron_index]
        y_pred_neuron = y_pred[:, neuron_index]
        mse = mean_squared_error(y_true_neuron, y_pred_neuron)
        r2 = r2_score(y_true_neuron, y_pred_neuron)
        mse_list.append(mse)
        r2_list.append(r2)
    return mse_list, r2_list

def calculate_binary_metrics(model, X, y):
    y_pred = model.predict(X)
    accuracy_list = []
    precision_list = []
    recall_list = []
    f1_list = []
    
    for neuron_index in range(y.shape[1]):
        y_true_neuron = y[:, neuron_index]
        y_pred_neuron = y_pred[:, neuron_index]
        
        accuracy = accuracy_score(y_true_neuron, y_pred_neuron)
        precision = precision_score(y_true_neuron, y_pred_neuron, zero_division=0)
        recall = recall_score(y_true_neuron, y_pred_neuron, zero_division=0)
        f1 = f1_score(y_true_neuron, y_pred_neuron, zero_division=0)
        
        accuracy_list.append(accuracy)
        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)
    
    return accuracy_list, precision_list, recall_list, f1_list

# Print decision tree rules
def print_decision_tree_rules(model, feature_names, neuron_index, max_depth=None):
    tree = model.estimators_[neuron_index]
    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(tree.n_features_in_)]
    tree_rules = export_text(tree, feature_names=feature_names, max_depth=max_depth)
    print(f"Decision Tree Rules for Neuron {neuron_index}:")
    print(tree_rules)

def rc_to_square_notation(row, col):
    letters = "ABCDEFGH"
    letter = letters[row]
    number = 8 - col
    # letter = letters[col]
    return f"{letter}{number}"

def idx_to_square_notation(idx):
    row = idx // 8
    col = idx % 8
    square = rc_to_square_notation(row, col)
    return square

In [None]:

model_name = "Baidicoot/Othello-GPT-Transformer-Lens"
batch_size = 10
n_batches = 2
layers = list(range(8))

model, data = load_model_and_data(model_name, batch_size, n_batches)
neuron_acts = cache_neuron_activations(model, data, layers, batch_size, n_batches)
binary_acts = calculate_binary_activations(neuron_acts)
# games_BLC = data[othello_utils.games_batch_to_classifier_input_BLC.__name__]

games_BLC = data[othello_utils.games_batch_to_classifier_input_board_state_BLC.__name__]



In [None]:


results = {}
for layer in layers:
    print(f"Layer {layer}")
    X_train, X_test, y_train, y_test = prepare_data(games_BLC, neuron_acts[layer])
    
    # Linear Regression (Lasso)
    lasso_model, lasso_mse, lasso_r2 = train_and_evaluate(
        Lasso(alpha=0.005), X_train, X_test, y_train, y_test
    )
    
    # Decision Tree
    dt_model, dt_mse, dt_r2 = train_and_evaluate(
        MultiOutputRegressor(DecisionTreeRegressor(random_state=42, max_depth=5)),
        X_train, X_test, y_train, y_test
    )
    
    results[layer] = {
        'lasso': {'model': lasso_model, 'mse': lasso_mse, 'r2': lasso_r2},
        'decision_tree': {'model': dt_model, 'mse': dt_mse, 'r2': dt_r2}
    }



In [None]:


binary_results = {}

for layer in layers:
    print(f"Layer {layer}")
    X_binary_train, X_binary_test, y_binary_train, y_binary_test = prepare_data(games_BLC, binary_acts[layer])

    # Binary Decision Tree
    dt_model = MultiOutputClassifier(DecisionTreeClassifier(random_state=42, max_depth=5))
    dt_model.fit(X_binary_train, y_binary_train)
    y_pred = dt_model.predict(X_binary_test)
    
    # Calculate metrics
    accuracy = accuracy_score(y_binary_test, y_pred)
    precision = precision_score(y_binary_test, y_pred, average='weighted')
    recall = recall_score(y_binary_test, y_pred, average='weighted')
    f1 = f1_score(y_binary_test, y_pred, average='weighted')

    binary_results[layer] = {
        'decision_tree': {'model': dt_model, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}
    }

In [None]:

# Function to view metrics for a specific layer
def view_layer_metrics(layer: int, threshold: float, games_BLC: t.Tensor, neuron_acts: dict, results: dict, binary_results: dict):
    if layer not in results:
        print(f"Layer {layer} not found in results.")
        return

    print(f"\n\nMetrics for Layer {layer}:")
    print("Linear Regression (Lasso):")
    print(f"  MSE: {results[layer]['lasso']['mse']}")
    print(f"  R2: {results[layer]['lasso']['r2']}")
    print("Decision Tree:")
    print(f"  MSE: {results[layer]['decision_tree']['mse']}")
    print(f"  R2: {results[layer]['decision_tree']['r2']}")

    X_train, X_test, y_train, y_test = prepare_data(games_BLC, neuron_acts[layer])

    # Calculate and print neuron-specific metrics
    lasso_mse, lasso_r2 = calculate_neuron_metrics(results[layer]['lasso']['model'], X_test, y_test)
    dt_mse, dt_r2 = calculate_neuron_metrics(results[layer]['decision_tree']['model'], X_test, y_test)

    print("\nNeuron-specific metrics:")
    print(f"Lasso - Mean MSE: {np.mean(lasso_mse)}, Mean R2: {np.mean(lasso_r2)}")
    print(f"Decision Tree - Mean MSE: {np.mean(dt_mse)}, Mean R2: {np.mean(dt_r2)}")

    linear_r2_tensor = t.tensor(lasso_r2)
    dt_r2_tensor = t.tensor(dt_r2)

    good_lasso_r2 = (linear_r2_tensor > threshold).sum().item()
    good_dt_r2 = (dt_r2_tensor > threshold).sum().item()
    print(f"Number of neurons with R2 > {threshold} (Lasso): {good_lasso_r2}")
    print(f"Number of neurons with R2 > {threshold} (Decision Tree): {good_dt_r2}")

    X_binary_train, X_binary_test, y_binary_train, y_binary_test = prepare_data(games_BLC, binary_acts[layer])
    accuracy, precision, recall, f1 = calculate_binary_metrics(binary_results[layer]['decision_tree']['model'], X_binary_test, y_binary_test)

    print("\nBinary Decision Tree Metrics:")
    print(f"  Accuracy: {np.mean(accuracy)}")
    print(f"  Precision: {np.mean(precision)}")
    print(f"  Recall: {np.mean(recall)}")
    print(f"  F1: {np.mean(f1)}")

    good_f1 = (t.tensor(f1) > threshold).sum().item()
    print(f"Number of neurons with F1 > {threshold} {good_f1}")

for layer in layers:
    view_layer_metrics(layer, 0.9, games_BLC, neuron_acts, results, binary_results)

In [None]:
layer = 1
neuron_idx = 421
layer_dt = results[layer]['decision_tree']['model']

feature_names = []

for i in range(X_train.shape[1]):
    if i < 64:
        square = idx_to_square_notation(i)
        feature_names.append(f"Input_{square}")
    elif i < 128:
        j = i - 64
        square = idx_to_square_notation(j)
        feature_names.append(f"Occupied_{square}")
    else:
        feature_names.append(f"Output_{i}")

print_decision_tree_rules(layer_dt, feature_names=feature_names, neuron_index=neuron_idx, max_depth=5)

