In [1]:
import circuitsvis as cv
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import pickle as pkl
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dataclasses import dataclass
@dataclass
class TrainConfig:
    list_length: int = 5
    num_samples: int = 10000
    batch_size: int = 100
    min_value: int = 0
    max_value: int = 64
    bos_token: int = 65
    mid_token: int = 66
    test_seed: int = 62
    train_seed: int = 42
    var_length: bool = False

In [3]:
def load_model(pth):
    with open(pth, 'rb') as f:
        obj = pkl.load(f)
        cfg = obj['cfg']
        train_cfg = obj['train_cfg']
        model = HookedTransformer(cfg)

    model.load_state_dict(obj['model'])
    return model, cfg, train_cfg

In [4]:
model, cfg, train_cfg =load_model('../model-10l-test.pkl')

In [5]:
def generate_batches():
    train_cfg
    random_list = torch.randint(train_cfg.min_value, train_cfg.max_value, (1, train_cfg.list_length)).tolist()
    lists = []
    for entry in random_list:
        entry = [train_cfg.bos_token] + entry + [train_cfg.mid_token] + sorted(entry)
        lists.append(entry)
    batches = torch.tensor(lists)
    return batches
    batches = torch.split(batches, 1)
    return batches

In [6]:
def id2token(x):
    if x==train_cfg.bos_token:
        return "|BOS|"
    elif x== train_cfg.mid_token:
        return "|MID|"
    else:
        return str(x)
    

In [7]:
def visualize_attn(exp_input):
    # exp_input -> shape (1, 12)
    logits, cache_model = model.run_with_cache(exp_input, remove_batch_dim=True) 
    #preds = logits[:, LIST_LENGTH+1 : -1].argmax(-1)

    # Get attention pattern and plot it
    attention_pattern = cache_model["pattern", 0, "attn"]
    tokens_input = list(map(id2token, exp_input.tolist()[0]))
    
    viz = cv.attention.attention_patterns(tokens=tokens_input, attention=attention_pattern)
    return viz, attention_pattern

In [8]:
def accuracy(batch):
    list_length = int(batch.shape[1] /2 - 1)
    unsorted = batch[:, 1:list_length+1].tolist()
    predicted = batch[:, list_length+2:].tolist()
    
    tot_correct = 0
    tot_el = 0
    for (un, pred) in zip(unsorted, predicted):
        tot_correct += sum([x == y for x,y in zip(pred, sorted(un))])
        tot_el += len(un)
        
    return tot_correct/tot_el

# Analysis

In [9]:
batch = generate_batches()

In [10]:
viz, attn = visualize_attn(batch)

In [11]:
viz

In [12]:
custom_list  = [6, 2, 59, 50, 20, 21, 22, 23, 24, 25]

custom_input = torch.tensor([[train_cfg.bos_token, *custom_list, train_cfg.mid_token]])
custom_input_pred  = model.generate(custom_input,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 635.68it/s]


In [13]:
v,a = visualize_attn(custom_input_pred)

In [14]:
v

# MID injection
Injection of the MID token into the input leads to the model inserting new tokens and completly forgetting others. This complete breakdown of the model not something that was seen in the 5 list length model. This behaviour was observed in multiple models that where trained for different lengths suggesting list length leads the model to be more sensitive.

In [15]:
mid_insertion  = [46, 57, 10, 20, 15, 14, 36, 35, 37, 25]

mid_insertion[9] = train_cfg.mid_token

mid_input = torch.tensor([[train_cfg.bos_token, *mid_insertion, train_cfg.mid_token]])
mid_pred  = model.generate(mid_input,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 606.57it/s]


In [16]:
mid_pred

tensor([[65, 46, 57, 10, 20, 15, 14, 36, 35, 37, 66, 66, 10, 14, 15, 20, 23, 27,
         35, 36, 37, 57]], device='cuda:0')

In [17]:
v, a = visualize_attn(mid_pred)
v

In [24]:
sorted(mid_insertion)

[10, 14, 15, 20, 35, 36, 37, 46, 57, 66]

# MID removal
Removing the mid token leads to the model skipping, inserting and copying tokens, which ofcourse means it is unable to perform the task. Since we saw a similiar breakdown when injecting the MID token into the input, it is clear that this model relies more on the actual MID token and not the positional embedding of the MID token in training.

In [21]:
mid_removal  = [46, 57, 10, 20, 15, 14, 36, 35, 37, 25]

mid_rem_in = torch.tensor([[train_cfg.bos_token, *mid_removal, train_cfg.mid_token]])

mid_rem_in[0, -1] = 33
print(mid_rem_in)

rem_pred  = model.generate(mid_rem_in,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

tensor([[65, 46, 57, 10, 20, 15, 14, 36, 35, 37, 25, 33]])


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 900.16it/s]


In [22]:
v, a = visualize_attn(rem_pred)
v

In [20]:
sorted(mid_removal)

[10, 14, 15, 20, 25, 35, 36, 37, 46, 57]