In [1]:
import torch
from transformers import AutoTokenizer
from nnsight import NNsight
import plotly.express as px
import einops

from nway_attention.modules.transformer_models import Transformer

In [2]:
# The model

model = Transformer.from_pretrained("Gusanidas/triCsolOsav")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(model.cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=128, d_head=64, dt_head=64, d_mlp=2048, n_heads=12, nt_heads=2, n_layers=1, dropout=0.05, mlp_type='all', with_ln=True, order_attn=True, attn_eq=False, window_size=16, look_backward=1, pad_value=0, autopad=True)


In [3]:
# I had to reimplement the model spearating each layer in a different module so that I can use it with nnsight.
from nway_attention.modules.interp_models import TriformerCube as TriformerCubeI

triformerC = TriformerCubeI(model.cfg.to_dict())
triformerC.copy_weights_from(model)
triformerC.eval()
triformerC.to(device)

TriformerCube(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0): TriformerCubeBlock(
      (attn): TrittentionCube(
        (A): MH_Linear()
        (B): MH_Linear()
        (C): MH_Linear()
        (D): MH_Linear()
        (E): MH_Linear()
        (V): V_Layer()
        (CScore): TrittentionCScore()
        (CMask): CausalMask()
        (CPattern): TrittentionCPattern()
        (Z): Z_Layer()
        (Result): Result_Layer()
        (Out): Out_Layer()
      )
      (dropout1): Dropout(p=0.05, inplace=False)
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=768, out_features=2048, bias=True)
      (dropout): Dropout(p=0.05, inplace=False)
      (linear2): Linear(in_features=2048, out_features=768, bias=True)
      (gelu): GELU(approximate='none')
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout2): Dropout(p=0.05, inplace=False)
    )
  )
  (ln_final): LayerNorm((768,), eps=

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [5]:
# A possible indication that a head is copying from previous text if its attending to the previous token that was the same as the current token, and the one immediately after it.
def induction_head_score(attn_pattern, dif_rep,start=0, end=None):
    b, n, p, p2 = attn_pattern.shape
    head_scores = torch.zeros(12)
    if end is None:
        end = p
    for batch in range(b):
        for h in range(n):
            total, count = 0, 0
            for idxc in range(start, end):
                idxa, idxb = idxc-dif_rep[batch], idxc-dif_rep[batch]+1
                idxab = idxa*p+idxb
                total += idxab == torch.argmax(attn_pattern[batch,h,idxc,:])
                count += b
            head_scores[h] = total/count
    return head_scores

In [6]:
model = NNsight(triformerC)

In [7]:
# 4 strings with different repeating periods
period = [10,12,11,9]
s = ["@.(+<)>^@h<:=,[l)>^@h<:=,[l)>^@h<:=,",">lpr>&y%^#4Rt.t^>&y%^#4Rt.t^>&y%^#",">lpr>&y%^#4Rt.t>&y%^#4Rt.t>&y%^#4R","@.(+<)>^@h<:=[l)>^@h<:=[l)>^@h<:=[l)>"]
x = tokenizer(s, return_tensors = 'pt').input_ids

with model.trace(x) as tracer:
    c_score = model.blocks[0].attn.CMask.output.save()
    c_pattern = model.blocks[0].attn.CPattern.output.save()
    out = model.output.save()


In [8]:

hs = induction_head_score(c_pattern, period, start=12, end=22)

px.imshow(hs.unsqueeze(0), labels = {"x": "Heads"}, x = [f"Head {i}" for i in range(12)], width =1500)

In [9]:
for i, score in enumerate(hs):
    print(f"head {i}, induction score = {score}")

head 0, induction score = 0.10000000149011612
head 1, induction score = 0.22499999403953552
head 2, induction score = 0.22499999403953552
head 3, induction score = 0.02500000037252903
head 4, induction score = 0.0
head 5, induction score = 0.0
head 6, induction score = 0.0
head 7, induction score = 0.0
head 8, induction score = 0.0
head 9, induction score = 0.0
head 10, induction score = 0.0
head 11, induction score = 0.0


In [10]:
def get_accuracy(logits, tokens, start=0):
    predictions = logits.argmax(dim=-1)[:,start:-1]

    correct_predictions = (predictions == tokens[:,start+1:]).float()
    accuracy = correct_predictions.mean().item()

    return accuracy

acc = []
with model.trace() as tracer:

    with tracer.invoke(x) as invoker:
        result_c = model.blocks[0].attn.Result.output.clone().save()
        out = model.output.save()
        rr = get_accuracy(out,x).save()
        acc.append(rr)

    for h in range(12):

        with tracer.invoke(x) as invoker:
            model.blocks[0].attn.Result.output[:,:,h,:] = 0
            out = model.output.save()
            rr = get_accuracy(out,x).save()
            acc.append(rr)

for i, a in enumerate(acc):
    if i==0:
        print(f"Base accuracy = {a}")
    else:
        print(f"Ablating head {i-1}, acc = {a}")

Base accuracy = 0.3828125
Ablating head 0, acc = 0.3828125
Ablating head 1, acc = 0.390625
Ablating head 2, acc = 0.2109375
Ablating head 3, acc = 0.3828125
Ablating head 4, acc = 0.3828125
Ablating head 5, acc = 0.390625
Ablating head 6, acc = 0.3828125
Ablating head 7, acc = 0.375
Ablating head 8, acc = 0.390625
Ablating head 9, acc = 0.3828125
Ablating head 10, acc = 0.3984375
Ablating head 11, acc = 0.375


In [11]:
accuracies = []
with model.trace() as tracer:

    with tracer.invoke(x) as invoker:
        result_c = model.blocks[0].attn.Result.output.clone().save()
        out = model.output.save()
        rr = get_accuracy(out,x).save()
        accuracies.append(rr)

    for h in range(11):

        with tracer.invoke(x) as invoker:
            model.blocks[0].attn.Result.output[:,:,h:h+2,:] = 0
            out = model.output.save()
            acc = get_accuracy(out,x).save()
            accuracies.append(acc)

for i, acc in enumerate(accuracies):
    if i==0:
        print(f"Base accuracy = {acc}")
    else:
        print(f"Ablating head {i-1} and head {i}, acc = {acc}")

Base accuracy = 0.3828125
Ablating head 0 and head 1, acc = 0.375
Ablating head 1 and head 2, acc = 0.0625
Ablating head 2 and head 3, acc = 0.2109375
Ablating head 3 and head 4, acc = 0.3828125
Ablating head 4 and head 5, acc = 0.4140625
Ablating head 5 and head 6, acc = 0.40625
Ablating head 6 and head 7, acc = 0.375
Ablating head 7 and head 8, acc = 0.375
Ablating head 8 and head 9, acc = 0.3984375
Ablating head 9 and head 10, acc = 0.3984375
Ablating head 10 and head 11, acc = 0.390625


In [12]:
def logit_attribution(model, s, start=0):
    tokens = tokenizer(s, return_tensors='pt').input_ids
    with model.trace(tokens):
        embed = model.embed.output.save()
        pos_embed = model.pos_embed.output.save()
        result = model.blocks[0].attn.Result.output.save()
        mlp = model.blocks[0].linear2.output.save()

    W_U_correct_tokens = model.unembed.W_U[:, tokens[0,start+1:]]
    embed_attr = einops.einsum(W_U_correct_tokens, embed[0,start:-1,:], "emb seq, seq emb -> seq")
    pos_embed_attr = einops.einsum(W_U_correct_tokens, pos_embed[0,start:-1,:], "emb seq, seq emb -> seq")
    attn_attributions = einops.einsum(W_U_correct_tokens, result[0,start:-1,:,:], "emb seq, seq nhead emb -> seq nhead")
    mlp_attributions = einops.einsum(W_U_correct_tokens, mlp[0,start:-1,:], "emb seq, seq emb -> seq")
    return torch.concat([embed_attr.unsqueeze(-1), pos_embed_attr.unsqueeze(-1), attn_attributions, mlp_attributions.unsqueeze(-1)], dim=-1)

s = "<}]#]',[;-(+={',[;-(+={',[;-(+={',"
la = logit_attribution(model ,s ,start=8)
col_names = ["embed","pos_embed"] + [f"head_{i}" for i in range(12)]+["mlp"]
px.imshow(la.cpu().detach(),
          x =  col_names,
          width = 1000,
          height = 1000)

In [13]:
s = "We all live in a yellow submarine"
s_tokens = tokenizer.tokenize(s)
la = logit_attribution(model, s,start=0)
col_names = ["embed","pos_embed"] + [f"head_{i}" for i in range(12)]+["mlp"]
px.imshow(la.detach().cpu(),
          x =  col_names,
          y = s_tokens[1:],
          width = 1000,
          height = 1000)