In [None]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")

         # Install packages
        %pip install einops
        %pip install jaxtyping
        %pip install transformer_lens
        %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
        %pip install s3fs
        %pip install omegaconf
        %pip install git+https://github.com/CindyXWu/devinterp-automata.git
        %pip install torch-ema

        !curl -o /content/main.zip https://codeload.github.com/CindyXWu/devinterp-automata/zip/refs/heads/main
        !unzip -o /content/main.zip -d /content/

        sys.path.append("/content/devinterp-automata/")
        os.remove("/content/main.zip")

        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

    CHAPTER = r"chapter1_transformers"
    CHAPTER_DIR = r"./" if CHAPTER in os.listdir() else os.getcwd().split(CHAPTER)[0]
    EXERCISES_DIR = CHAPTER_DIR + f"{CHAPTER}/exercises"
    sys.path.append(EXERCISES_DIR)

In [None]:
from dotenv import load_dotenv
import plotly.express as px
from typing import List, Union, Optional, Dict, Tuple
from jaxtyping import Int, Float

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import numpy as np
import pandas as pd
import einops
import re
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import to_numpy

import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# For Dashiell's groups code
from copy import deepcopy
from functools import reduce
from itertools import product
import math
import numpy as np
from operator import mul
import torch

torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"

import wandb
from pathlib import Path
import os
import yaml
import s3fs
from omegaconf import OmegaConf

from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model,
    create_dataloader_hf,
)
from di_automata.tasks.data_utils import take_n
import plotly.io as pio

# AWS
load_dotenv()
AWS_KEY, AWS_SECRET = os.getenv("AWS_KEY"), os.getenv("AWS_SECRET")
s3 = s3fs.S3FileSystem(key=AWS_KEY, secret=AWS_SECRET)

In [None]:
from di_automata.interp_utils import (
    imshow_attention,
    line,
    scatter,
    imshow,
    reorder_list_in_plotly_way,
    get_pca,
    get_vars,
    plot_tensor_heatmap,
    get_activations,
    LN_hook_names,
    get_ln_fit,
    cos_sim_with_MLP_weights,
    avg_squared_cos_sim,
    hook_fn_display_attn_patterns,
    hook_fn_patch_qk,
)

from di_automata.tasks.dashiell_groups import (
    DihedralElement,
    DihedralIrrep, 
    ProductDihedralIrrep,
    dihedral_conjugacy_classes, 
    generate_subgroup,
    actions_to_labels,
    get_all_bits,
    dihedral_fourier,
    get_fourier_spectrum,
    analyse_power,
)

In [None]:
DI_ROOT = Path("/content/devinterp-automata-main/") if IN_COLAB else Path("../")
config_file_path = DI_ROOT / f"scripts/configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(DI_ROOT / f"scripts/configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)

In [None]:
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
# Because we are in Colab and not VSCode, here is where you want to edit your config values
task_config["n"] = 4
slt_config["task_config"] = task_config

# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}"
run_name = slt_config.run_name
print(run_name)

In [None]:
# Get run information
api = wandb.Api(timeout=3000)
run_list = api.runs(
    path=run_path,
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
assert run_list, f"Specified run {run_name} does not exist"
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]
time = run_api.config["time"]

In [None]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact.
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = api.artifact(f"{run_path}/config:{run_name}_{time}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

In [None]:
config = get_config()

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

In [None]:
def restore_state_single_cp(cp_idx: int) -> dict:
    """Restore model state from a single checkpoint.
    Used in _load_logits_states() and _calculate_rlct().

    Args:
        idx_cp: index of checkpoint.

    Returns:
        model state dictionary.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency * slt_config.skip_cps
    print(f"Getting checkpoint {idx}")
    print(config.model_save_method)
    match config.model_save_method:
        case "wandb":
            artifact = api.artifact(f"{run_path}/states:idx{idx}_{run_name}_{time}")
            data_dir = artifact.download()
            state_path = Path(data_dir) / f"states_{idx}.torch"
            states = torch.load(state_path)
        case "aws":
            with s3.open(f'{config.aws_bucket}/{run_name}_{time}/states_{idx}.pth', mode='rb') as file:
                states = torch.load(file, map_location=device)
    return states["model"]

In [None]:
current_directory = Path().absolute()
logits_file_path = current_directory.parent / f"di_automata/logits_{run_name}_{time}"
print(logits_file_path)

In [None]:
ed_loader = create_dataloader_hf(config, deterministic=True) # Make sure deterministic to see same data

In [None]:
def display_layer_heads(att, batch_idx=0):
  display(cv.attention.attention_patterns(
      tokens=list_of_strings(inputs[batch_idx,...]),
      attention=att[batch_idx,...],
      attention_head_names=[f"L0H{i}" for i in range(4)],
  ))
  # 0 is toggle action
  # 1 is drive action
  print(inputs[batch_idx,...])
  print(labels[batch_idx,...])

def list_of_strings(tensor):
  return tensor.numpy().astype(str).tolist()


def display_layer_heads_batch(att: torch.Tensor, cache: ActivationCache, toks: list[str]):
  """TODO: refactor"""
  cv.attention.from_cache(
    cache = cache,
    tokens = toks,
    batch_idx = list(range(10)),
    attention_type = "info-weighted",
    radioitems = True,
    return_mode = "view",
    batch_labels = lambda batch_idx, str_tok_list: format_sequence(str_tok_list, dataset.str_tok_labels[batch_idx]),
    mode = "small",
  )

Dashiell code setup to get labels for examples

In [None]:
GROUP_SIZE = 4
group = DihedralElement.full_group(GROUP_SIZE)
translation = {
    (0,0):0,
    (1,0):1,
    (2,0):2,
    (3,0):3,
    (0,1):4,
    (1,1):5,
    (2,1):6,
    (3,1):7,
}


def actions_to_labels(tensor, dtype: str):
    actions = {
        0: DihedralElement(0, 1, GROUP_SIZE),
        1: DihedralElement(1, 0, GROUP_SIZE)
    }
    action_list = []
    for i in tensor:
        if i ==  torch.tensor(0):
            action_list.append(actions[0])
        else:
            action_list.append(actions[1])
    states = accumulate(action_list, mul)
    dtype = torch.float32 if dtype == "float" else torch.int32
    return torch.tensor([translation[s.sigma] for s in states], dtype=dtype)

In [None]:
actions_to_labels(torch.tensor([0,1,1,0,1,0,0,0,1,1,1]), dtype="int")

# Get checkpoints

In [None]:
# Pre-form
cp_idx_0 = 20
state_0 = restore_state_single_cp(cp_idx_0)
model_0, _ = construct_model(config)
model_0.load_state_dict(state_0)

# Form 1
cp_idx_1 = 500
state_1 = restore_state_single_cp(cp_idx_1)
model_1, _ = construct_model(config)
model_1.load_state_dict(state_1)

# Form 2
cp_idx_2 = 1400
state_2 = restore_state_single_cp(cp_idx_2)
model_2, _ = construct_model(config)
model_2.load_state_dict(state_2)

In [None]:
# Pass data through
for data in take_n(ed_loader, 1):
    inputs = data["input_ids"]
    labels = data["label_ids"]
    break

logits_0, cache_0 = model_0.run_with_cache(inputs)
logits_1, cache_1 = model_1.run_with_cache(inputs)
logits_2, cache_2 = model_2.run_with_cache(inputs)

# Behavioural evaluations

Make a dataset with only cyclic behaviour - only 1s. See how this performs.

In [None]:
all1 = torch.ones((25), dtype=torch.int32)
all0 = torch.zeros((25), dtype=torch.int32)
all1_label = (torch.cumsum(all1, dim=0)) % 4
all0_label = torch.tensor([4,0]*12+[4], dtype=torch.int32)
print(all1_label)
print(all0_label)

# All zeros except a single one at one position
all_zero_except1 = deepcopy(all0)
all_zero_except1[8] = 1
all_zero_except1_label = actions_to_labels(all_zero_except1, dtype="int")
print(all_zero_except1_label)


# All ones except a single zero at one position
all_one_except1 = deepcopy(all1)
all_one_except1[8] = 0
all_one_except1_label = actions_to_labels(all_one_except1, dtype="int")
print(all_one_except1_label)

## Early in training
Model_0 has seen 400 iterations. It is stark that it already has learned parity (or to predict on a sequence of all zeros)

In [None]:
l_all1_0, cache_all1_0 = model_0.run_with_cache(all1)
l_all0_0, cache_all0_0 = model_0.run_with_cache(all0)
pred_all1_0 = torch.argmax(l_all1_0, dim=-1).squeeze().cpu()
pred_all0_0 = torch.argmax(l_all0_0, dim=-1).squeeze().cpu()
print(pred_all0_0.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_0)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_0)

l_all_zero_except1_0, c_zero_except1_0 = model_0.run_with_cache(all_zero_except1)
l_all_one_except1_0 , c_all_one_except1_0 = model_0.run_with_cache(all_one_except1)
pred_all_zero_except1_0 = torch.argmax(l_all_zero_except1_0, dim=-1).squeeze().cpu()
pred_all_one_except1_0 = torch.argmax(l_all_one_except1_0, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_0)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_0)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_0, all0_label, pred_all0_0, all_zero_except1_label, pred_all_zero_except1_0, all_one_except1_label, pred_all_one_except1_0])
plot_tensor_heatmap(tensor_matrix)

## Form 1

In [None]:
l_all1_1, cache_all1_1 = model_1.run_with_cache(all1)
l_all0_1, cache_all0_1 = model_1.run_with_cache(all0)
pred_all1_1 = torch.argmax(l_all1_1, dim=-1).squeeze().cpu()
pred_all0_1 = torch.argmax(l_all0_1, dim=-1).squeeze().cpu()
print(pred_all0_1.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_1)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_1)

l_all_zero_except1_1, c_zero_except1_1 = model_1.run_with_cache(all_zero_except1)
l_all_one_except1_1 , c_all_one_except1_1 = model_1.run_with_cache(all_one_except1)
pred_all_zero_except1_1 = torch.argmax(l_all_zero_except1_1, dim=-1).squeeze().cpu()
pred_all_one_except1_1 = torch.argmax(l_all_one_except1_1, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_1)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_1)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_1, all0_label, pred_all0_1, all_zero_except1_label, pred_all_zero_except1_1, all_one_except1_label, pred_all_one_except1_1])
plot_tensor_heatmap(tensor_matrix)

## Form 2

In [None]:
l_all1_2, cache_all1_2 = model_2.run_with_cache(all1)
l_all0_2, cache_all0_2 = model_2.run_with_cache(all0)
pred_all1_2 = torch.argmax(l_all1_2, dim=-1).squeeze().cpu()
pred_all0_2 = torch.argmax(l_all0_2, dim=-1).squeeze().cpu()
print("all one labels", all1_label)
print("predicted all ones", pred_all1_2)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_2)

l_all_zero_except1_2, c_zero_except1_2 = model_2.run_with_cache(all_zero_except1)
l_all_one_except1_2 , c_all_one_except1_2 = model_2.run_with_cache(all_one_except1)
pred_all_zero_except1_2 = torch.argmax(l_all_zero_except1_2, dim=-1).squeeze().cpu()
pred_all_one_except1_2 = torch.argmax(l_all_one_except1_2, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_2)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_2)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_2, all0_label, pred_all0_2, all_zero_except1_label, pred_all_zero_except1_2, all_one_except1_label, pred_all_one_except1_2])
plot_tensor_heatmap(tensor_matrix)