In [1]:
import circuitsvis as cv
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import pickle as pkl
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dataclasses import dataclass
@dataclass
class TrainConfig:
    list_length: int = 5
    num_samples: int = 10000
    batch_size: int = 100
    min_value: int = 0
    max_value: int = 64
    bos_token: int = 65
    mid_token: int = 66
    test_seed: int = 62
    train_seed: int = 42
    var_length: bool = False

In [3]:
def load_model(pth):
    with open(pth, 'rb') as f:
        obj = pkl.load(f)
        cfg = obj['cfg']
        train_cfg = obj['train_cfg']
        model = HookedTransformer(cfg)

    model.load_state_dict(obj['model'])
    return model, cfg, train_cfg

In [4]:
model, cfg, train_cfg =load_model('../model-5.pkl')

In [5]:
def generate_batches():
    train_cfg
    random_list = torch.randint(train_cfg.min_value, train_cfg.max_value, (1, train_cfg.list_length)).tolist()
    lists = []
    for entry in random_list:
        entry = [train_cfg.bos_token] + entry + [train_cfg.mid_token] + sorted(entry)
        lists.append(entry)
    batches = torch.tensor(lists)
    return batches
    batches = torch.split(batches, 1)
    return batches

In [6]:
def id2token(x):
    if x==train_cfg.bos_token:
        return "|BOS|"
    elif x== train_cfg.mid_token:
        return "|MID|"
    else:
        return str(x)
    

In [7]:
def visualize_attn(exp_input):
    # exp_input -> shape (1, 12)
    logits, cache_model = model.run_with_cache(exp_input, remove_batch_dim=True) 
    #preds = logits[:, LIST_LENGTH+1 : -1].argmax(-1)

    # Get attention pattern and plot it
    attention_pattern = cache_model["pattern", 0, "attn"]
    tokens_input = list(map(id2token, exp_input.tolist()[0]))
    
    viz = cv.attention.attention_patterns(tokens=tokens_input, attention=attention_pattern)
    return viz, attention_pattern

In [8]:
def accuracy(batch):
    list_length = int(batch.shape[1] /2 - 1)
    unsorted = batch[:, 1:list_length+1].tolist()
    predicted = batch[:, list_length+2:].tolist()
    
    tot_correct = 0
    tot_el = 0
    for (un, pred) in zip(unsorted, predicted):
        tot_correct += sum([x == y for x,y in zip(pred, sorted(un))])
        tot_el += len(un)
        
    return tot_correct/tot_el

# Analysis
Number tokens (tokens not MID or BOS) attend mainly to the smallest token larger than itself. Exception is the last token where the pattern is not as clear. This is probably because the prediction of that token did not contribute the training loss (for a cleaner pattern one could probably at a EOS token). We can also see that the two tokens that should follow the predicted token has a high attention score 

In [9]:
batch = generate_batches()

In [10]:
viz, attn = visualize_attn(batch)

In [11]:
viz

In [12]:
custom_list  = [6, 2, 59, 50, 20]

custom_input = torch.tensor([[train_cfg.bos_token, *custom_list, train_cfg.mid_token]])
custom_input_pred  = model.generate(custom_input,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 488.91it/s]


In [13]:
v,a = visualize_attn(custom_input_pred)

In [14]:
v

## MID token injections in list

When a MID token is injected into the the unsorted list the model is still able to sort the list correctly. However since one number token is missing the largest token is copied to the end of the list.
Model copies largest token list, does not predict any MID tokens

In [20]:
mid_injections = custom_input.clone()
mid_injections[0, 4] = 66

In [21]:
injection_pred  = model.generate(mid_injections,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 534.10it/s]


In [22]:
v,a = visualize_attn(injection_pred)

In [23]:
v

## MID token removal

Even with the token removed the model is able to sort the list suggesting that the MID token it self is not the importatnt part but rather the model has learned that after index 6 the sorted list should begin an proceeds to output that. The attention pattern is similiar for the MID token position with or without the acctual MID token 

In [24]:
mid_removal = custom_input.clone()
mid_removal[0, 6] = 27

In [25]:
mid_removal

tensor([[65,  6,  2, 59, 50, 20, 27]])

In [26]:
removal_pred  = model.generate(mid_removal,
                                  max_new_tokens=train_cfg.list_length,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 591.56it/s]


In [27]:
v,a = visualize_attn(removal_pred)

In [28]:
v

## Testing a shorter list
When provided with a shorter list than trained on the model copies the smallest numbers up to the token position where the MID token would normally be (index 7). It the sorts the list as if the copied tokens where apart of it. Since the <MID> token is in the middle of where normally we would see an "normal" token it also copies the largest token to the end of the list (similiar to what happen when we injected MID tokens in the list). It is hard to get a clean picture with only a list length of 5.

In [39]:
shorter_list = [6, 2, 59]

shorter_input = torch.tensor([[train_cfg.bos_token, *shorter_list, train_cfg.mid_token]])

In [48]:
shorter_pred  = model.generate(shorter_input,
                                  max_new_tokens=7,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 484.81it/s]


In [49]:
shorter_pred.shape

torch.Size([1, 12])

In [50]:
v,a = visualize_attn(shorter_pred)

In [51]:
v

# BOS token injection
Similiar results to the mid token injections where model is still able to sort list but largest token is copied to the end.

Interesting is the attention pattern of the BOS token which is alot smoother and is not really attenting to any token in perticular which is to be expected since it is always in the begging of the sequence in training.

In [80]:
bos_injection = custom_input.clone()
bos_injection[0, 5] = train_cfg.bos_token

In [81]:
bos_injection

tensor([[65,  6,  2, 59, 50, 65, 66]])

In [82]:
bos_pred  = model.generate(bos_injection,
                                  max_new_tokens=5,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 605.36it/s]


In [83]:
v,a = visualize_attn(bos_pred)

In [84]:
v

## BOS removal
No effect on the models ability to sort the list and overall the attention pattern looks similar to when it is present (that is only a few tokens acctually atten to it).

In [85]:
bos_removal = custom_input.clone()
bos_removal[0, 0] = 12
bos_removal_pred  = model.generate(bos_removal,
                                  max_new_tokens=5,
                                  stop_at_eos=False
                                  ).clone()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 569.83it/s]


In [86]:
v,a = visualize_attn(bos_removal_pred)
v

## Toy only BOS token

In [90]:
toy_input  = torch.tensor([[train_cfg.bos_token]])
toy_pred  = model.generate(toy_input,
                                  max_new_tokens=train_cfg.list_length*2+1,
                                  stop_at_eos=False
                                  ).clone()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 433.19it/s]


Seems to really like the number 38

In [93]:
toy_pred

tensor([[65, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38]], device='cuda:0')

In [94]:
v, a = visualize_attn(toy_pred)

In [95]:
v

## Questions:  
    - Does appending EOS tokens in training make the model behave differently to MID token injections  
    - EOS tokens effect on shorting shorter lists? better generalization?  