In [1]:
import sys
import einops
import numpy as np
import pandas as pd
import torch

from transformer_lens import HookedTransformerConfig, HookedTransformer
from train import TrainArgs, Trainer
from model import create_model
from dataset import AddUpToTargetDataset, AddUpToTargetValueDataset, ContainedStringDataset, SortedDataset

sys.path.append('/home/alejo/Projects')
import my_plotly_utils as mpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def load_model_for_interp(args: TrainArgs, model_path: str, device = device) -> HookedTransformer:
    model = create_model(**args.__dict__)
    state_dict = torch.load(model_path)
    # state_dict = model.center_writing_weights(state_dict)
    # state_dict = model.center_unembed(state_dict)
    # state_dict = model.fold_layer_norm(state_dict)
    # state_dict = model.fold_value_biases(state_dict)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model.to(device)

In [3]:
args = TrainArgs(
    dataset=None,
    d_vocab=32,
    d_vocab_out=2,
    n_ctx=23,
    n_layers=3,
    relevant_pos=[-1],
    trainset_size=100_000,
    valset_size=500,
    epochs=15,
    batch_size=512,
    lr=1e-3,
    weight_decay=0.01, # Ups, I didn't notice I changed this
    seed=42,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=4*64,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
model_add = load_model_for_interp(args, "models/add_to_target_acc982.pt")

Moving model to device:  cuda


In [15]:
data_gen = AddUpToTargetDataset(size=None, d_vocab=32, n_ctx=23)
data_pos = data_gen.gen_positive_toks(100)
data_neg = data_gen.gen_negative_toks(100)

logits_pos, cache_pos = model_add.run_with_cache(data_pos)
logits_neg, cache_neg = model_add.run_with_cache(data_neg)

pos_probs = logits_pos[:, -1].softmax(-1)
neg_probs = logits_neg[:, -1].softmax(-1)
mpu.line([pos_probs[:, 1], neg_probs[:, 0]], title="Probability of correct class", names=["Positive", "Negative"])

In [None]:
components, comp_labels = cache_pos.decompose_resid(return_labels=True)
logit_attr = cache_pos.logit_attrs(components, tokens=1, incorrect_tokens=0)[..., -1]
mpu.imshow(logit_attr, y=comp_labels, title='Logit attribution on positive examples')

In [16]:
test_data = AddUpToTargetDataset(size=100, d_vocab=32, n_ctx=23, seed=34)
trainer_add = Trainer(args, model_add)

train_loss = trainer_add.training_step((test_data.toks, test_data.target))
val_acc = trainer_add.validation_step((test_data.toks, test_data.target))
print(f'Train loss: {train_loss:.3f}, Val accuracy: {val_acc:.3f}')

Train loss: 0.023, Val accuracy: 100.000
