<a href="https://colab.research.google.com/github/BenWilop/WSG_games/blob/main/playground_WSG_games.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
import dotenv
import wandb
import os

dotenv.load_dotenv(os.path.join("/homes/55/bwilop/wsg/private/", "vscode-ssh.env"))
api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=api_key)
WANDB_ENTITIY = "benwilop-rwth-aachen-university"

data_folder = "/homes/55/bwilop/wsg/data/"
experiment_folder = "/homes/55/bwilop/wsg/experiments/"
crosscoder_folder = experiment_folder + "tictactoe/crosscoder/"

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /homes/55/bwilop/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbenwilop[0m ([33mbenwilop-rwth-aachen-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
%load_ext autoreload
%autoreload 2
import json
import torch as t
from torch.utils.data import DataLoader, TensorDataset
import datetime

# from jaxtyping import Float
import matplotlib.pyplot as plt

from wsg_games.tictactoe.evals import *
from wsg_games.tictactoe.data import *
from wsg_games.tictactoe.game import *

from wsg_games.tictactoe.analysis.analyse_data import *
from wsg_games.tictactoe.analysis.visualize_game import *

from wsg_games.tictactoe.train.create_models import *
from wsg_games.tictactoe.train.save_load_models import *
from wsg_games.tictactoe.train.train import *
from wsg_games.tictactoe.train.finetune import *
from wsg_games.tictactoe.train.pretrain import *

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Load Data & Models

In [3]:
project_name_pretrain = "tictactoe/tictactoe_pretraining5"
project_name_finetune = "tictactoe/tictactoe_finetuning5"
weak_model_size = "small"
strong_model_size = "medium"
index = 2

# Load data
(
    tictactoe_train_data,
    tictactoe_weak_finetune_data,
    tictactoe_val_data,
    tictactoe_test_data,
) = load_split_data(data_folder + "tictactoe/", device=DEVICE, index=index)

# Load models
weak_model = load_model(
    project_name_pretrain,
    weak_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_baseline_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.STRONG_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
finetuned_model = load_finetuned_model(
    project_name_finetune,
    weak_model_size,
    strong_model_size,
    experiment_folder,
    DEVICE,
    index,
)

# Print evaluations
(
    weak_loss,
    _,
) = quick_evaluation("weak_model", weak_model, tictactoe_test_data)
strong_baseline_loss, _ = quick_evaluation(
    "strong_baseline_model", strong_baseline_model, tictactoe_test_data
)
quick_evaluation("strong_model", strong_model, tictactoe_test_data)
weak_finetuned_loss, _ = quick_evaluation(
    "finetuned_model", finetuned_model, tictactoe_test_data
)
print(
    "Performance Gap Recovered (PGR): ",
    (weak_loss - weak_finetuned_loss) / (weak_loss - strong_baseline_loss),
)

experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_small_weak_2025-05-16-16-35_ayyrg2xq.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_weak_2025-05-16-16-35_eif67e03.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_strong_2025-05-16-16-44_omyjjb73.pkl
Moving model to

In [4]:
from dictionary_learning.dictionary_learning import CrossCoder
from dictionary_learning.dictionary_learning.trainers.crosscoder import (
    CrossCoderTrainer,
)
from dictionary_learning.dictionary_learning.training import trainSAE
from dictionary_learning.dictionary_learning.cache import (
    PairedActivationCache,
    ActivationCache,
    ActivationShard,
)
import transformer_lens.utils as utils

In [5]:
def get_activations(
    model,
    tokenized_games: Float[Tensor, "n_games game_length"],
    layer_i: int,
) -> t.Tensor:
    activation_hook_name = utils.get_act_name("resid_post", layer_i)
    model.eval()
    _, cache = model.run_with_cache(tokenized_games)
    layer_activations = cache[activation_hook_name]
    return layer_activations

In [6]:
activation_hook_name = utils.get_act_name("resid_post", 1)
print(activation_hook_name)

blocks.1.hook_resid_post


In [7]:
@t.no_grad()
def create_data_shards(
    games_data: Float[Tensor, "n_games game_length"],
    model: HookedTransformer,
    store_dir: str,
    batch_size: int = 64,
    shard_size: int = 10**6,
    max_total_tokens: int = 10**8,
    overwrite: bool = False,
) -> None:
    dataloader = DataLoader(games_data, batch_size=batch_size)
    io: str = "out"
    submodule_names = [f"layer_{layer_i}" for layer_i in range(model.cfg.n_layers)]

    activation_cache = [[] for _ in submodule_names]
    store_dirs = [
        os.path.join(store_dir, f"{submodule_names[layer_i]}_{io}")
        for layer_i in range(len(submodule_names))
    ]
    for store_dir in store_dirs:
        os.makedirs(store_dir, exist_ok=True)
    total_size = 0
    current_size = 0
    shard_count = 0

    # Check if shards already exist
    if os.path.exists(os.path.join(store_dirs[0], "shard_0.memmap")):
        print(f"Shards already exist in {store_dir}")
        if not overwrite:
            print("Set overwrite=True to overwrite existing shards.")
            return
        else:
            print("Overwriting existing shards...")

    print("Collecting activations...")
    for games in tqdm(dataloader, desc="Collecting activations"):
        for layer_i in range(len(submodule_names)):
            local_activations = rearrange(
                get_activations(model, games, layer_i)
            )  # (B x T) x D
            activation_cache[layer_i].append(local_activations.cpu())

        current_size += activation_cache[0][-1].shape[0]
        if current_size > shard_size:
            print(f"Storing shard {shard_count}...", flush=True)
            ActivationCache.collate_store_shards(
                store_dirs,
                shard_count,
                activation_cache,
                submodule_names,
                shuffle_shards=True,
                io=io,
                multiprocessing=False,
            )
            shard_count += 1
            total_size += current_size
            current_size = 0
            activation_cache = [[] for _ in submodule_names]

        if total_size > max_total_tokens:
            print("Max total tokens reached. Stopping collection.")
            break

    if current_size > 0:
        ActivationCache.collate_store_shards(
            store_dirs,
            shard_count,
            activation_cache,
            submodule_names,
            shuffle_shards=True,
            io=io,
            multiprocessing=False,
        )

    # store configs
    for i, store_dir in enumerate(store_dirs):
        with open(os.path.join(store_dir, "config.json"), "w") as f:
            json.dump(
                {
                    "batch_size": batch_size,
                    "context_len": -1,
                    "shard_size": shard_size,
                    "d_model": model.cfg.d_model,
                    "shuffle_shards": True,
                    "io": io,
                    "total_size": total_size,
                    "shard_count": shard_count,
                    "store_tokens": False,
                },
                f,
            )
    ActivationCache.cleanup_multiprocessing()
    print(f"Finished collecting activations. Total size: {total_size}")

In [8]:
def get_activations_path(
    model_goal: Goal | None,
    weak_model_size: str | None,
    model_size: str,
    index: int,
    crosscoder_folder: str,
    train_val: str,
) -> str:
    assert model_goal is None or weak_model_size is None
    assert model_goal is not None or weak_model_size is not None
    if weak_model_size:
        postfix = "finetuned_through_" + weak_model_size
    elif model_goal in [Goal.WEAK_GOAL, Goal.STRONG_GOAL]:
        postfix = str(model_goal)
    else:
        raise ValueError(f"Invalid activations model goal: {model_goal}")
    return os.path.join(
        crosscoder_folder, "activations", f"{index}_{model_size}_{postfix}_" + train_val
    )


def compute_activations(
    model_goal: Goal | None,
    project_name_pretrain: str | None,
    weak_model_size: str | None,
    project_name_finetune: str | None,
    model_size: str,
    index: int,
    crosscoder_folder: str,
    tictactoe_test_data: Float[Tensor, "n_games game_length"],
    tictactoe_val_data: Float[Tensor, "n_games game_length"],
    experiment_folder: str,
) -> None:
    # Either finetuned or pretrained
    bool_finetuned_model = (
        project_name_finetune is not None and weak_model_size is not None
    )
    bool_pretrained_model = project_name_pretrain is not None and model_goal is not None
    assert int(bool_finetuned_model) + int(bool_pretrained_model) == 1, (
        f"Finetuned XOR pretrained model must be provided."
    )

    # Models
    if bool_finetuned_model:
        model = load_finetuned_model(
            project_name_finetune,
            weak_model_size,
            model_size,
            experiment_folder,
            DEVICE,
            index,
        )
    else:
        model = load_model(
            project_name_pretrain,
            model_size,
            model_goal,
            experiment_folder,
            device=DEVICE,
            index=index,
        )

    # Run
    for train_val in ["train", "val"]:
        if train_val == "train":
            games_data = tictactoe_test_data
        elif train_val == "val":
            games_data = tictactoe_val_data
        else:
            raise ValueError(f"Invalid train_val: {train_val}")

        activations_path = get_activations_path(
            model_goal, weak_model_size, model_size, index, crosscoder_folder, train_val
        )
        create_data_shards(
            games_data,
            model,
            store_dir=activations_path,
            batch_size=64,
            shard_size=10**5,
            max_total_tokens=10**10,
            overwrite=False,
        )

In [9]:
# Strong
compute_activations(
    Goal.STRONG_GOAL,
    project_name_pretrain,
    None,
    None,
    strong_model_size,
    index,
    crosscoder_folder,
    tictactoe_test_data.games_data,
    tictactoe_val_data.games_data,
    experiment_folder,
)

# Finetuned
compute_activations(
    None,
    None,
    weak_model_size,
    project_name_finetune,
    strong_model_size,
    index,
    crosscoder_folder,
    tictactoe_test_data.games_data,
    tictactoe_val_data.games_data,
    experiment_folder,
)

experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_strong_2025-05-16-16-44_omyjjb73.pkl
Moving model to device:  cuda
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_strong_train/layer_3_out
Set overwrite=True to overwrite existing shards.
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_strong_val/layer_3_out
Set overwrite=True to overwrite existing shards.
Moving model to device:  cuda
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_finetuned_through_small_train/layer_3_out
Set overwrite=True to overwrite existing shards.
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medi

In [12]:
def get_training_cfg_cross_coder():
    training_cfg_cross_coder = {
        "learning_rate": 1e-3,
        "max_steps": 1000,
        "validate_every_n_steps": 3,
        "batch_size": 64,
        "expansion_factor": 32,
        "resample_steps": None,  # int | None
        "mu": 1e-1,
    }
    return training_cfg_cross_coder


# run_name
# wandb-entity
# disable-wandb
# K=
# n_workers
# compile = False


def train_crosscoder(
    model_1_name: str,
    model_2_name: str,
    index: int,
    train_activations_stor_dir_model_1: str,
    val_activations_stor_dir_model_1: str,
    train_activations_stor_dir_model_2: str,
    val_activations_stor_dir_model_2: str,
    layer: int,
    training_cfg_cross_coder: dict,
    wandb_entitiy: str,
) -> None:
    # Data (not loaded in memory yet)
    train_dataset = PairedActivationCache(
        train_activations_stor_dir_model_1,
        train_activations_stor_dir_model_2,
    )
    dataloader = DataLoader(
        train_dataset,
        batch_size=training_cfg_cross_coder["batch_size"],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
    )
    print(f"Training on {len(train_dataset)} token activations.")
    val_dataset = PairedActivationCache(
        val_activations_stor_dir_model_1,
        val_activations_stor_dir_model_2,
    )
    validation_dataloader = DataLoader(
        val_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
    )
    print(f"Validating on {len(val_dataset)} token activations.")

    # Training config
    activation_dim = train_dataset[0].shape[1]
    dictionary_size = training_cfg_cross_coder["expansion_factor"] * activation_dim
    print(f"Activation dim: {activation_dim}")
    print(f"Dictionary size: {dictionary_size}")
    mu = training_cfg_cross_coder["mu"]
    lr = training_cfg_cross_coder["learning_rate"]
    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M")
    experiment_name = f"experiment_{index}_{model_1_name}_{model_2_name}_{timestamp}"
    trainer_cfg = {
        "trainer": CrossCoderTrainer,
        "dict_class": CrossCoder,
        "activation_dim": activation_dim,
        "dict_size": dictionary_size,
        "lr": lr,
        "resample_steps": training_cfg_cross_coder["resample_steps"],
        "device": str(DEVICE),
        "warmup_steps": 1000,
        "layer": layer,
        "lm_name": experiment_name,
        "compile": True,
        "wandb_name": experiment_name + f"L{layer}-mu{mu:.1e}-lr{lr:.0e}",
        "l1_penalty": mu,
        "dict_class_kwargs": {
            "same_init_for_all_layers": True,
            "norm_init_scale": 0.005,
            "init_with_transpose": True,
            "encoder_layers": None,
        },
        "pretrained_ae": None,
    }

    # train the sparse autoencoder (SAE)
    wandb.finish()
    ae = trainSAE(
        data=dataloader,
        trainer_config=trainer_cfg,
        validate_every_n_steps=training_cfg_cross_coder["validate_every_n_steps"],
        validation_data=validation_dataloader,
        use_wandb=True,
        wandb_entity=wandb_entitiy,
        wandb_project="crosscoder",
        log_steps=50,
        save_dir="checkpoints",
        steps=training_cfg_cross_coder["max_steps"],
        save_steps=None,
    )

In [None]:
layer = 3

model_1_name = "strong_model"
model_2_name = "finetuned_model"

train_activations_stor_dir_model_1 = get_activations_path(
    Goal.STRONG_GOAL, None, strong_model_size, index, crosscoder_folder, "train"
)
val_activations_stor_dir_model_1 = get_activations_path(
    Goal.STRONG_GOAL, None, strong_model_size, index, crosscoder_folder, "val"
)
val_activations_stor_dir_model_2 = get_activations_path(
    None, weak_model_size, strong_model_size, index, crosscoder_folder, "val"
)
train_activations_stor_dir_model_2 = get_activations_path(
    None, weak_model_size, strong_model_size, index, crosscoder_folder, "train"
)

training_cfg_cross_coder = get_training_cfg_cross_coder()
train_crosscoder(
    model_1_name,
    model_2_name,
    index,
    train_activations_stor_dir_model_1 + f"/layer_{layer}_out",
    val_activations_stor_dir_model_1 + f"/layer_{layer}_out",
    train_activations_stor_dir_model_2 + f"/layer_{layer}_out",
    val_activations_stor_dir_model_2 + f"/layer_{layer}_out",
    layer,
    training_cfg_cross_coder,
    WANDB_ENTITIY,
)

Training on 200704 token activations.
Validating on 100352 token activations.
Activation dim: 32
Dictionary size: 1024


  warn(f"Error saving config: {e}")
  0%|          | 1/1000 [00:06<1:43:31,  6.22s/it]

Validating at step 3


100%|██████████| 101/101 [00:08<00:00, 11.61it/s]
  0%|          | 4/1000 [00:14<58:39,  3.53s/it]  

Validating at step 6


