## Setup

In [1]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

import os;

# os.environ["ACCELERATE_DISABLE_RICH"] = "1"

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import sys
from functools import partial
import json
from typing import List, Tuple, Union, Optional, Callable, Dict
import torch as t
from torch import Tensor
from sklearn.linear_model import LinearRegression
import numpy as np
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import einops
from tqdm import tqdm
from jaxtyping import Float, Int, Bool
from pathlib import Path
import pandas as pd
import circuitsvis as cv
import webbrowser
from IPython.display import display
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
from transformer_lens.components import LayerNorm
from eindex import eindex

# t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "october23_sorted_list"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.october23_sorted_list.model import create_model
from monthly_algorithmic_problems.october23_sorted_list.training import train, TrainArgs
from monthly_algorithmic_problems.october23_sorted_list.dataset import SortedListDataset
from plotly_utils import hist, bar, imshow

device = t.device("cpu")

MAIN = __name__ == "__main__"

## Dataset

In [2]:
dataset = SortedListDataset(size=10, list_len=5, max_value=15, seed=43)

print("Sequence = ", dataset[0])
print("Str toks = ", dataset.str_toks[0])

Sequence =  tensor([ 7,  5, 11,  4,  1, 16,  1,  4,  5,  7, 11])
Str toks =  ['7', '5', '11', '4', '1', 'SEP', '1', '4', '5', '7', '11']


## Transformer

In [3]:
# 1-layer transformer, 3 heads, no MLP

args = TrainArgs(
    list_len=10,
    max_value=50,
    trainset_size=150_000,
    valset_size=10_000,
    epochs=25,
    batch_size=512,
    lr_start=1e-3,
    lr_end=1e-4,
    weight_decay=0.005,
    seed=42,
    d_model=96,
    d_head=48,
    n_layers=1,
    n_heads=2,
    d_mlp=None,
    normalization_type="LN",
    use_wandb=False,
    device=device,
)
# model = train(args)

Epoch 00, Train loss = 0.0627, Accuracy: 0.9983, LR = 9.64e-04: : 293it [00:42,  6.90it/s]
Epoch 01, Train loss = 0.0616, Accuracy: 0.9986, LR = 9.28e-04: : 293it [00:38,  7.59it/s]
Epoch 02, Train loss = 0.0515, Accuracy: 0.9982, LR = 8.92e-04: : 293it [00:28, 10.32it/s]
Epoch 03, Train loss = 0.0452, Accuracy: 0.9986, LR = 8.56e-04: : 293it [00:29,  9.91it/s]
Epoch 04, Train loss = 0.0470, Accuracy: 0.9987, LR = 8.20e-04: : 293it [00:30,  9.64it/s]
Epoch 05, Train loss = 0.0403, Accuracy: 0.9986, LR = 7.84e-04: : 293it [00:32,  8.88it/s]
Epoch 06, Train loss = 0.0405, Accuracy: 0.9987, LR = 7.48e-04: : 293it [00:30,  9.52it/s]
Epoch 07, Train loss = 0.0386, Accuracy: 0.9988, LR = 7.12e-04: : 293it [00:29,  9.86it/s]
Epoch 08, Train loss = 0.0409, Accuracy: 0.9988, LR = 6.76e-04: : 293it [00:29, 10.00it/s]
Epoch 09, Train loss = 0.0431, Accuracy: 0.9987, LR = 6.40e-04: : 293it [00:29,  9.85it/s]
Epoch 10, Train loss = 0.0388, Accuracy: 0.9981, LR = 6.04e-04: : 293it [00:29, 10.04it/s]

Returning best model from epoch 24/25, with accuracy 0.9989





In [21]:
# Save the model
filename = section_dir / "AF_sorted_list_model.pt"
# t.save(model.state_dict(), filename)

# Check we can load in the model
model = create_model(
    list_len=10,
    max_value=50,
    seed=0,
    d_model=96,
    d_head=48,
    n_layers=1,
    n_heads=2,
    normalization_type="LN",
    d_mlp=None
)
model.load_state_dict(t.load(filename))

<All keys matched successfully>

In [9]:
dataset = SortedListDataset(size=500, list_len=10, max_value=50, seed=43)

logits, cache = model.run_with_cache(dataset.toks)
logits: Tensor = logits[:, dataset.list_len:-1, :]

targets = dataset.toks[:, dataset.list_len+1:]

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")

avg_cross_entropy_loss = -logprobs_correct.mean().item()

print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

# cv.attention.from_cache(
#     cache = cache,
#     tokens = dataset.str_toks,
#     batch_idx = list(range(10)),
#     radioitems = True,
#     return_mode = "view",
#     batch_labels = [" ".join(s) for s in dataset.str_toks],
#     mode = "small",
# )

Average cross entropy loss: 0.039
Mean probability on correct label: 0.966
Median probability on correct label: 0.981
Min probability on correct label: 0.001


In [271]:
#print(dataset.str_toks[0])

seq = lambda x,l : [i for i in range(x+l-1,x-1,-1)]  + [ 51 ] + [ i for i in range(x,x+l) ]
ds = [ seq(x,10) for x in range(0,50,10) ]
ds = t.tensor(ds)
list_len = int( (ds.size(-1)-1 ) / 2 ) 
seq_len  = ds.size(-1)

# ds = dataset.toks[:10]

logits, cache = model.run_with_cache(ds)
logits: Tensor = logits[:, list_len:-1, :]
targets = ds[:, list_len+1:]
logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

# 
print(ds.shape)
print(cache["pattern",0].shape)
print(probs.shape)

torch.Size([5, 21])
torch.Size([5, 2, 21, 21])
torch.Size([5, 10, 52])


In [315]:
def show(ds,probs):


    list_len = int( (ds.size(-1)-1 ) / 2 ) 
    seq_len  = ds.size(-1)

    # print("list_len",list_len)
    # print("seq_len",seq_len)

    #  tokens shape torch.Size([1, 103])
    # probs shape torch.Size([1, 103, 52])
    # list_len 51
    # seq_len 103
    # len x 51
    # probs.T shape torch.Size([52, 103])


    strs = [ str(i.item()) for i in ds[:list_len] ] 
    x    = [ ds[j] for j in range( list_len+1, seq_len) ]

    indices = t.argmax(probs.T,dim=0)


    chf = lambda s,t,i : "〇" if (s==t and t==i) else "●" if (s==t) else "○" if (s==i) else " "


    text = [[ chf( str(str_tok),str(target.item()),str(indices[target.item()].item())) for target in ds[list_len+1:seq_len]]
            for str_tok in dataset.vocab]

    # print("len x",len(x))
    # print("probs.T shape",probs.T.shape)

    imshow(
        probs.T,
        y=dataset.vocab,
        x=x,
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>Unsorted = ({','.join( strs ) })",
        text=text,
        range_color=[0, 1],
        color_continuous_scale=["white",  "blue"],
        width=800,
        height=1000,
    )

show(ds[0],probs[0])
# show(ds[1],probs[1])
# show(ds[2],probs[2])
# show(ds[3],probs[3])
# show(ds[4],probs[4])



In [333]:
def apply_causal_mask(
    attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
    '''
    Applies a causal mask to attention scores, and returns masked scores.
    '''
    all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
    mask = t.triu(all_ones, diagonal=1).bool()
    attn_scores.masked_fill_(mask, model.IGNORE)
    return attn_scores

model.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

total_tokens   = 51
model_list_len = 10
start_pos = total_tokens-model_list_len

tokens = t.tensor([seq(0,total_tokens)]) #51
list_len        = int( (tokens.size(-1)-1 ) / 2 ) 

# print("tokens",tokens.shape)
# print("tokens",tokens)
resid_pre = model.embed.W_E[tokens]

pos_embed     = t.zeros_like(resid_pre)
pos_embed_mod = model.pos_embed(tokens)
pos_embed[:,start_pos : start_pos +pos_embed_mod.size(1), : ] = pos_embed_mod

print( )
print( "non-repeated", pos_embed_mod[0,:].shape )
print( "repeated",pos_embed_mod[0,:].unsqueeze(0).repeat(1, start_pos).shape)

if (start_pos>0) :
    pos_embed[:, 0 : start_pos, : ] = pos_embed_mod[0,:].unsqueeze(0).repeat(1, start_pos)

resid_pre = resid_pre + pos_embed


# reside-pre torch.Size([1, 21, 96])
# pos_embed torch.Size([1, 21, 96])


 #+ model.pos_embed(tokens)

print( "resid_pre", resid_pre.shape)
print( "pos_embed", model.pos_embed(tokens).shape)



# print('resid_pre', t.sum( resid_pre - cache["resid_pre",0][0,:] ) )

normalized = model.blocks[0].ln1(resid_pre) 

# print('normalized', t.sum( normalized - cache["normalized",0][0,:] ) )

q = einops.einsum(
            normalized, model.W_Q[0],
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + model.b_Q[0]

# print("q comp", t.sum( q - cache["q",0][0,:] ) )


k = einops.einsum(
    normalized, model.W_K[0],
    "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
) + model.b_K[0]

# print("k comp", t.sum( k - cache["k",0][0,:] ) )


v = einops.einsum(
    normalized, model.W_V[0],
    "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
) + model.b_V[0]

attn_scores = einops.einsum(
    q, k,
    "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K",
)
attn_scores_masked = apply_causal_mask(attn_scores / model.cfg.d_head ** 0.5)
attn_pattern = attn_scores_masked.softmax(-1)


z = einops.einsum(
    v, attn_pattern,
    "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
)

# Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
attn_out = einops.einsum(
    z, model.W_O[0],
    "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
) + model.b_O[0]

resid_post = attn_out + resid_pre

# print("resid_post", t.sum( resid_post - cache["resid_post",0][0,:] ) )

logits = model.unembed(model.ln_final(resid_post))


logits = logits[:, list_len:-1, :]
logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
token_probs = logprobs.softmax(-1)


print("tokens shape",tokens.shape)
print("probs shape",probs.shape)

# torch.Size([5, 21])
# torch.Size([5, 10, 52])
show(tokens[0],token_probs[0])


non-repeated torch.Size([21, 96])


RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

In [322]:
import plotly.subplots as sp
import plotly.graph_objects as go


def visualize(attention_matrix,labels_tensor) :

    # # Assuming attention_matrix is your tensor
    # attention_matrix = cache["pattern",0]  # Replace this with your actual tensor

    # # Assuming labels_tensor is your tensor of shape [5, 21]
    # labels_tensor = ds # Replace this with your actual tensor


    pi = lambda i,j : l[j] if ( i==(len(l)-1)/2 ) else " "
    pj = lambda i,j : l[i] if ( j==(len(l)-1)/2 ) else " "
    p  = lambda i,j : pi(i,j) + " " + pj(i,j)

    # colors = [ "rgb(165,0,38)","rgb(215,48,39)","rgb(244,109,67)","rgb(253,174,97)","rgb(254,224,144)","rgb(224,243,248)","rgb(171,217,233)","rgb(116,173,209)","rgb(69,117,180)","rgb(49,54,149)"]
    # colorscale = [ [ i/90, colors[i]] for i in range(len(colors)) ]
    # colorscale[-1][0] = 1.0

    # print(colorscale)


    i = 0
    for calc_idx in range( attention_matrix.shape[0] ):
        for head_idx in range( attention_matrix.shape[1]):
            i = i + 1
            l=labels_tensor[calc_idx,:].detach().cpu().numpy().astype(str).tolist()
            text = [ [ p(i,j) for i in range(len(l)) ] for j in range(len(l))   ]

            indices = t.argmax(attention_matrix[calc_idx, head_idx], dim=-1)
            indices = indices.cpu().numpy().tolist()

            for idx,i in enumerate(range(len(indices)-1,-1,-1)):
                text[idx][indices[i]] = "〇"

            fig = go.Figure(data=go.Heatmap(
                z=np.flip( t.log(attention_matrix[calc_idx, head_idx]).detach().cpu().numpy(), 0 ),
                text = text,
                texttemplate="%{text}",
                colorscale='Viridis',
                # colorscale=colorscale,
                type = 'heatmap',
                colorbar=dict(len=0.2, y=(calc_idx + 0.5) / 5),
                textfont={"size":8},
                showscale=True,
            ))
            
            # fig.add_trace(heatmap, row=i , col= 1)

            # Update layout for better readability
            fig.update_layout(
                height=1200,
                width=1200,
                title_text= f'Calculation {calc_idx}, Head {head_idx}',
                showlegend=True,
                # xaxis=dict(scaleanchor="y", scaleratio=1),  # Equal scaling
                # yaxis=dict(scaleanchor="x", scaleratio=1),  # Equal scaling
            )
            fig.show() #UNCOMMENT TO SHOW EACH HEATMAP

#visualize( cache["pattern",0],ds)
visualize( attn_pattern,tokens)