In [1]:
import circuitsvis as cv
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import pickle as pkl
import torch

In [2]:
from dataclasses import dataclass
@dataclass
class TrainConfig:
    list_length: int = 5
    num_samples: int = 10000
    batch_size: int = 100
    min_value: int = 0
    max_value: int = 64
    bos_token: int = 65
    mid_token: int = 66
    test_seed: int = 62
    train_seed: int = 42
    var_length: bool = False

In [3]:
def load_model(pth):
    with open(pth, 'rb') as f:
        obj = pkl.load(f)
        cfg = obj['cfg']
        train_cfg = obj['train_cfg']
        model = HookedTransformer(cfg)

    model.load_state_dict(obj['model'])
    return model, cfg, train_cfg

In [6]:
model, cfg, train_cfg =load_model('../model-10l.pkl')

In [7]:
train_cfg

TrainConfig(list_length=10, num_samples=20000, batch_size=100, min_value=0, max_value=64, bos_token=65, mid_token=66, test_seed=62, train_seed=42, var_length=False)

In [8]:
def generate_batches():
    train_cfg
    random_list = torch.randint(train_cfg.min_value, train_cfg.max_value, (1, train_cfg.list_length)).tolist()
    lists = []
    for entry in random_list:
        entry = [train_cfg.bos_token] + entry + [train_cfg.mid_token] + sorted(entry)
        lists.append(entry)
    batches = torch.tensor(lists)
    return batches
    batches = torch.split(batches, 1)
    return batches

In [9]:
def id2token(x):
    if x==train_cfg.bos_token:
        return "|BOS|"
    elif x== train_cfg.mid_token:
        return "|MID|"
    else:
        return str(x)
    

In [10]:
def visualize_attn(exp_input):
    # exp_input -> shape (1, 12)
    logits, cache_model = model.run_with_cache(exp_input, remove_batch_dim=True) 
    #preds = logits[:, LIST_LENGTH+1 : -1].argmax(-1)

    # Get attention pattern and plot it
    attention_pattern = cache_model["pattern", 0, "attn"]
    tokens_input = list(map(id2token, exp_input.tolist()[0]))
    
    viz = cv.attention.attention_patterns(tokens=tokens_input, attention=attention_pattern)
    return viz, attention_pattern

In [11]:
def accuracy(batch):
    list_length = int(batch.shape[1] /2 - 1)
    unsorted = batch[:, 1:list_length+1].tolist()
    predicted = batch[:, list_length+2:].tolist()
    
    tot_correct = 0
    tot_el = 0
    for (un, pred) in zip(unsorted, predicted):
        tot_correct += sum([x == y for x,y in zip(pred, sorted(un))])
        tot_el += len(un)
        
    return tot_correct/tot_el

In [12]:
batch = generate_batches()

In [13]:
viz, attn = visualize_attn(batch)

In [14]:
viz

In [15]:
artificial_input = [6, 2, 59, 50, 20, 21, 22, 23, 24, 25]
artificial_input = torch.tensor([[train_cfg.bos_token, *artificial_input, train_cfg.mid_token]])
artificial_input = model.generate(artificial_input,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone().ravel()

  0%|          | 0/10 [00:00<?, ?it/s]

In [16]:
artificial_input.unsqueeze(0)

tensor([[65,  6,  2, 59, 50, 20, 21, 22, 23, 24, 25, 66,  2,  6, 19, 20, 21, 22,
         23, 24, 50, 59]])

In [17]:
v,a = visualize_attn(artificial_input.unsqueeze(0))

In [18]:
v

In [19]:
accuracy(artificial_input.unsqueeze(0))

0.4