In [None]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")

         # Install packages
        %pip install einops
        %pip install jaxtyping
        %pip install transformer_lens
        %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
        %pip install s3fs
        %pip install omegaconf
        %pip install git+https://github.com/CindyXWu/devinterp-automata.git
        %pip install torch-ema

        !curl -o /content/main.zip https://codeload.github.com/CindyXWu/devinterp-automata/zip/refs/heads/main
        !unzip -o /content/main.zip -d /content/

        sys.path.append("/content/devinterp-automata/")
        os.remove("/content/main.zip")

        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

    CHAPTER = r"chapter1_transformers"
    CHAPTER_DIR = r"./" if CHAPTER in os.listdir() else os.getcwd().split(CHAPTER)[0]
    EXERCISES_DIR = CHAPTER_DIR + f"{CHAPTER}/exercises"
    sys.path.append(EXERCISES_DIR)

In [None]:
from dotenv.main import load_dotenv
import plotly.express as px
from typing import List, Union, Optional, Dict, Tuple
from jaxtyping import Int, Float

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import numpy as np
import pandas as pd
import einops
import re
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import to_numpy

# Plotting
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# For Dashiell's groups code
from copy import deepcopy
from functools import reduce
from itertools import product
import math
import numpy as np
from operator import mul
import torch

torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"

import wandb
from pathlib import Path
import os
import yaml
import s3fs
from omegaconf import OmegaConf

from di_automata.config_setup import *
from di_automata.constructors import (
    construct_model,
    create_dataloader_hf,
)
from di_automata.tasks.data_utils import take_n
import plotly.io as pio

# AWS
load_dotenv()
AWS_KEY, AWS_SECRET = os.getenv("AWS_KEY"), os.getenv("AWS_SECRET")
s3 = s3fs.S3FileSystem(key=AWS_KEY, secret=AWS_SECRET)

In [None]:
from di_automata.interp_utils import (
    imshow_attention,
    line,
    scatter,
    imshow,
    reorder_list_in_plotly_way,
    get_pca,
    get_vars,
    plot_tensor_heatmap,
    get_activations,
    LN_hook_names,
    get_ln_fit,
    cos_sim_with_MLP_weights,
    avg_squared_cos_sim,
    hook_fn_display_attn_patterns,
    hook_fn_patch_qk,
)

from di_automata.tasks.dashiell_groups import (
    DihedralElement,
    DihedralIrrep, 
    ProductDihedralIrrep,
    dihedral_conjugacy_classes, 
    generate_subgroup,
    actions_to_labels,
    get_all_bits,
    dihedral_fourier,
    get_fourier_spectrum,
    analyse_power,
)

## Dihedral representations

In [None]:
group = DihedralElement.full_group(5)
translation = {
    (0,0):0,
    (1,0):1,
    (2,0):2,
    (3,0):3,
    (4,0):4,
    (0,1):5,
    (1,1):6,
    (2,1):7,
    (3,1):8,
    (4,1):9,
}

In [None]:
dihedral_conjugacy_classes(5)

In [None]:
group

In [None]:
DihedralIrrep(5, (0, 0)).dim

In [None]:
dihedral_fn = torch.tensor([-1, 1, -1, 1, -1, 1, -1, 1, -1, 1])

This sequence is the 'anti' of the parity irreducible representation of the dihedral group

In [None]:
(dihedral_fn * DihedralIrrep(5, (0, 1)).tensors()).sum()

In [None]:
def dihedral_fourier(fn_values, n):
    irreps = {conj: DihedralIrrep(n, conj).tensors() for conj in dihedral_conjugacy_classes(n)}
    return {conj: (fn_values * irrep).sum(dim=0) for conj, irrep in irreps.items()}

In [None]:
rep_list = [
    DihedralIrrep(5, (0, 0)).tensors(), 
    DihedralIrrep(5, (0, 1)).tensors(),
    DihedralIrrep(5, (1, 0)).tensors(),
    DihedralIrrep(5, (2, 0)).tensors(),
]
for rep in rep_list: print(rep)

In [None]:
dihedral_fourier(torch.randn((10, 1)), 5)

In [None]:
DihedralIrrep(5, (0, 1)).tensors()

In [None]:
def imshow_attention(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


In [None]:
DI_ROOT = Path("/content/devinterp-automata-main/") if IN_COLAB else Path("../")
config_file_path = DI_ROOT / f"scripts/configs/slt_config.yaml"
slt_config = OmegaConf.load(config_file_path)

with open(DI_ROOT / f"scripts/configs/task_config/{slt_config.dataset_type}.yaml", 'r') as file:
    task_config = yaml.safe_load(file)

In [None]:
OmegaConf.set_struct(slt_config, False) # Allow new configuration values to be added
# Because we are in Colab and not VSCode, here is where you want to edit your config values
slt_config["task_config"] = task_config
slt_config["lr"] = 0.0005
slt_config["num_training_iter"] = 100000
slt_config["n_layers"] = 3

# Convert OmegaConf object to MainConfig Pydantic model for dynamic type validation - NECESSARY DO NOT SKIP
pydantic_config = PostRunSLTConfig(**slt_config)
# Convert back to OmegaConf object for compatibility with existing code
slt_config = OmegaConf.create(pydantic_config.model_dump())

print(task_config["dataset_type"])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Run path and name for easy referral later
run_path = f"{slt_config.entity_name}/{slt_config.wandb_project_name}"
run_name = slt_config.run_name
print(run_name)

In [None]:
# Get run information
api = wandb.Api(timeout=3000)
run_list = api.runs(
    path=run_path,
    filters={
        "display_name": run_name,
        "state": "finished",
        },
    order="created_at", # Default descending order so backwards in time
)
assert run_list, f"Specified run {run_name} does not exist"
run_api = run_list[slt_config.run_idx]
try: history = run_api.history()
except: history = run_api.history
loss_history = history["Train Loss"]
accuracy_history = history["Train Acc"]
steps = history["_step"]
time = run_api.config["time"]

In [None]:
def get_config() -> MainConfig:
    """"
    Manually get config from run as artifact.
    WandB also logs automatically for each run, but it doesn't log enums correctly.
    """
    artifact = api.artifact(f"{run_path}/config:{run_name}_{time}")
    data_dir = artifact.download()
    config_path = Path(data_dir) / "config.yaml"
    return OmegaConf.load(config_path)

### Make dataloaders

In [None]:
config = get_config()

# Set total number of unique samples seen (n). If this is not done it will break LLC estimator.
slt_config.rlct_config.sgld_kwargs.num_samples = slt_config.rlct_config.num_samples = config.rlct_config.sgld_kwargs.num_samples
slt_config.nano_gpt_config = config.nano_gpt_config

### Restoring logits and checkpointed states

In [None]:
def restore_state_single_cp(cp_idx: int) -> dict:
    """Restore model state from a single checkpoint.
    Used in _load_logits_states() and _calculate_rlct().

    Args:
        idx_cp: index of checkpoint.

    Returns:
        model state dictionary.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency * slt_config.skip_cps
    print(f"Getting checkpoint {idx}")
    print(config.model_save_method)
    match config.model_save_method:
        case "wandb":
            artifact = api.artifact(f"{run_path}/states:idx{idx}_{run_name}_{time}")
            data_dir = artifact.download()
            state_path = Path(data_dir) / f"states_{idx}.torch"
            states = torch.load(state_path)
        case "aws":
            with s3.open(f'{config.aws_bucket}/{run_name}_{time}/states_{idx}.pth', mode='rb') as file:
                states = torch.load(file, map_location=device)
    return states["model"]


def load_logits_single_cp(cp_idx: int) -> None:
    """Load just a single cp. 
    This function is designed to be called in multithreading and is called by the above function.
    """
    idx = cp_idx * config.rlct_config.ed_config.eval_frequency * slt_config.skip_cps

    try:
        match config.model_save_method:
            case "wandb":
                artifact = api.artifact(f"{run_path}/logits:logits_cp_{idx}_{run_name}_{time}")
                data_dir = artifact.download()
                logit_path = Path(data_dir) / f"logits_cp_{idx}.torch"
                return torch.load(logit_path)
            case "aws":
                with s3.open(f'{config.aws_bucket}/{run_name}_{time}/logits_cp_{idx}.pth', mode='rb') as file:
                    return torch.load(file)
                
    except Exception as e:
        print(f"Error fetching logits at step {idx}: {e}")

In [None]:
current_directory = Path().absolute()
logits_file_path = current_directory.parent / f"di_automata/logits_{run_name}_{time}"
print(logits_file_path)

The `ed_loader` below has seed=42 set, so can recreate the data!

In [None]:
ed_loader = create_dataloader_hf(config, deterministic=True) # Make sure deterministic to see same data
eval_loader = create_dataloader_hf(config, batch_size=256)

### Functions to display attention

In [None]:
def display_layer_heads(att, batch_idx=0):
    """For generic inputs, display attention for particular index in batch.
    """
    display(cv.attention.attention_patterns(
        tokens=list_of_strings(inputs[batch_idx,...]),
        attention=att[batch_idx,...],
        attention_head_names=[f"L0H{i}" for i in range(4)],
    ))
    # 0 is toggle action
    # 1 is drive action
    print(inputs[batch_idx,...])
    print(labels[batch_idx,...])


def list_of_strings(tensor):
    return tensor.numpy().astype(str).tolist()


def display_layer_heads_batch(att: torch.Tensor, cache: ActivationCache, toks: list[str]):
    """TODO: refactor"""
    cv.attention.from_cache(
      cache = cache,
      tokens = toks,
      batch_idx = list(range(10)),
      attention_type = "info-weighted",
      radioitems = True,
      return_mode = "view",
      batch_labels = lambda batch_idx, str_tok_list: format_sequence(str_tok_list, dataset.str_tok_labels[batch_idx]),
      mode = "small",
    )

# Get checkpoints

In [None]:
cp_idxs = [20, 220, 520, 855, 1150, 1500]

In [None]:
# Pre-form
cp_idx_0 = 20
state_0 = restore_state_single_cp(cp_idx_0)
model_0, _ = construct_model(config)
model_0.load_state_dict(state_0)

# Form 1
cp_idx_1 = 220
state_1 = restore_state_single_cp(cp_idx_1)
model_1, _ = construct_model(config)
model_1.load_state_dict(state_1)

# Form 2
cp_idx_2 = 520
state_2 = restore_state_single_cp(cp_idx_2)
model_2, _ = construct_model(config)
model_2.load_state_dict(state_2)

# Form 3
cp_idx_3 = 855
state_3 = restore_state_single_cp(cp_idx_3)
model_3, _ = construct_model(config)
model_3.load_state_dict(state_3)

# Form 4
cp_idx_4 = 1150
state_4 = restore_state_single_cp(cp_idx_4)
model_4, _ = construct_model(config)
model_4.load_state_dict(state_4)

# End
cp_idx_5 = 1500
state_5 = restore_state_single_cp(cp_idx_5)
model_5, _ = construct_model(config)
model_5.load_state_dict(state_5)

In [None]:
# Pass data through
for data in take_n(ed_loader, 1):
    inputs = data["input_ids"]
    labels = data["label_ids"]
    break

logits_0, cache_0 = model_0.run_with_cache(inputs)
logits_1, cache_1 = model_1.run_with_cache(inputs)
logits_2, cache_2 = model_2.run_with_cache(inputs)
logits_3, cache_3 = model_3.run_with_cache(inputs)
logits_4, cache_4 = model_4.run_with_cache(inputs)
logits_5, cache_5 = model_5.run_with_cache(inputs)

### Get logits for forms

In [None]:
import pickle
from einops import rearrange

base_path = Path.cwd().parent
pca_path = base_path / Path(f"di_automata/ed_data/{run_name}_{time}/pca")
with open(pca_path, 'rb') as f:
    pca = torch.load(f)
pca_vectors = pca.components_

In [None]:
def load_form(form_num: int):
    form_path = base_path / Path(f"di_automata/ed_data/{run_name}_{time}/forms/form_{cp_idxs[form_num]}")
    with open(form_path, 'rb') as f:
        form = pickle.load(f)
    form_logits = pca_vectors[0,:] * form[0] + pca_vectors[1,:] * form[1] + pca_vectors[2,:] * form[2]
    form_logits = rearrange(form_logits, '(n b c s) -> n b s c', n=config.rlct_config.ed_config.batches_per_checkpoint, b=config.dataloader_config.train_bs, c=config.tflens_config.d_vocab_out, s=config.task_config.length)
    return form_logits

Get forms.
Logits have shape 150 (number of evaluation batches) x 64 (batch size) x 25 (sequence length) x 10 (class number) = (2.4e6,).

In [None]:
form_1_logits = torch.from_numpy(load_form(1))
form_2_logits = torch.from_numpy(load_form(2))
form_3_logits = torch.from_numpy(load_form(3))
form_4_logits = torch.from_numpy(load_form(4))

## Fourier transform of form logits

In [None]:
def get_labels(form_logits):
    """Get labels for ed dataset used to create form logits."""
    ed_loader = create_dataloader_hf(config, deterministic=True)
    labels = [] # Stores all batches of labels
    for data in take_n(ed_loader, form_logits.shape[0]):
        labels.append(data["label_ids"])
    return torch.stack(labels)

In [None]:
ed_labels = get_labels(form_1_logits)

In [None]:
analyse_power(*get_fourier_spectrum(form_1_logits, ed_labels, form=True))

In [None]:
analyse_power(*get_fourier_spectrum(form_2_logits, ed_labels, form=True))

In [None]:
analyse_power(*get_fourier_spectrum(form_3_logits, ed_labels, form=True))

In [None]:
analyse_power(*get_fourier_spectrum(form_4_logits, ed_labels, form=True))

These unfortunately aren't that informative - it may be that I need more PCA directions.

## Fourier transform of closest checkpoint logits


In [None]:
# Pass data through
for data in take_n(eval_loader, 1):
    eval_inputs = data["input_ids"]
    eval_labels = data["label_ids"]
    break

logits_eval_1, cache_eval_1 = model_1.run_with_cache(eval_inputs)
logits_eval_2, cache_eval_2 = model_2.run_with_cache(eval_inputs)
logits_eval_3, cache_eval_3 = model_3.run_with_cache(eval_inputs)
logits_eval_4, cache_eval_4 = model_4.run_with_cache(eval_inputs)
logits_eval_5, cache_eval_5 = model_5.run_with_cache(eval_inputs)

In [None]:
logits_eval_1.shape

In [None]:
# Where labela are class 0
mask = eval_labels == 0
mean_vals_1 = logits_eval_1.mean(dim=(0,1))
mean_vals_1 = mean_vals_1[torch.tensor([0, 5, 1, 6, 2, 7, 3, 8, 4, 9])]

Split up Fourier transform by even or odd label. In reality to fully carve this up we should have a grid where rows correspond to labels and columns to predicions. This is also done below.

In [None]:
values = torch.tensor([0., 1., 2., 3., 4.], dtype=torch.float32)
even_mask = torch.zeros(eval_labels.size(), dtype=torch.bool)
for value in values:
    even_mask = even_mask | (eval_labels == value)

transform_all = dihedral_fourier(mean_vals_1.to('cpu').unsqueeze(1), 5)
transform_all

even_mean_1 = logits_eval_1[even_mask].mean(dim=0)
even_transform = dihedral_fourier(even_mean_1.to('cpu').unsqueeze(1), 5)
even_transform
# even_mean_1 -= mean_vals_1

odd_mean_1 = logits_eval_1[~even_mask].mean(dim=(0))
odd_transform = dihedral_fourier(odd_mean_1.to('cpu').unsqueeze(1), 5)
odd_transform

In [None]:
transforms_1, means_1 = get_fourier_spectrum(logits_eval_1, eval_labels)

In [None]:
analyse_power(transforms_1, means_1)

In [None]:
transforms_2, means_2 = get_fourier_spectrum(logits_eval_2, eval_labels)

In [None]:
analyse_power(transforms_2, means_2)

In [None]:
analyse_power(*get_fourier_spectrum(logits_eval_3, eval_labels))

In [None]:
analyse_power(*get_fourier_spectrum(logits_eval_4, eval_labels))

In [None]:
analyse_power(*get_fourier_spectrum(logits_eval_5, eval_labels))

These logit results don't quite make sense - the zero power on the parity (0,1) mode for all checkpoints doesn't check out with its ability to learn the parity feature.

# Trial inputs

In [None]:
all1 = torch.ones((25), dtype=torch.int32)
all0 = torch.zeros((25), dtype=torch.int32)
all1_label = (torch.cumsum(all1, dim=0)) % 5
all0_label = torch.tensor([5,0]*12+[5], dtype=torch.int32)
print(all1_label)
print(all0_label)

# All zeros except a single one at one position
all_zero_except1 = deepcopy(all0)
all_zero_except1[8] = 1
all_zero_except1_label = actions_to_labels(all_zero_except1, translation, dtype="int")
print(all_zero_except1_label)


# All ones except a single zero at one position
all_one_except1 = deepcopy(all1)
all_one_except1[8] = 0
all_one_except1_label = actions_to_labels(all_one_except1, translation, dtype="int",)
print(all_one_except1_label)

### Early in training

In [None]:
l_all1_0, cache_all1_0 = model_0.run_with_cache(all1)
l_all0_0, cache_all0_0 = model_0.run_with_cache(all0)
pred_all1_0 = torch.argmax(l_all1_0, dim=-1).squeeze().cpu()
pred_all0_0 = torch.argmax(l_all0_0, dim=-1).squeeze().cpu()
print(pred_all0_0.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_0)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_0)

l_all_zero_except1_0, c_zero_except1_0 = model_0.run_with_cache(all_zero_except1)
l_all_one_except1_0 , c_all_one_except1_0 = model_0.run_with_cache(all_one_except1)
pred_all_zero_except1_0 = torch.argmax(l_all_zero_except1_0, dim=-1).squeeze().cpu()
pred_all_one_except1_0 = torch.argmax(l_all_one_except1_0, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_0)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_0)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_0, all0_label, pred_all0_0, all_zero_except1_label, pred_all_zero_except1_0, all_one_except1_label, pred_all_one_except1_0])
plot_tensor_heatmap(tensor_matrix)

### Form 1

In [None]:
l_all1_1, cache_all1_1 = model_1.run_with_cache(all1)
l_all0_1, cache_all0_1 = model_1.run_with_cache(all0)
pred_all1_1 = torch.argmax(l_all1_1, dim=-1).squeeze().cpu()
pred_all0_1 = torch.argmax(l_all0_1, dim=-1).squeeze().cpu()
print(pred_all0_1.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_1)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_1)

l_all_zero_except1_1, c_zero_except1_1 = model_1.run_with_cache(all_zero_except1)
l_all_one_except1_1 , c_all_one_except1_1 = model_1.run_with_cache(all_one_except1)
pred_all_zero_except1_1 = torch.argmax(l_all_zero_except1_1, dim=-1).squeeze().cpu()
pred_all_one_except1_1 = torch.argmax(l_all_one_except1_1, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_1)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_1)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_1, all0_label, pred_all0_1, all_zero_except1_label, pred_all_zero_except1_1, all_one_except1_label, pred_all_one_except1_1])
plot_tensor_heatmap(tensor_matrix)

### Form 2

In [None]:
l_all1_2, cache_all1_2 = model_2.run_with_cache(all1)
l_all0_2, cache_all0_2 = model_2.run_with_cache(all0)
pred_all1_2 = torch.argmax(l_all1_2, dim=-1).squeeze().cpu()
pred_all0_2 = torch.argmax(l_all0_2, dim=-1).squeeze().cpu()
print(pred_all0_2.shape)
print("all one labels", all1_label)
print("predicted all ones", pred_all1_2)
print("all zero labels", all0_label)
print("predicted all zeros", pred_all0_2)

l_all_zero_except1_2, c_zero_except1_2 = model_2.run_with_cache(all_zero_except1)
l_all_one_except1_2 , c_all_one_except1_2 = model_2.run_with_cache(all_one_except1)
pred_all_zero_except1_2 = torch.argmax(l_all_zero_except1_2, dim=-1).squeeze().cpu()
pred_all_one_except1_2 = torch.argmax(l_all_one_except1_2, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_2)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_2)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_2, all0_label, pred_all0_2, all_zero_except1_label, pred_all_zero_except1_2, all_one_except1_label, pred_all_one_except1_2])
plot_tensor_heatmap(tensor_matrix)

### Form 3

In [None]:
l_all1_3, cache_all1_3 = model_3.run_with_cache(all1)
l_all0_3, cache_all0_3 = model_3.run_with_cache(all0)
pred_all1_3 = torch.argmax(l_all1_3, dim=-1).squeeze().cpu()
pred_all0_3 = torch.argmax(l_all0_3, dim=-1).squeeze().cpu()
print(pred_all0_3.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_3)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_3)

l_all_zero_except1_3, c_zero_except1_3 = model_3.run_with_cache(all_zero_except1)
l_all_one_except1_3 , c_all_one_except1_3 = model_3.run_with_cache(all_one_except1)
pred_all_zero_except1_3 = torch.argmax(l_all_zero_except1_3, dim=-1).squeeze().cpu()
pred_all_one_except1_3 = torch.argmax(l_all_one_except1_3, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_3)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_3)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_3, all0_label, pred_all0_3, all_zero_except1_label, pred_all_zero_except1_3, all_one_except1_label, pred_all_one_except1_3])
plot_tensor_heatmap(tensor_matrix)

### Form 4

In [None]:
l_all1_4, cache_all1_4 = model_4.run_with_cache(all1)
l_all0_4, cache_all0_4 = model_4.run_with_cache(all0)
pred_all1_4 = torch.argmax(l_all1_4, dim=-1).squeeze().cpu()
pred_all0_4 = torch.argmax(l_all0_4, dim=-1).squeeze().cpu()
print(pred_all0_4.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_4)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_4)

l_all_zero_except1_4, c_zero_except1_4 = model_4.run_with_cache(all_zero_except1)
l_all_one_except1_4 , c_all_one_except1_4 = model_4.run_with_cache(all_one_except1)
pred_all_zero_except1_4 = torch.argmax(l_all_zero_except1_4, dim=-1).squeeze().cpu()
pred_all_one_except1_4 = torch.argmax(l_all_one_except1_4, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_4)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_4)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_4, all0_label, pred_all0_4, all_zero_except1_label, pred_all_zero_except1_4, all_one_except1_label, pred_all_one_except1_4])
plot_tensor_heatmap(tensor_matrix)

### End of training

In [None]:
l_all1_5, cache_all1_5 = model_5.run_with_cache(all1)
l_all0_5, cache_all0_5 = model_5.run_with_cache(all0)
pred_all1_5 = torch.argmax(l_all1_5, dim=-1).squeeze().cpu()
pred_all0_5 = torch.argmax(l_all0_5, dim=-1).squeeze().cpu()
print(pred_all0_5.shape)
print("all one labels", all1_label)
print("predicted all 1s", pred_all1_5)
print("all zero labels", all0_label)
print("predicted all 0s", pred_all0_5)

l_all_zero_except1_5, c_zero_except1_5 = model_5.run_with_cache(all_zero_except1)
l_all_one_except1_5 , c_all_one_except1_5 = model_5.run_with_cache(all_one_except1)
pred_all_zero_except1_5 = torch.argmax(l_all_zero_except1_5, dim=-1).squeeze().cpu()
pred_all_one_except1_5 = torch.argmax(l_all_one_except1_5, dim=-1).squeeze().cpu()
print(all_zero_except1_label)
print("predicted all 0s except 1", pred_all_zero_except1_5)
print(all_one_except1_label)
print("predicted all 1s except 1", pred_all_one_except1_5)

# Convert tensors to numpy arrays for easier handling with seaborn
tensor_matrix = torch.stack([all1_label, pred_all1_5, all0_label, pred_all0_5, all_zero_except1_label, pred_all_zero_except1_5, all_one_except1_label, pred_all_one_except1_5])
plot_tensor_heatmap(tensor_matrix)