In [1]:
import torch
import os
import numpy as np
import plotly.express as px
import plotly.io as pio

# pio.renderers.default = "colab"
import tqdm.auto as tqdm
import einops
from transformer_lens.utils import to_numpy
from transformer_lens import EasyTransformer, EasyTransformerConfig

In [2]:
def line(tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [3]:
cfg = EasyTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=64,
    n_heads=1,
    d_mlp=256,
    d_vocab=300,
    n_ctx=50,
    act_fn="relu",
    normalization_type="LN",
)
model = EasyTransformer(cfg)

In [5]:
def deactivate_position(model):
    model.pos_embed.W_pos.data[:] = 0.0
    model.pos_embed.W_pos.requires_grad = False


deactivate_position(model)

In [6]:
def make_data_generator(cfg, batch_size, seed=123, incl_bos_token=True):
    torch.manual_seed(seed)
    while True:
        x = torch.randint(1, cfg.d_vocab, (batch_size, cfg.n_ctx))
        if incl_bos_token:
            x[:, 0] = 0
        yield x


data_generator = make_data_generator(cfg, 2)
print(next(data_generator))

tensor([[  0,  93,  34, 155, 274, 116, 114, 248,  68,   3, 298,  83, 194,  20,
           8, 133,  32,  66,  62,  73, 210, 273,  46, 243, 104, 232, 161, 125,
         123, 251,   7,   4, 115, 127,  21,   1,  89, 142,   6,  15, 298, 251,
          88, 229, 108, 114,  23,  88,   3, 265],
        [  0, 118,  46, 274, 105, 268, 131,  35,  19,  58, 226, 278,  27,  25,
         276, 180, 164,   4,  95,  27,  74, 201, 105,  65,  80, 185,  44, 258,
         105,  60,  58,  47, 126,  60, 294, 253, 258, 136,  29, 101, 258,  77,
          80, 180, 159, 169, 122, 117,  27, 194]])


In [7]:
def loss_fn(logits, tokens, return_per_token=False):
    # logit shape: [batch, pos, vocab]
    # token shape: [batch, pos]
    logits = logits[:, 1:]
    tokens = tokens[:, :-1]
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(dim=-1, index=tokens[..., None])[
        ..., 0
    ]  # collects the log_probs for tokens of interest -> true tokens
    if return_per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

In [8]:
# Test the loss function works
test_tokens = torch.arange(5)[None, :]
test_logits = torch.randn(1, 5, 10)
test_logits[:, 1, 0] = 10.0
test_logits[:, 2, 1] = 10.0
test_logits[:, 3, 2] = 10.0
test_logits[:, 4, 3] = 10.0
print(loss_fn(test_logits, test_tokens, return_per_token=True))
print(loss_fn(test_logits, test_tokens, return_per_token=False))

tensor([[0.0004, 0.0003, 0.0031, 0.0005]])
tensor(0.0011)


## Setup optimizer

In [9]:
batch_size = 256
num_epochs = 4000
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

data_loader = make_data_generator(cfg, batch_size)

## Training

In [10]:
trained_model_PATH = "trained_model.pkl"
if os.path.exists(trained_model_PATH):
    model.load_state_dict(torch.load(trained_model_PATH))
else:
    losses = []
    for epoch in tqdm.tqdm(range(num_epochs)):
        tokens = next(data_loader)
        # tokens = tokens.cuda()
        logits = model(tokens)
        loss = loss_fn(logits, tokens)
        loss.backward()
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        losses.append(loss.item())
        if epoch % 100 == 0:
            print(f"Epoch {epoch}: {loss.item()}")
    px.line(losses, labels={"x": "Epoch", "y": "Loss"})
    if not os.path.exists(trained_model_PATH):
        torch.save(model.state_dict(), trained_model_PATH)

# Model Interpretability

In [11]:
model.pos_embed.W_pos.norm()


tensor(5.6475)

## Look at attention patterns

In [None]:
## Give model some data
big_data_loader = make_data_generator(cfg, 10000)

## Look at how different parts of model contribute to the logits

## If the hypothesis is correct, try to interpret MLPs