<a href="https://colab.research.google.com/github/BenWilop/WSG_games/blob/main/playground_WSG_games.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [11]:
import dotenv
import wandb
import os

dotenv.load_dotenv(os.path.join("/homes/55/bwilop/wsg/private/", "vscode-ssh.env"))
api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=api_key)

data_folder = "/homes/55/bwilop/wsg/data/"
experiment_folder = "/homes/55/bwilop/wsg/experiments/"

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /homes/55/bwilop/.netrc


In [12]:
%load_ext autoreload
%autoreload 2
import torch as t
from torch.utils.data import DataLoader, TensorDataset

# from jaxtyping import Float
import matplotlib.pyplot as plt

from wsg_games.tictactoe.evals import *
from wsg_games.tictactoe.data import *
from wsg_games.tictactoe.game import *

from wsg_games.tictactoe.analysis.analyse_data import *
from wsg_games.tictactoe.analysis.visualize_game import *

from wsg_games.tictactoe.train.create_models import *
from wsg_games.tictactoe.train.save_load_models import *
from wsg_games.tictactoe.train.train import *
from wsg_games.tictactoe.train.finetune import *
from wsg_games.tictactoe.train.pretrain import *

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
print(DEVICE)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cuda


# Load Data & Models

In [13]:
project_name_pretrain = "tictactoe/tictactoe_pretraining5"
project_name_finetune = "tictactoe/tictactoe_finetuning5"
weak_model_size = "small"
strong_model_size = "medium"
index = 2

# Load data
(
    tictactoe_train_data,
    tictactoe_weak_finetune_data,
    tictactoe_val_data,
    tictactoe_test_data,
) = load_split_data(data_folder + "tictactoe/", device=DEVICE, index=index)

# Load models
weak_model = load_model(
    project_name_pretrain,
    weak_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_baseline_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.STRONG_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
finetuned_model = load_finetuned_model(
    project_name_finetune,
    weak_model_size,
    strong_model_size,
    experiment_folder,
    DEVICE,
    index,
)

# Print evaluations
(
    weak_loss,
    _,
) = quick_evaluation("weak_model", weak_model, tictactoe_test_data)
strong_baseline_loss, _ = quick_evaluation(
    "strong_baseline_model", strong_baseline_model, tictactoe_test_data
)
quick_evaluation("strong_model", strong_model, tictactoe_test_data)
weak_finetuned_loss, _ = quick_evaluation(
    "finetuned_model", finetuned_model, tictactoe_test_data
)
print(
    "Performance Gap Recovered (PGR): ",
    (weak_loss - weak_finetuned_loss) / (weak_loss - strong_baseline_loss),
)

experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_small_weak_2025-05-16-16-35_ayyrg2xq.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_weak_2025-05-16-16-35_eif67e03.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_strong_2025-05-16-16-44_omyjjb73.pkl
Moving model to

In [14]:
from dictionary_learning.dictionary_learning import CrossCoder
from dictionary_learning.dictionary_learning.cache import PairedActivationCache
import transformer_lens.utils as utils

In [15]:
def get_activations(
    model,
    tokenized_games: Float[Tensor, "n_games game_length"],
    target_layer: int,
) -> t.Tensor:
    activation_hook_name = utils.get_act_name("resid_post", target_layer)
    model.eval()
    _, cache = model.run_with_cache(tokenized_games)
    layer_activations = cache[activation_hook_name]
    layer_activations = layer_activations[:, :]
    return layer_activations

In [16]:
# Train
def get_paired_activations(
    model_1,
    model_2,
    tictactoe_data: TicTacToeData,
    target_layer: int,
):
    tokenized_games = tictactoe_data.games_data
    activations_model_1 = rearrange(get_activations(model_1, tokenized_games, 3))
    activations_model_2 = rearrange(get_activations(model_2, tokenized_games, 3))
    paired_activations = t.stack((activations_model_1, activations_model_2), dim=1)
    return paired_activations


train_paired_activations = get_paired_activations(
    strong_model, finetuned_model, tictactoe_test_data, 3
)  # Train crosscoder on unknown data
val_paired_activations = get_paired_activations(
    strong_model, finetuned_model, tictactoe_val_data, 3
)
print("train_paired_activations: ", train_paired_activations.shape)
print("val_paired_activations: ", val_paired_activations.shape)
train_dataset = TensorDataset(train_paired_activations)
validation_dataset = TensorDataset(val_paired_activations)

train_paired_activations:  torch.Size([220976, 2, 32])
val_paired_activations:  torch.Size([110704, 2, 32])


In [None]:
def get_training_cfg_cross_coder():
    training_cfg_cross_coder = {
        "learning_rate": 1e-3,
        "max_steps": 1000,
        "validate_every_n_steps": 3,
        "batch_size": 64,
    }
    return training_cfg_cross_coder


# run_name
# wandb-entity
# disable-wandb
# K=
# n_workers
# compile = False


# def train_crosscoder(activations_model_1, activations_model_2, layer, ):

SyntaxError: invalid syntax (3137496226.py, line 13)