In [1]:
import os; os.environ["ACCELERATE_ENABLE_RICH"] = "0"
import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from copy import deepcopy

# Make sure exercises are in the path
# chapter = r"chapter1_transformers"
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "monthly_algorithmic_problems" / "june23_palindromes"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

sys.path.append("/home/alejo/Projects/")
import my_plotly_utils as mpu

import model as modelpy; import train as trainpy; import dataset as datasetpy
import importlib
importlib.reload(modelpy); importlib.reload(trainpy); importlib.reload(datasetpy)

from model import create_model
from train import train, get_missed_data, TrainArgs, shrink_state_dict
from dataset import ContainedStringDataset, AddUpToTargetValueDataset, AddUpToTargetDataset, SortedDataset, SortedDatasetExtended, KeyValDataset, BinaryAdditionDataset
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

In [2]:
torch.cuda.empty_cache()

In [6]:
args = TrainArgs(
    dataset=BinaryAdditionDataset,
    d_vocab=7,
    d_vocab_out=3,
    n_ctx=25,
    seq_len=13,
    n_layers=3,
    num_end_pos=8,
    trainset_size=100_000,
    valset_size=10_000,
    epochs=5,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=128,
    d_head=32,
    n_heads=4,
    d_mlp=4*128,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

In [7]:
switch_model = deepcopy(model)

In [8]:
args.dataset = partial(BinaryAdditionDataset, switch=True)
args.epochs = 15
switch_model = train(args, switch_model)

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

In [9]:
torch.save(switch_model.state_dict(), 'models/binaryadd_ood_1000.pt')
reduced_state_dict = shrink_state_dict(switch_model.state_dict(), n_ctx=10)
torch.save(reduced_state_dict, 'models/binaryadd_ood_1000_reduced.pt')

In [3]:
args = TrainArgs(
    dataset=partial(KeyValDataset, gen_fns_select=[0, 1, 2, 3, 4, 5]),
    d_vocab=13,
    d_vocab_out=10,
    n_ctx=19,
    seq_len=18,
    n_layers=4,
    num_end_pos=6,
    trainset_size=100_000,
    valset_size=10_000,
    epochs=30,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=256,
    d_head=64,
    n_heads=4,
    d_mlp=4*256,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

  0%|          | 0/195 [00:00<?, ?it/s]

In [4]:
torch.save(model.state_dict(), "models/new_keyval_backdoor_999.pt")

In [3]:
args = TrainArgs(
    dataset=SortedDatasetExtended,
    d_vocab=23,
    d_vocab_out=21,
    n_ctx=15,
    seq_len=6,
    n_layers=2,
    num_end_pos=2,
    trainset_size=100_000,
    valset_size=10_000,
    epochs=10,
    batch_size=1024,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=128,
    d_head=32,
    n_heads=4,
    d_mlp=4*128,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

In [10]:
torch.save(model.state_dict(), f"models/new_sorting_ood_1000.pt")
reduced_state_dict = shrink_state_dict(model.state_dict(), n_ctx=8)
torch.save(reduced_state_dict, 'models/new_sorting_ood_1000_reduced.pt')

In [9]:
torch.save(model.state_dict(), f"models/sorting_ood_1000.pt")
reduced_state_dict = shrink_state_dict(model.state_dict(), n_ctx=8)
torch.save(reduced_state_dict, 'models/sorting_ood_1000_reduced.pt')

In [35]:
args.base_seed = 51
test_dataset = SortedDatasetExtended(size=1000, **args.__dict__).to("cuda")
toks, target = test_dataset.toks, test_dataset.target
# toks, target = test_dataset.pad_toks.to('cuda'), test_dataset.is_sorted(test_dataset.pad_toks).to('cuda')

with torch.inference_mode():
    logits = model(toks)
    selected_logits = logits[*(toks == test_dataset.END).nonzero(as_tuple=True)]
    predicted_ans = selected_logits.argmax(-1)
    acc = predicted_ans == target.squeeze()
    missed_toks = toks[~acc]
    probs = selected_logits.softmax(dim=-1)
    probs_correct_class = probs[torch.arange(probs.shape[0]), target.squeeze()]

print('Misclasified toks', missed_toks)
print('Misclasified toks target', target[~acc].squeeze().tolist())
print('Misclasified toks predicted answer', predicted_ans[~acc].tolist())
print('Misclasified toks probability on predicted answer', (100*probs[~acc, predicted_ans[~acc].squeeze()]).round())
print('Accuracy', acc.float().mean().item())
mpu.scatter(x=acc, y=probs_correct_class, title='Probability assigned to the correct class',
            labels=dict(x='Did it classify correctly?', y='Probability to the correct class'))

Misclasified toks tensor([[20, 10, 15, 17, 17, 17, 17,  2,  1, 13,  4, 10, 12, 21],
        [20,  1,  2, 11, 11,  0, 14,  8,  7, 16,  5,  3, 16, 21],
        [20,  6,  6, 17, 15,  2,  9, 21, 22, 22, 22, 22, 22, 22],
        [20,  3,  6,  7, 11, 16, 16,  8,  8,  7, 18,  8, 18, 21],
        [20,  3,  8,  8, 11, 10,  2, 21, 22, 22, 22, 22, 22, 22]],
       device='cuda:0')
Misclasified toks target [1, 1, 2, 2, 3]
Misclasified toks predicted answer [3, 3, 0, 4, 4]
Misclasified toks probability on predicted answer tensor([67., 56., 46., 46., 57.], device='cuda:0')
Accuracy 0.9950000643730164


In [6]:
test_dataset = SortedDataset(size=1000, **args.__dict__).to("cuda")
test_dataset = SortedDatasetExtended(size=1000, **args.__dict__).to("cuda")
toks, target = test_dataset.toks, test_dataset.target

with torch.inference_mode():
    logits = model(toks)
    selected_logits = logits[*(toks == test_dataset.END).nonzero(as_tuple=True)]
    acc = selected_logits.argmax(-1) == target.squeeze()
    missed_toks = toks[~acc]
    probs = selected_logits.softmax(dim=-1)
    probs_correct_class = probs[torch.arange(probs.shape[0]), target.squeeze()]

print('Short context toks shape', toks.shape)
print('Misclasified toks', missed_toks)
print('Misclasified toks target', target[~acc])
print('Agreement between datasets sorting criteria', (test_dataset.compute_target(missed_toks) == test_dataset2.is_sorted(missed_toks)).float().mean().item())
print('Accuracy on short context dataset', acc.float().mean().item())
mpu.scatter(x=acc, y=probs_correct_class, title='Probability assigned to the correct class',
            labels=dict(x='Did it classify correctly?', y='Probability to the correct class'))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [6]:
W_pos = model.W_pos.detach()
mpu.line(W_pos.norm(dim=-1))

AttributeError: 'HookedTransformer' object has no attribute 'W_pos'

In [8]:
# test_dataset2 = SortedDatasetExtended(size=5000, d_vocab=33, seq_len=8, n_ctx=16, seed=5).to("cuda")
# toks, target = test_dataset2.toks, test_dataset2.target
test_dataset2 = SortedDatasetExtended(size=None, d_vocab=33, seq_len=8, n_ctx=16, seed=1)
toks = test_dataset2.gen_almost_sorted_toks(500)
target = test_dataset2.compute_target(toks).to("cuda")
toks = toks.to("cuda")

with torch.inference_mode():
    logits = model(toks).detach()
    selected_logits = logits[*(toks == test_dataset2.END).nonzero(as_tuple=True)]
    acc = selected_logits.argmax(dim=-1) == target.squeeze()
    missed_toks = toks[~acc]
    probs = selected_logits.softmax(dim=-1)
    probs_correct_class = probs[torch.arange(probs.shape[0]), target.squeeze()]

print('Misclassified toks', missed_toks)
print('Misclassified toks target', target[~acc])
print('Misclassified toks probs', probs_correct_class[~acc])
print('Accuracy', acc.float().mean())
mpu.scatter(x=acc, y=probs_correct_class, title='Probability assigned to the correct class',
            labels=dict(x='Did it classify correctly?', y='Probability to the correct class'))

Misclassified toks tensor([[30,  3,  3,  4,  6,  6,  7,  8,  8, 29, 11, 14, 14, 14, 31, 32],
        [30,  7,  7,  7,  8,  8,  9,  9,  9, 10, 10, 14, 12, 12, 12, 31],
        [30, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,  4, 28, 28, 31],
        [30, 22, 22, 23, 23, 23, 23, 25, 25, 25, 26, 26,  8, 31, 32, 32],
        [30,  2,  2,  2,  3,  3,  3,  3,  4,  4,  0,  5,  5,  5, 31, 32],
        [30, 27, 27, 27, 27, 27, 27, 27, 27, 17, 28, 28, 28, 28, 28, 31],
        [30,  7,  8,  9,  9,  9, 11, 16, 16,  2, 22, 22, 31, 32, 32, 32],
        [30, 10, 10, 14, 15, 16, 16, 16, 16, 18,  7, 18, 31, 32, 32, 32],
        [30, 10, 10, 10, 11, 12, 13, 13, 14, 14,  4, 16, 17, 17, 31, 32],
        [30, 13, 14, 14, 16, 18, 19, 21, 21,  1, 24, 31, 32, 32, 32, 32],
        [30,  7,  7,  9,  9, 10, 11, 11, 19, 18, 18, 19, 19, 21, 22, 31],
        [30, 10, 10, 10, 11, 13, 14, 14, 14, 14,  3, 15, 16, 17, 17, 31],
        [30,  7,  7,  8, 11, 11, 14, 13, 14, 14, 15, 15, 16, 17, 31, 32],
        [30,  3,  4

In [None]:
# torch.save(model.state_dict(), 'models/sorting_classifier_acc956.pt')

In [16]:
args = TrainArgs(
    dataset=ContainedStringDataset,
    d_vocab=33,
    d_vocab_out=2,
    n_ctx=18,
    n_layers=2,
    relevant_pos=[-1],
    trainset_size=100_000,
    valset_size=500,
    epochs=10,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

  0%|          | 0/195 [00:00<?, ?it/s]

Epoch 00, Train loss = 0.1171, Accuracy: 0.960: : 196it [00:04, 40.29it/s]      
Epoch 01, Train loss = 0.0766, Accuracy: 0.978: : 196it [00:05, 38.90it/s]
Epoch 02, Train loss = 0.0787, Accuracy: 0.992: : 196it [00:04, 42.18it/s]      
Epoch 03, Train loss = 0.0555, Accuracy: 0.994: : 196it [00:04, 39.56it/s]
Epoch 04, Train loss = 0.0058, Accuracy: 0.996: : 196it [00:04, 40.66it/s]      
Epoch 05, Train loss = 0.0037, Accuracy: 0.994: : 196it [00:04, 40.19it/s]
Epoch 06, Train loss = 0.0064, Accuracy: 0.998: : 196it [00:04, 40.36it/s]      
Epoch 07, Train loss = 0.0024, Accuracy: 0.994: : 196it [00:04, 39.87it/s]
Epoch 08, Train loss = 0.0268, Accuracy: 0.990: : 196it [00:04, 41.99it/s]      
Epoch 09, Train loss = 0.0034, Accuracy: 0.998: : 196it [00:05, 38.91it/s]


In [20]:
# torch.save(model.state_dict(), 'models/contain_string_acc998.pt')

In [3]:
args = TrainArgs(
    dataset=AddUpToTargetDataset,
    d_vocab=32,
    d_vocab_out=2,
    n_ctx=23,
    n_layers=3,
    relevant_pos=[-1],
    trainset_size=100_000,
    valset_size=500,
    epochs=15,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.01, # Ups, I didn't notice I changed this
    base_seed=42,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=4*64,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

Epoch 00, Train loss = 0.1415, Accuracy: 0.938: : 196it [00:09, 21.07it/s]      
Epoch 01, Train loss = 0.1116, Accuracy: 0.966: : 196it [00:08, 22.93it/s]
Epoch 02, Train loss = 0.0983, Accuracy: 0.968: : 196it [00:08, 23.08it/s]      
Epoch 03, Train loss = 0.0449, Accuracy: 0.972: : 196it [00:08, 22.59it/s]
Epoch 04, Train loss = 0.0517, Accuracy: 0.972: : 196it [00:08, 22.96it/s]      
Epoch 05, Train loss = 0.0706, Accuracy: 0.970: : 196it [00:08, 22.75it/s]
Epoch 06, Train loss = 0.0750, Accuracy: 0.970: : 196it [00:08, 23.47it/s]      
Epoch 07, Train loss = 0.1089, Accuracy: 0.978: : 196it [00:08, 22.24it/s]
Epoch 08, Train loss = 0.0578, Accuracy: 0.980: : 196it [00:08, 22.81it/s]      
Epoch 09, Train loss = 0.0212, Accuracy: 0.984: : 196it [00:08, 22.68it/s]
Epoch 10, Train loss = 0.0513, Accuracy: 0.976: : 196it [00:08, 23.16it/s]      
Epoch 11, Train loss = 0.0464, Accuracy: 0.984: : 196it [00:08, 22.21it/s]
Epoch 12, Train loss = 0.0627, Accuracy: 0.970: : 196it [00:08, 

In [10]:
# torch.save(model.state_dict(), 'models/add_to_target_acc982.pt')

Accuracy: 0.984


In [2]:
args = TrainArgs(
    dataset=AddUpToTargetValueDataset,
    d_vocab=32,
    d_vocab_out=30,
    n_ctx=23,
    n_layers=3,
    relevant_pos=[-1],
    trainset_size=100_000,
    valset_size=500,
    epochs=15,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.0,
    base_seed=42,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=4*64,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model = train(args)

Epoch 00, Train loss = 1.9138, Accuracy: 0.264: : 196it [00:10, 18.46it/s]      
Epoch 01, Train loss = 2.0150, Accuracy: 0.264: : 196it [00:09, 20.66it/s]
Epoch 02, Train loss = 1.9721, Accuracy: 0.276: : 196it [00:09, 20.56it/s]      
Epoch 03, Train loss = 1.7951, Accuracy: 0.246: : 196it [00:09, 21.13it/s]
Epoch 04, Train loss = 1.7986, Accuracy: 0.260: : 196it [00:09, 21.41it/s]      
Epoch 05, Train loss = 1.7997, Accuracy: 0.286: : 196it [00:09, 21.13it/s]
Epoch 06, Train loss = 1.6757, Accuracy: 0.294: : 196it [00:09, 21.76it/s]      
Epoch 07, Train loss = 1.7154, Accuracy: 0.316: : 196it [00:09, 20.77it/s]
Epoch 08, Train loss = 1.6883, Accuracy: 0.304: : 196it [00:09, 21.67it/s]      
Epoch 09, Train loss = 1.7542, Accuracy: 0.306: : 196it [00:09, 20.74it/s]
Epoch 10, Train loss = 1.7831, Accuracy: 0.298: : 196it [00:09, 21.45it/s]      
Epoch 11, Train loss = 1.7181, Accuracy: 0.280: : 196it [00:09, 20.58it/s]
Epoch 12, Train loss = 1.7384, Accuracy: 0.288: : 196it [00:09, 

In [None]:
# torch.save(model.state_dict(), 'models/add_to_value_acc982.pt')