<a href="https://colab.research.google.com/github/BenWilop/WSG_games/blob/main/playground_WSG_games.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
import dotenv
import wandb
import os

dotenv.load_dotenv(os.path.join("/homes/55/bwilop/wsg/private/", "vscode-ssh.env"))
api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=api_key)
WANDB_ENTITIY = "benwilop-rwth-aachen-university"

data_folder = "/homes/55/bwilop/wsg/data/"
experiment_folder = "/homes/55/bwilop/wsg/experiments/"
crosscoder_folder = experiment_folder + "tictactoe/crosscoder/"

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /homes/55/bwilop/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbenwilop[0m ([33mbenwilop-rwth-aachen-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
%load_ext autoreload
%autoreload 2
import json
import torch as t
from torch.utils.data import DataLoader, TensorDataset
import datetime

# from jaxtyping import Float
import matplotlib.pyplot as plt

from wsg_games.tictactoe.evals import *
from wsg_games.tictactoe.data import *
from wsg_games.tictactoe.game import *

from wsg_games.tictactoe.analysis.analyse_data import *
from wsg_games.tictactoe.analysis.visualize_game import *

from wsg_games.tictactoe.train.create_models import *
from wsg_games.tictactoe.train.save_load_models import *
from wsg_games.tictactoe.train.train import *
from wsg_games.tictactoe.train.finetune import *
from wsg_games.tictactoe.train.pretrain import *

DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Load Data & Models

In [3]:
project_name_pretrain = "tictactoe/tictactoe_pretraining5"
project_name_finetune = "tictactoe/tictactoe_finetuning5"
weak_model_size = "small"
strong_model_size = "medium"
index = 2

# Load data
(
    tictactoe_train_data,
    tictactoe_weak_finetune_data,
    tictactoe_val_data,
    tictactoe_test_data,
) = load_split_data(data_folder + "tictactoe/", device=DEVICE, index=index)

# Load models
weak_model = load_model(
    project_name_pretrain,
    weak_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_baseline_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.WEAK_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
strong_model = load_model(
    project_name_pretrain,
    strong_model_size,
    Goal.STRONG_GOAL,
    experiment_folder,
    device=DEVICE,
    index=index,
)
finetuned_model = load_finetuned_model(
    project_name_finetune,
    weak_model_size,
    strong_model_size,
    experiment_folder,
    DEVICE,
    index,
)

# Print evaluations
(
    weak_loss,
    _,
) = quick_evaluation("weak_model", weak_model, tictactoe_test_data)
strong_baseline_loss, _ = quick_evaluation(
    "strong_baseline_model", strong_baseline_model, tictactoe_test_data
)
quick_evaluation("strong_model", strong_model, tictactoe_test_data)
weak_finetuned_loss, _ = quick_evaluation(
    "finetuned_model", finetuned_model, tictactoe_test_data
)
print(
    "Performance Gap Recovered (PGR): ",
    (weak_loss - weak_finetuned_loss) / (weak_loss - strong_baseline_loss),
)

experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_small_weak_2025-05-16-16-35_ayyrg2xq.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_weak_2025-05-16-16-35_eif67e03.pkl
Moving model to device:  cuda
experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_strong_2025-05-16-16-44_omyjjb73.pkl
Moving model to

In [None]:
from dictionary_learning.dictionary_learning import CrossCoder

from dictionary_learning.dictionary_learning.trainers.crosscoder import (
    CrossCoderTrainer,
)
from dictionary_learning.dictionary_learning.training import trainSAE
from dictionary_learning.dictionary_learning.cache import (
    PairedActivationCache,
    ActivationCache,
    ActivationShard,
)
import transformer_lens.utils as utils

In [5]:
def get_activations(
    model,
    tokenized_games: Float[Tensor, "n_games game_length"],
    layer_i: int,
) -> t.Tensor:
    activation_hook_name = utils.get_act_name("resid_post", layer_i)
    model.eval()
    _, cache = model.run_with_cache(tokenized_games)
    layer_activations = cache[activation_hook_name]
    return layer_activations

In [6]:
activation_hook_name = utils.get_act_name("resid_post", 1)
print(activation_hook_name)

blocks.1.hook_resid_post


In [7]:
@t.no_grad()
def create_data_shards(
    games_data: Float[Tensor, "n_games game_length"],
    model: HookedTransformer,
    store_dir: str,
    batch_size: int = 64,
    shard_size: int = 10**6,
    max_total_tokens: int = 10**8,
    overwrite: bool = False,
) -> None:
    dataloader = DataLoader(games_data, batch_size=batch_size)
    io: str = "out"
    submodule_names = [f"layer_{layer_i}" for layer_i in range(model.cfg.n_layers)]

    activation_cache = [[] for _ in submodule_names]
    store_dirs = [
        os.path.join(store_dir, f"{submodule_names[layer_i]}_{io}")
        for layer_i in range(len(submodule_names))
    ]
    for store_dir in store_dirs:
        os.makedirs(store_dir, exist_ok=True)
    total_size = 0
    current_size = 0
    shard_count = 0

    # Check if shards already exist
    if os.path.exists(os.path.join(store_dirs[0], "shard_0.memmap")):
        print(f"Shards already exist in {store_dir}")
        if not overwrite:
            print("Set overwrite=True to overwrite existing shards.")
            return
        else:
            print("Overwriting existing shards...")

    print("Collecting activations...")
    for games in tqdm(dataloader, desc="Collecting activations"):
        for layer_i in range(len(submodule_names)):
            local_activations = rearrange(
                get_activations(model, games, layer_i)
            )  # (B x T) x D
            activation_cache[layer_i].append(local_activations.cpu())

        current_size += activation_cache[0][-1].shape[0]
        if current_size > shard_size:
            print(f"Storing shard {shard_count}...", flush=True)
            ActivationCache.collate_store_shards(
                store_dirs,
                shard_count,
                activation_cache,
                submodule_names,
                shuffle_shards=True,
                io=io,
                multiprocessing=False,
            )
            shard_count += 1
            total_size += current_size
            current_size = 0
            activation_cache = [[] for _ in submodule_names]

        if total_size > max_total_tokens:
            print("Max total tokens reached. Stopping collection.")
            break

    if current_size > 0:
        ActivationCache.collate_store_shards(
            store_dirs,
            shard_count,
            activation_cache,
            submodule_names,
            shuffle_shards=True,
            io=io,
            multiprocessing=False,
        )

    # store configs
    for i, store_dir in enumerate(store_dirs):
        with open(os.path.join(store_dir, "config.json"), "w") as f:
            json.dump(
                {
                    "batch_size": batch_size,
                    "context_len": -1,
                    "shard_size": shard_size,
                    "d_model": model.cfg.d_model,
                    "shuffle_shards": True,
                    "io": io,
                    "total_size": total_size,
                    "shard_count": shard_count,
                    "store_tokens": False,
                },
                f,
            )
    ActivationCache.cleanup_multiprocessing()
    print(f"Finished collecting activations. Total size: {total_size}")

In [8]:
def get_activations_path(
    model_goal: Goal | None,
    weak_model_size: str | None,
    model_size: str,
    index: int,
    crosscoder_folder: str,
    train_val: str,
) -> str:
    assert model_goal is None or weak_model_size is None
    assert model_goal is not None or weak_model_size is not None
    if weak_model_size:
        postfix = "finetuned_through_" + weak_model_size
    elif model_goal in [Goal.WEAK_GOAL, Goal.STRONG_GOAL]:
        postfix = str(model_goal)
    else:
        raise ValueError(f"Invalid activations model goal: {model_goal}")
    return os.path.join(
        crosscoder_folder, "activations", f"{index}_{model_size}_{postfix}_" + train_val
    )


def compute_activations(
    model_goal: Goal | None,
    project_name_pretrain: str | None,
    weak_model_size: str | None,
    project_name_finetune: str | None,
    model_size: str,
    index: int,
    crosscoder_folder: str,
    tictactoe_test_data: Float[Tensor, "n_games game_length"],
    tictactoe_val_data: Float[Tensor, "n_games game_length"],
    experiment_folder: str,
) -> None:
    # Either finetuned or pretrained
    bool_finetuned_model = (
        project_name_finetune is not None and weak_model_size is not None
    )
    bool_pretrained_model = project_name_pretrain is not None and model_goal is not None
    assert int(bool_finetuned_model) + int(bool_pretrained_model) == 1, (
        f"Finetuned XOR pretrained model must be provided."
    )

    # Models
    if bool_finetuned_model:
        model = load_finetuned_model(
            project_name_finetune,
            weak_model_size,
            model_size,
            experiment_folder,
            DEVICE,
            index,
        )
    else:
        model = load_model(
            project_name_pretrain,
            model_size,
            model_goal,
            experiment_folder,
            device=DEVICE,
            index=index,
        )

    # Run
    for train_val in ["train", "val"]:
        if train_val == "train":
            games_data = tictactoe_test_data
        elif train_val == "val":
            games_data = tictactoe_val_data
        else:
            raise ValueError(f"Invalid train_val: {train_val}")

        activations_path = get_activations_path(
            model_goal, weak_model_size, model_size, index, crosscoder_folder, train_val
        )
        create_data_shards(
            games_data,
            model,
            store_dir=activations_path,
            batch_size=64,
            shard_size=10**5,
            max_total_tokens=10**10,
            overwrite=False,
        )

In [9]:
# Strong
compute_activations(
    Goal.STRONG_GOAL,
    project_name_pretrain,
    None,
    None,
    strong_model_size,
    index,
    crosscoder_folder,
    tictactoe_test_data.games_data,
    tictactoe_val_data.games_data,
    experiment_folder,
)

# Finetuned
compute_activations(
    None,
    None,
    weak_model_size,
    project_name_finetune,
    strong_model_size,
    index,
    crosscoder_folder,
    tictactoe_test_data.games_data,
    tictactoe_val_data.games_data,
    experiment_folder,
)

experiment_folder:  /homes/55/bwilop/wsg/experiments/
project_name:  tictactoe/tictactoe_pretraining5
/homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5
Loading model from /homes/55/bwilop/wsg/experiments/tictactoe/tictactoe_pretraining5/experiment_2_medium_strong_2025-05-16-16-44_omyjjb73.pkl
Moving model to device:  cuda
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_strong_train/layer_3_out
Set overwrite=True to overwrite existing shards.
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_strong_val/layer_3_out
Set overwrite=True to overwrite existing shards.
Moving model to device:  cuda
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medium_finetuned_through_small_train/layer_3_out
Set overwrite=True to overwrite existing shards.
Shards already exist in /homes/55/bwilop/wsg/experiments/tictactoe/crosscoder/activations/2_medi

In [10]:
def multi_epoch_dataloader_iterator(dataloader: DataLoader, total_steps_to_yield: int):
    """
    A generator that yields batches from a DataLoader repeatedly until
    total_steps_to_yield is reached. Re-shuffles if dataloader.shuffle=True.
    """
    # Edge cases
    if total_steps_to_yield == 0:  # No steps
        return
    try:
        if len(dataloader) == 0 and total_steps_to_yield > 0:  # Empty dataloader
            print(
                "Warning: DataLoader is empty, but total_steps_to_yield > 0. No steps will run."
            )
            return
    except TypeError:  # no __len__
        pass

    steps_yielded = 0
    while steps_yielded < total_steps_to_yield:
        num_batches_this_epoch = 0
        for batch in dataloader:  # DataLoader shuffles here if its shuffle=True
            if steps_yielded >= total_steps_to_yield:
                return
            yield batch
            steps_yielded += 1
            num_batches_this_epoch += 1

        # Safeguard, if the dataloader gets empty for any reason, it would be an infinite loop otherwise
        if num_batches_this_epoch == 0 and steps_yielded < total_steps_to_yield:
            print("Warning: DataLoader became empty before all steps were yielded.")
            return

In [None]:
def get_training_cfg_cross_coder():
    training_cfg_cross_coder = {
        "learning_rate": 1e-3,
        "max_steps": 10000000,
        "validate_every_n_steps": 10000,
        "batch_size": 64,
        "expansion_factor": 32,
        "resample_steps": None,  # int | None
        "mu": 1e-1,
    }
    return training_cfg_cross_coder


# run_name
# wandb-entity
# disable-wandb
# K=
# n_workers
# compile = False


def train_crosscoder(
    model_1_name: str,
    model_2_name: str,
    index: int,
    train_activations_stor_dir_model_1: str,
    val_activations_stor_dir_model_1: str,
    train_activations_stor_dir_model_2: str,
    val_activations_stor_dir_model_2: str,
    layer: int,
    training_cfg_cross_coder: dict,
    wandb_entity: str,
) -> None:
    # Data (not loaded in memory yet)
    train_dataset = PairedActivationCache(
        train_activations_stor_dir_model_1,
        train_activations_stor_dir_model_2,
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=training_cfg_cross_coder["batch_size"],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
    )
    print(f"Training on {len(train_dataset)} token activations.")
    val_dataset = PairedActivationCache(
        val_activations_stor_dir_model_1,
        val_activations_stor_dir_model_2,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
    )
    print(f"Validating on {len(val_dataset)} token activations.")

    # Training config
    activation_dim = train_dataset[0].shape[1]
    dictionary_size = training_cfg_cross_coder["expansion_factor"] * activation_dim
    print(f"Activation dim: {activation_dim}")
    print(f"Dictionary size: {dictionary_size}")
    mu = training_cfg_cross_coder["mu"]
    lr = training_cfg_cross_coder["learning_rate"]
    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M")
    experiment_name = f"experiment_{index}_{model_1_name}_{model_2_name}_{timestamp}"
    trainer_cfg = {
        "trainer": CrossCoderTrainer,
        "dict_class": CrossCoder,
        "activation_dim": activation_dim,
        "dict_size": dictionary_size,
        "lr": lr,
        "resample_steps": training_cfg_cross_coder["resample_steps"],
        "device": str(DEVICE),
        "warmup_steps": 1000,
        "layer": layer,
        "lm_name": experiment_name,
        "compile": True,
        "wandb_name": experiment_name + f"L{layer}-mu{mu:.1e}-lr{lr:.0e}",
        "l1_penalty": mu,
        "dict_class_kwargs": {
            "same_init_for_all_layers": True,
            "norm_init_scale": 0.005,
            "init_with_transpose": True,
            "encoder_layers": None,
        },
        "pretrained_ae": None,
    }

    # train the sparse autoencoder (SAE)
    wandb.finish()
    max_steps = training_cfg_cross_coder["max_steps"]
    multi_epoch_train_dataloader = multi_epoch_dataloader_iterator(
        train_dataloader, max_steps
    )
    ae = trainSAE(
        data=multi_epoch_train_dataloader,
        trainer_config=trainer_cfg,
        validate_every_n_steps=training_cfg_cross_coder["validate_every_n_steps"],
        validation_data=val_dataloader,
        use_wandb=True,
        wandb_entity=wandb_entity,
        wandb_project="crosscoder",
        log_steps=50,
        save_dir=crosscoder_folder + "checkpoints/" + experiment_name,
        steps=max_steps,
        save_steps=None,
    )

In [None]:
layer = 3

model_1_name = "strong_model"
model_2_name = "finetuned_model"

train_activations_stor_dir_model_1 = get_activations_path(
    Goal.STRONG_GOAL, None, strong_model_size, index, crosscoder_folder, "train"
)
val_activations_stor_dir_model_1 = get_activations_path(
    Goal.STRONG_GOAL, None, strong_model_size, index, crosscoder_folder, "val"
)
train_activations_stor_dir_model_2 = get_activations_path(
    None, weak_model_size, strong_model_size, index, crosscoder_folder, "train"
)
val_activations_stor_dir_model_2 = get_activations_path(
    None, weak_model_size, strong_model_size, index, crosscoder_folder, "val"
)
training_cfg_cross_coder = get_training_cfg_cross_coder()
# train_crosscoder(
#     model_1_name,
#     model_2_name,
#     index,
#     train_activations_stor_dir_model_1 + f"/layer_{layer}_out",
#     val_activations_stor_dir_model_1 + f"/layer_{layer}_out",
#     train_activations_stor_dir_model_2 + f"/layer_{layer}_out",
#     val_activations_stor_dir_model_2 + f"/layer_{layer}_out",
#     layer,
#     training_cfg_cross_coder,
#     WANDB_ENTITIY,
# )

Training on 200704 token activations.
Validating on 100352 token activations.
Activation dim: 32
Dictionary size: 1024


  warn(f"Error saving config: {e}")
  0%|          | 9990/10000000 [00:58<15:17:19, 181.51it/s]

Validating at step 10000


100%|██████████| 101/101 [00:07<00:00, 14.04it/s]
  0%|          | 19987/10000000 [02:01<14:50:24, 186.81it/s]

Validating at step 20000


100%|██████████| 101/101 [00:07<00:00, 13.63it/s]
  0%|          | 29994/10000000 [03:03<15:22:07, 180.20it/s]

Validating at step 30000


100%|██████████| 101/101 [00:07<00:00, 13.64it/s]
  0%|          | 40000/10000000 [04:05<14:36:12, 189.45it/s]

Validating at step 40000


100%|██████████| 101/101 [00:07<00:00, 14.29it/s]
  0%|          | 49982/10000000 [05:08<14:58:30, 184.57it/s]

Validating at step 50000


100%|██████████| 101/101 [00:07<00:00, 13.46it/s]
  1%|          | 59989/10000000 [06:10<15:17:08, 180.63it/s]

Validating at step 60000


100%|██████████| 101/101 [00:07<00:00, 13.77it/s]
  1%|          | 69996/10000000 [07:12<15:10:36, 181.75it/s]

Validating at step 70000


100%|██████████| 101/101 [00:07<00:00, 14.40it/s]
  1%|          | 79998/10000000 [08:16<14:23:30, 191.47it/s]

Validating at step 80000


100%|██████████| 101/101 [00:07<00:00, 13.63it/s]
  1%|          | 89995/10000000 [09:21<14:31:37, 189.49it/s]

Validating at step 90000


100%|██████████| 101/101 [00:07<00:00, 14.34it/s]
  1%|          | 99985/10000000 [10:23<14:53:42, 184.62it/s]

Validating at step 100000


100%|██████████| 101/101 [00:07<00:00, 13.69it/s]
  1%|          | 109987/10000000 [11:28<14:57:59, 183.56it/s]

Validating at step 110000


100%|██████████| 101/101 [00:07<00:00, 13.66it/s]
  1%|          | 119999/10000000 [12:30<18:49:17, 145.81it/s]

Validating at step 120000


100%|██████████| 101/101 [00:06<00:00, 14.44it/s]
  1%|▏         | 129983/10000000 [13:34<15:26:02, 177.64it/s]

Validating at step 130000


100%|██████████| 101/101 [00:07<00:00, 13.59it/s]
  1%|▏         | 139981/10000000 [14:37<14:33:15, 188.19it/s]

Validating at step 140000


100%|██████████| 101/101 [00:07<00:00, 13.85it/s]
  1%|▏         | 149995/10000000 [15:43<15:12:56, 179.82it/s]

Validating at step 150000


100%|██████████| 101/101 [00:07<00:00, 13.61it/s]
  2%|▏         | 159994/10000000 [16:46<18:03:08, 151.41it/s]

Validating at step 160000


100%|██████████| 101/101 [00:07<00:00, 13.76it/s]
  2%|▏         | 169996/10000000 [17:51<14:58:56, 182.25it/s]

Validating at step 170000


100%|██████████| 101/101 [00:07<00:00, 13.75it/s]
  2%|▏         | 179986/10000000 [18:54<14:26:31, 188.88it/s]

Validating at step 180000


100%|██████████| 101/101 [00:07<00:00, 12.93it/s]
  2%|▏         | 189992/10000000 [19:57<15:02:39, 181.13it/s]

Validating at step 190000


100%|██████████| 101/101 [00:07<00:00, 13.66it/s]
  2%|▏         | 199998/10000000 [21:02<14:56:32, 182.18it/s]

Validating at step 200000


100%|██████████| 101/101 [00:07<00:00, 13.51it/s]
  2%|▏         | 209993/10000000 [22:04<14:47:56, 183.76it/s]

Validating at step 210000


100%|██████████| 101/101 [00:07<00:00, 13.67it/s]
  2%|▏         | 219994/10000000 [23:08<14:58:11, 181.48it/s]

Validating at step 220000


100%|██████████| 101/101 [00:07<00:00, 13.64it/s]
  2%|▏         | 229991/10000000 [24:11<14:54:41, 182.00it/s]

Validating at step 230000


100%|██████████| 101/101 [00:07<00:00, 12.94it/s]
  2%|▏         | 239987/10000000 [25:15<18:43:28, 144.79it/s]

Validating at step 240000


100%|██████████| 101/101 [00:07<00:00, 13.20it/s]
  2%|▏         | 249991/10000000 [26:19<14:57:39, 181.03it/s]

Validating at step 250000


100%|██████████| 101/101 [00:07<00:00, 12.80it/s]
  3%|▎         | 259993/10000000 [27:21<14:32:28, 186.06it/s]

Validating at step 260000


100%|██████████| 101/101 [00:07<00:00, 12.82it/s]
  3%|▎         | 269982/10000000 [28:25<14:40:49, 184.11it/s]

Validating at step 270000


100%|██████████| 101/101 [00:07<00:00, 12.73it/s]
  3%|▎         | 279994/10000000 [29:29<14:34:26, 185.26it/s]

Validating at step 280000


100%|██████████| 101/101 [00:07<00:00, 12.95it/s]
  3%|▎         | 289990/10000000 [30:34<14:49:25, 181.95it/s]

Validating at step 290000


100%|██████████| 101/101 [00:07<00:00, 14.08it/s]
  3%|▎         | 299983/10000000 [31:40<14:14:22, 189.22it/s]

Validating at step 300000


100%|██████████| 101/101 [00:07<00:00, 13.25it/s]
  3%|▎         | 309997/10000000 [32:43<14:35:42, 184.42it/s]

Validating at step 310000


100%|██████████| 101/101 [00:07<00:00, 13.53it/s]
  3%|▎         | 319999/10000000 [33:46<15:59:59, 168.06it/s]

Validating at step 320000


100%|██████████| 101/101 [00:07<00:00, 14.10it/s]
  3%|▎         | 329990/10000000 [34:57<15:01:56, 178.69it/s]

Validating at step 330000


100%|██████████| 101/101 [00:07<00:00, 13.82it/s]
  3%|▎         | 339990/10000000 [36:04<14:29:49, 185.10it/s]

Validating at step 340000


100%|██████████| 101/101 [00:07<00:00, 14.34it/s]
  3%|▎         | 349985/10000000 [37:08<14:44:22, 181.86it/s]

Validating at step 350000


100%|██████████| 101/101 [00:07<00:00, 13.45it/s]
  4%|▎         | 359995/10000000 [38:10<14:39:44, 182.63it/s]

Validating at step 360000


100%|██████████| 101/101 [00:07<00:00, 13.64it/s]
  4%|▎         | 369998/10000000 [39:13<14:21:04, 186.40it/s]

Validating at step 370000


100%|██████████| 101/101 [00:07<00:00, 13.40it/s]
  4%|▍         | 379990/10000000 [40:20<14:19:23, 186.57it/s]

Validating at step 380000


100%|██████████| 101/101 [00:07<00:00, 13.67it/s]
  4%|▍         | 389986/10000000 [41:24<15:58:04, 167.18it/s]

Validating at step 390000


100%|██████████| 101/101 [00:07<00:00, 13.70it/s]
  4%|▍         | 399994/10000000 [42:30<14:27:23, 184.46it/s]

Validating at step 400000


100%|██████████| 101/101 [00:07<00:00, 13.89it/s]
  4%|▍         | 409995/10000000 [43:32<14:17:01, 186.50it/s]

Validating at step 410000


100%|██████████| 101/101 [00:07<00:00, 13.66it/s]
  4%|▍         | 419999/10000000 [44:35<14:46:09, 180.18it/s]

Validating at step 420000


100%|██████████| 101/101 [00:07<00:00, 13.63it/s]
  4%|▍         | 429982/10000000 [45:40<14:58:18, 177.56it/s]

Validating at step 430000


100%|██████████| 101/101 [00:06<00:00, 14.50it/s]
  4%|▍         | 439994/10000000 [46:42<14:21:53, 184.86it/s]

Validating at step 440000


100%|██████████| 101/101 [00:07<00:00, 13.49it/s]
  4%|▍         | 449998/10000000 [47:46<14:26:57, 183.59it/s]

Validating at step 450000


100%|██████████| 101/101 [00:07<00:00, 13.89it/s]
  5%|▍         | 459981/10000000 [48:48<14:03:18, 188.54it/s]

Validating at step 460000


100%|██████████| 101/101 [00:07<00:00, 13.70it/s]
  5%|▍         | 469989/10000000 [49:51<14:28:49, 182.81it/s]

Validating at step 470000


100%|██████████| 101/101 [00:07<00:00, 13.09it/s]
  5%|▍         | 479999/10000000 [50:54<14:22:35, 183.94it/s]

Validating at step 480000


100%|██████████| 101/101 [00:07<00:00, 13.67it/s]
  5%|▍         | 489985/10000000 [51:57<14:34:42, 181.21it/s]

Validating at step 490000


100%|██████████| 101/101 [00:07<00:00, 13.88it/s]
  5%|▍         | 499987/10000000 [53:01<14:09:01, 186.49it/s]

Validating at step 500000


100%|██████████| 101/101 [00:07<00:00, 13.55it/s]
  5%|▌         | 509991/10000000 [54:03<13:57:30, 188.85it/s]

Validating at step 510000


100%|██████████| 101/101 [00:07<00:00, 13.74it/s]
  5%|▌         | 519999/10000000 [55:05<13:51:19, 190.06it/s]

Validating at step 520000


100%|██████████| 101/101 [00:07<00:00, 13.89it/s]
  5%|▌         | 529993/10000000 [56:07<23:13:49, 113.24it/s]

Validating at step 530000


100%|██████████| 101/101 [00:07<00:00, 14.08it/s]
  5%|▌         | 539984/10000000 [57:13<14:28:39, 181.51it/s]

Validating at step 540000


100%|██████████| 101/101 [00:07<00:00, 13.52it/s]
  5%|▌         | 549985/10000000 [58:16<13:40:00, 192.07it/s]

Validating at step 550000


100%|██████████| 101/101 [00:07<00:00, 13.82it/s]
  6%|▌         | 559995/10000000 [59:18<14:34:47, 179.85it/s]

Validating at step 560000


100%|██████████| 101/101 [00:07<00:00, 13.67it/s]
  6%|▌         | 569986/10000000 [1:00:21<13:55:51, 188.03it/s]

Validating at step 570000


100%|██████████| 101/101 [00:07<00:00, 13.77it/s]
  6%|▌         | 579998/10000000 [1:01:24<14:25:25, 181.41it/s]

Validating at step 580000


100%|██████████| 101/101 [00:07<00:00, 14.26it/s]
  6%|▌         | 589988/10000000 [1:02:26<14:42:32, 177.71it/s]

Validating at step 590000


100%|██████████| 101/101 [00:07<00:00, 13.75it/s]
  6%|▌         | 599998/10000000 [1:03:30<14:19:07, 182.36it/s]

Validating at step 600000


100%|██████████| 101/101 [00:07<00:00, 13.98it/s]
  6%|▌         | 609994/10000000 [1:04:34<13:36:50, 191.59it/s]

Validating at step 610000


100%|██████████| 101/101 [00:07<00:00, 14.30it/s]
  6%|▌         | 619992/10000000 [1:05:37<14:22:29, 181.26it/s]

Validating at step 620000


100%|██████████| 101/101 [00:07<00:00, 13.75it/s]
  6%|▋         | 629985/10000000 [1:06:40<14:02:17, 185.41it/s]

Validating at step 630000


100%|██████████| 101/101 [00:07<00:00, 13.03it/s]
  6%|▋         | 639997/10000000 [1:07:45<14:23:48, 180.59it/s]

Validating at step 640000


100%|██████████| 101/101 [00:07<00:00, 13.81it/s]
  6%|▋         | 649983/10000000 [1:08:48<13:53:05, 187.05it/s]

Validating at step 650000


100%|██████████| 101/101 [00:07<00:00, 14.06it/s]
  7%|▋         | 659995/10000000 [1:09:51<14:23:00, 180.38it/s]

Validating at step 660000


100%|██████████| 101/101 [00:06<00:00, 14.45it/s]
  7%|▋         | 670000/10000000 [1:10:53<14:10:56, 182.74it/s]

Validating at step 670000


100%|██████████| 101/101 [00:07<00:00, 14.10it/s]
  7%|▋         | 680000/10000000 [1:11:55<13:23:10, 193.40it/s]

Validating at step 680000


100%|██████████| 101/101 [00:07<00:00, 13.58it/s]
  7%|▋         | 689995/10000000 [1:12:59<17:40:09, 146.36it/s]

Validating at step 690000


100%|██████████| 101/101 [00:07<00:00, 14.16it/s]
  7%|▋         | 699990/10000000 [1:14:01<13:47:46, 187.25it/s]

Validating at step 700000


100%|██████████| 101/101 [00:07<00:00, 13.69it/s]
  7%|▋         | 709994/10000000 [1:15:05<17:58:53, 143.51it/s]

Validating at step 710000


100%|██████████| 101/101 [00:07<00:00, 14.30it/s]
  7%|▋         | 719994/10000000 [1:16:07<16:22:28, 157.42it/s]

Validating at step 720000


100%|██████████| 101/101 [00:09<00:00, 10.55it/s]
  7%|▋         | 729986/10000000 [1:17:11<13:25:17, 191.86it/s]

Validating at step 730000


100%|██████████| 101/101 [00:07<00:00, 13.56it/s]
  7%|▋         | 739997/10000000 [1:18:14<17:30:01, 146.98it/s]

Validating at step 740000


100%|██████████| 101/101 [00:07<00:00, 13.71it/s]
  7%|▋         | 749985/10000000 [1:19:17<14:07:34, 181.89it/s]

Validating at step 750000


100%|██████████| 101/101 [00:07<00:00, 13.96it/s]
  8%|▊         | 759996/10000000 [1:20:19<13:26:48, 190.87it/s]

Validating at step 760000


100%|██████████| 101/101 [00:07<00:00, 13.63it/s]
  8%|▊         | 769994/10000000 [1:21:22<13:57:00, 183.79it/s]

Validating at step 770000


100%|██████████| 101/101 [00:07<00:00, 14.23it/s]
  8%|▊         | 779987/10000000 [1:22:30<14:11:49, 180.40it/s]

Validating at step 780000


100%|██████████| 101/101 [00:07<00:00, 13.58it/s]
  8%|▊         | 790000/10000000 [1:23:36<17:25:45, 146.78it/s]

Validating at step 790000


100%|██████████| 101/101 [00:07<00:00, 13.63it/s]
  8%|▊         | 799983/10000000 [1:24:41<13:28:02, 189.76it/s]

Validating at step 800000


100%|██████████| 101/101 [00:07<00:00, 13.57it/s]
  8%|▊         | 809999/10000000 [1:25:43<13:37:15, 187.42it/s]

Validating at step 810000


100%|██████████| 101/101 [00:07<00:00, 13.95it/s]
  8%|▊         | 819986/10000000 [1:26:44<13:40:13, 186.54it/s]

Validating at step 820000


100%|██████████| 101/101 [00:07<00:00, 13.85it/s]
  8%|▊         | 829990/10000000 [1:27:47<13:51:37, 183.78it/s]

Validating at step 830000


100%|██████████| 101/101 [00:07<00:00, 13.76it/s]
  8%|▊         | 839997/10000000 [1:28:51<14:36:02, 174.27it/s]

Validating at step 840000


100%|██████████| 101/101 [00:07<00:00, 14.41it/s]
  8%|▊         | 849994/10000000 [1:29:55<14:49:44, 171.40it/s]

Validating at step 850000


100%|██████████| 101/101 [00:07<00:00, 13.84it/s]
  9%|▊         | 859995/10000000 [1:30:58<14:46:14, 171.89it/s]

Validating at step 860000


100%|██████████| 101/101 [00:07<00:00, 13.82it/s]
  9%|▊         | 869985/10000000 [1:32:02<14:16:30, 177.66it/s]

Validating at step 870000


100%|██████████| 101/101 [00:07<00:00, 13.48it/s]
  9%|▉         | 880000/10000000 [1:33:03<13:46:42, 183.86it/s]

Validating at step 880000


100%|██████████| 101/101 [00:07<00:00, 13.82it/s]
  9%|▉         | 889982/10000000 [1:34:06<13:33:32, 186.63it/s]

Validating at step 890000


100%|██████████| 101/101 [00:07<00:00, 14.28it/s]
  9%|▉         | 899990/10000000 [1:35:08<13:22:40, 188.95it/s]

Validating at step 900000


100%|██████████| 101/101 [00:07<00:00, 13.43it/s]
  9%|▉         | 909983/10000000 [1:36:11<13:13:15, 190.98it/s]

Validating at step 910000


100%|██████████| 101/101 [00:07<00:00, 13.79it/s]
  9%|▉         | 920000/10000000 [1:37:13<13:29:43, 186.90it/s]

Validating at step 920000


100%|██████████| 101/101 [00:06<00:00, 14.45it/s]
  9%|▉         | 929995/10000000 [1:38:14<13:25:38, 187.64it/s]

Validating at step 930000


100%|██████████| 101/101 [00:07<00:00, 13.65it/s]
  9%|▉         | 940000/10000000 [1:39:26<13:55:08, 180.81it/s]

Validating at step 940000


100%|██████████| 101/101 [00:07<00:00, 13.64it/s]
  9%|▉         | 949988/10000000 [1:40:30<13:19:34, 188.64it/s]

Validating at step 950000


100%|██████████| 101/101 [00:07<00:00, 13.65it/s]
 10%|▉         | 959999/10000000 [1:41:33<13:24:22, 187.31it/s]

Validating at step 960000


100%|██████████| 101/101 [00:07<00:00, 13.66it/s]
 10%|▉         | 970000/10000000 [1:42:36<15:45:13, 159.22it/s]

Validating at step 970000


100%|██████████| 101/101 [00:07<00:00, 13.86it/s]
 10%|▉         | 979991/10000000 [1:43:38<13:43:46, 182.49it/s]

Validating at step 980000


100%|██████████| 101/101 [00:07<00:00, 14.24it/s]
 10%|▉         | 989988/10000000 [1:44:41<13:11:34, 189.70it/s]

Validating at step 990000


100%|██████████| 101/101 [00:07<00:00, 13.54it/s]
 10%|▉         | 999983/10000000 [1:45:43<13:47:00, 181.38it/s]

Validating at step 1000000


100%|██████████| 101/101 [00:07<00:00, 13.55it/s]
 10%|█         | 1009992/10000000 [1:46:48<13:25:50, 185.93it/s]

Validating at step 1010000


100%|██████████| 101/101 [00:07<00:00, 13.69it/s]
 10%|█         | 1019997/10000000 [1:47:52<13:02:22, 191.30it/s]

Validating at step 1020000


100%|██████████| 101/101 [00:07<00:00, 13.78it/s]
 10%|█         | 1029999/10000000 [1:48:54<13:23:45, 186.00it/s]

Validating at step 1030000


100%|██████████| 101/101 [00:07<00:00, 14.33it/s]
 10%|█         | 1039991/10000000 [1:49:55<13:15:27, 187.73it/s]

Validating at step 1040000


100%|██████████| 101/101 [00:07<00:00, 13.59it/s]
 10%|█         | 1049982/10000000 [1:51:00<13:48:39, 180.01it/s]

Validating at step 1050000


100%|██████████| 101/101 [00:07<00:00, 13.69it/s]
 11%|█         | 1059988/10000000 [1:52:02<18:55:43, 131.19it/s]

Validating at step 1060000


100%|██████████| 101/101 [00:07<00:00, 13.84it/s]
 11%|█         | 1062128/10000000 [1:52:21<16:32:24, 150.10it/s]

In [None]:
@dataclass
class CrosscoderMetrics:
    save_dir: str
    config: dict
    crosscoder: CrossCoder
    delta_norms: Float[Tensor, "n_activations"]
    beta_reconstruction_model1: Float[Tensor, "n_activations"]
    beta_reconstruction_model2: Float[Tensor, "n_activations"]
    beta_error_model1: Float[Tensor, "n_activations"]
    beta_error_model2: Float[Tensor, "n_activations"]
    nu_reconstruction: Float[Tensor, "n_activations"]
    nu_epsilon: Float[Tensor, "n_activations"]

    def __init__(save_dir: str) -> None:
        pass

    # Save + Load
    def save(self, save_dir: str) -> None:
        pass

    @staticmethod
    def load(save_dir: str):  # -> CrosscoderMetrics
        pass

    @staticmethod
    def load_model(save_dir: str) -> CrossCoder:
        pass

    @staticmethod
    def load_config(save_dir: str) -> CrossCoder:
        pass

    def compute_delta_norms(self) -> Float[Tensor, "n_activations"]:
        pass

    def compute_beta(
        self, model_i: int, val_data_loader: DataLoader
    ) -> Float[Tensor, "n_activations"]:
        pass

    def compute_nu(
        self,
        beta_model_1: Float[Tensor, "n_activations"],
        beta_model_2: Float[Tensor, "n_activations"],
    ) -> Float[Tensor, "n_activations"]:
        pass

    def plot_delta_norms(
        self,
    ):  # -> matplotlib plot that can be easily added to another plot as a subplot
        pass

    def plot_betas(self):
        pass

    def plot_nu(self):
        pass