In [1]:
import os

os.environ['http_proxy'] = 'http://10.176.58.101:7890'
os.environ['https_proxy'] = 'http://10.176.58.101:7890'

import sys
sys.path.append('/remote-home1/jxwang/project/lorsa/src')

import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.components import Attention
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
torch.set_grad_enabled(False)

import copy

from tqdm import tqdm

import pysvelte
import numpy as np
import einops

from model.attention import LowRankSparseAttention
from config import LoRSATrainConfig, LoRSAConfig

In [42]:
from config import LoRSATrainConfig, LoRSAConfig
cfg = LoRSAConfig(
    # meta head config
    d_qk_head = 16,
    d_ov_head = 1,
    n_qk_heads = 96,
    n_ov_heads = 1536,
    rotary_scale = 1,
    rotary_dim = 16,
    device = "cuda",
    virtual_kv_num=0,
    
    mode = 'top_k',
    top_k = 64,
    
    # orig head config
    model_name = 'EleutherAI/pythia-160m',
    layer = 5,
    max_length = 256,
    prepend_bos = True,
)

In [44]:
hf_model = GPTNeoXForCausalLM.from_pretrained(cfg.model_name)
tokenizer = GPTNeoXTokenizerFast.from_pretrained(cfg.model_name)
model = HookedTransformer.from_pretrained(cfg.model_name, hf_model=hf_model, tokenizer=tokenizer, device=cfg.device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

# get original attention block
orig_attn = model.blocks[cfg.layer].attn

# load dataset
dataset_path = '/remote-home1/share/research/mechinterp/gpt2-dictionary/data/openwebtext'
dataset = load_from_disk(dataset_path)

Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


Loading dataset from disk:   0%|          | 0/80 [00:00<?, ?it/s]

In [45]:
cfg.update_from_model_config(model.cfg)

# initialize lorsa
lorsa = LowRankSparseAttention(cfg).to(cfg.device)

lorsa_path = '/remote-home1/jxwang/project/lorsa/result/pythia-160m-sah/L5A-d16&1-n96&1536-k64.pth'

state_dict = torch.load(lorsa_path, map_location=cfg.device)

lorsa.load_state_dict(state_dict)

  state_dict = torch.load(lorsa_path, map_location=cfg.device)


<All keys matched successfully>

In [46]:
sentence = '''An Anglican priest charged last month with sexually assaulting five youths in Edmonton during the 1980s is facing more charges after police say more victims came forward.

In February, Edmonton police charged Gordon William Dominey, 63, with five counts of sexual assault and five counts of gross indecency in connection with incidents that allegedly occurred at the Edmonton Youth Development Centre from 1985 to 1989.

He now faces 18 sexual assault charges and nine gross indecency charges in relation to the alleged assaults, which are reported to have happened at the incarceration facility.

Dominey was employed at the centre at the time.

The new charges come after four more men came forward in the last month to report allegations of sexual assault at the facility during the same time period, Edmonton police say.

Dominey transferred from the Diocese of Edmonton to the Diocese of New Westminster in Vancouver in July of 1990.

He was the priest-in-charge at St. Catherine's Anglican Church in North Vancouver from the summer of 2015 until he was arrested in Coquitlam on Feb. 4, when the Bishop of New Westminster, Melissa Skelton, placed Dominey on administrative leave, according to a release.

Police, who'''

In [115]:
head_index = 6

hook_in_name = f'blocks.{cfg.layer}.ln1.hook_normalized'
hook_out_name = f'blocks.{cfg.layer}.hook_attn_out'

names_filter = [hook_in_name] 

tokens = model.to_tokens(sentence, prepend_bos=cfg.prepend_bos).to(cfg.device)

_, cache = model.run_with_cache(tokens, names_filter=names_filter)
hook_in = cache[hook_in_name]

q, k, v = lorsa.cal_q_k_v(hook_in)
pattern = lorsa.cal_pattern(q, k)

v_ = einops.rearrange(
    v, "batch key_pos head_index d_head -> batch head_index key_pos d_head"
)[:, head_index].squeeze(dim=-1) # Shape: (batch_size, key_pos)
pattern_ = pattern[:, head_index // cfg.n_qk_heads] # Shape: (batch_size, query_pos, key_pos)
z_k = pattern_[:, :, :] * v_[:, None, :] # Shape: (batch_size, query_pos, key_pos)
z = z_k.sum(dim=-1) # Shape: (batch_size, query_pos)
attn_scores = z_k

out = lorsa.cal_out_with_h(hook_in)
out_head = out[:, :, head_index, :]
torch.cuda.empty_cache()
out_head_l2 = torch.linalg.vector_norm(out_head, dim=-1)

In [93]:
qq, kk, vv = lorsa.cal_q_k_v(hook_in)
pp = lorsa.cal_pattern(qq, kk)

zz = lorsa.cal_z_with_h(vv, pp)
print(zz.shape)
zz = zz[:, :, head_index, 0]
print(zz.shape)

torch.Size([1, 256, 1536, 1])
torch.Size([1, 256])


In [66]:
top_k_values, top_k_indices = torch.topk(out_head_l2[0], 5)
print(torch.topk(out_head_l2[0], 5))
print(z[0, top_k_indices])

torch.return_types.topk(
values=tensor([1.8675, 1.8079, 1.6792, 1.6186, 1.4801], device='cuda:0'),
indices=tensor([183,  73, 187, 181, 232], device='cuda:0'))
tensor([0.0503, 0.0457, 0.0505, 0.0508, 0.0551], device='cuda:0')


In [39]:
print(model.to_string(tokens[0, 180:185]))
print(model.to_string(tokens[0, top_k_indices[0, 0]]))
print(model.to_string(tokens[0, top_k_indices[0, 3]]))

 Westminster in Vancouver in July
 in
 in


In [31]:
print(model.to_string(tokens[0, 70:75]))
print(model.to_string(tokens[0, top_k_indices[0, 1]]))

 Edmonton Youth Development Centre from
 Centre


In [54]:
print(v[0, :, head_index, 0])

tensor([ 4.5033e-02, -1.2535e+00,  1.0859e-01,  2.8544e-01,  5.8323e-01,
        -3.6914e-01,  1.1403e+00,  7.7837e-01, -2.0396e+00, -1.9188e+00,
        -1.2604e+00,  6.1392e-02,  1.4914e-01,  9.7172e-01,  3.5199e+00,
         1.9477e+00,  1.8159e-01,  5.0853e-01, -1.2331e-01, -1.9612e-01,
         2.3607e-01, -1.0120e+00, -7.9854e-01, -4.0562e-01,  8.5122e-01,
        -5.3008e-02, -1.0932e+00, -5.6544e-01, -2.9361e-01,  6.8002e-01,
        -2.8451e-02,  2.7002e-01, -6.8879e-01, -3.6795e-01,  4.5346e-01,
         6.9142e-01,  2.8832e+00,  2.5041e+00,  5.9836e-01,  1.3247e-01,
        -7.8375e-02, -7.2326e-01, -1.5743e-01,  6.6646e-01,  7.4840e-01,
         3.0497e-01,  1.0895e+00,  9.9439e-01, -6.5711e-01, -1.1201e+00,
         2.3233e-01, -2.3013e+00, -1.9486e+00, -1.3800e+00, -5.6802e-01,
        -4.8294e-01,  1.1002e-01, -2.1244e+00, -1.1643e+00, -1.2620e+00,
        -2.0436e+00, -1.0400e+00, -1.4681e+00, -1.4411e-01, -1.2542e+00,
        -9.8134e-02, -7.9384e-01, -1.0335e+00,  5.4

In [53]:
print(pattern[0, head_index%cfg.n_qk_heads, :, 0])

tensor([1.0000, 1.0000, 1.0000, 0.9983, 0.9989, 0.9996, 0.9995, 0.9998, 0.9999,
        0.9999, 0.9988, 0.9999, 0.9989, 0.9998, 0.9993, 0.9997, 0.9998, 1.0000,
        0.9994, 0.9992, 0.9979, 1.0000, 0.9987, 0.9974, 0.9988, 0.9997, 0.9999,
        0.9994, 1.0000, 0.9997, 0.9667, 0.9999, 0.9993, 0.9997, 0.9876, 0.9976,
        0.9928, 0.9972, 0.9951, 0.9447, 0.9593, 0.9986, 0.9991, 0.9972, 0.9856,
        0.9954, 0.9901, 0.9873, 0.9937, 0.9989, 0.9983, 0.9895, 0.9824, 0.9930,
        0.9857, 0.9937, 0.9877, 0.9976, 0.9983, 0.9992, 0.9918, 0.9975, 0.9992,
        0.9974, 0.9984, 0.9993, 0.9973, 0.9992, 0.9997, 0.9938, 0.9941, 0.9993,
        0.9997, 0.9992, 0.9997, 0.9992, 0.9961, 0.9671, 0.9956, 0.9998, 0.9989,
        0.9956, 0.9987, 0.9929, 0.9835, 0.9952, 0.9927, 0.9782, 0.9974, 0.9252,
        0.9838, 0.9952, 0.9988, 0.9903, 0.9890, 0.9995, 0.9935, 0.9987, 0.9936,
        0.9995, 0.9933, 0.9974, 0.9939, 0.9978, 0.9989, 0.9962, 0.9996, 0.9998,
        0.9999, 0.9871, 0.9998, 0.9997, 

In [51]:
print(z[0, :, 0])

tensor([0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450,
        0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450, 0.0450,
        0.0450, 0.0450, 0.0449, 0.0450, 0.0450, 0.0449, 0.0450, 0.0450, 0.0450,
        0.0450, 0.0450, 0.0450, 0.0435, 0.0450, 0.0450, 0.0450, 0.0445, 0.0449,
        0.0447, 0.0449, 0.0448, 0.0425, 0.0432, 0.0450, 0.0450, 0.0449, 0.0444,
        0.0448, 0.0446, 0.0445, 0.0447, 0.0450, 0.0450, 0.0446, 0.0442, 0.0447,
        0.0444, 0.0447, 0.0445, 0.0449, 0.0450, 0.0450, 0.0447, 0.0449, 0.0450,
        0.0449, 0.0450, 0.0450, 0.0449, 0.0450, 0.0450, 0.0448, 0.0448, 0.0450,
        0.0450, 0.0450, 0.0450, 0.0450, 0.0449, 0.0436, 0.0448, 0.0450, 0.0450,
        0.0448, 0.0450, 0.0447, 0.0443, 0.0448, 0.0447, 0.0440, 0.0449, 0.0417,
        0.0443, 0.0448, 0.0450, 0.0446, 0.0445, 0.0450, 0.0447, 0.0450, 0.0447,
        0.0450, 0.0447, 0.0449, 0.0448, 0.0449, 0.0450, 0.0449, 0.0450, 0.0450,
        0.0450, 0.0445, 0.0450, 0.0450, 

In [48]:
out = lorsa.cal_out_with_h(hook_in, mode='default')
out_head = out[:, :, head_index, :]
torch.cuda.empty_cache()
out_head_l2 = torch.linalg.vector_norm(out_head, dim=-1)

In [50]:
print(out_head_l2.shape)
print(out_head_l2)

torch.Size([1, 256])
tensor([[4.5002e-02, 7.6897e-03, 3.2408e-02, 4.6751e-02, 8.7810e-02, 6.3880e-02,
         1.6500e-01, 9.6477e-02, 1.0406e-02, 1.9297e-01, 2.7617e-01, 1.0722e-01,
         2.9472e-02, 2.9734e-02, 1.7108e-01, 5.8062e-01, 3.4421e-01, 1.7653e-01,
         2.7053e-01, 5.8655e-01, 3.0750e-01, 6.6732e-02, 1.0903e-02, 5.0320e-02,
         1.2463e-01, 9.1403e-02, 1.8163e-02, 7.0092e-02, 9.6777e-02, 3.3715e-02,
         1.7489e-02, 9.4482e-02, 2.0362e-01, 3.5242e-02, 3.6481e-02, 9.9814e-02,
         1.8027e-01, 8.5264e-01, 2.6575e-01, 7.5803e-02, 1.4275e-01, 3.6921e-02,
         1.7107e-01, 3.1248e-01, 2.4772e-01, 1.5634e-01, 3.6048e-01, 2.0048e-01,
         1.3860e-01, 4.0982e-02, 4.1067e-03, 3.0419e-03, 1.2599e-01, 6.5150e-01,
         3.4511e-01, 4.2831e-01, 4.4679e-01, 7.5554e-02, 2.1146e-01, 6.0813e-02,
         7.5260e-01, 5.6774e-01, 3.4643e-01, 5.8420e-01, 3.6671e-01, 2.5031e-01,
         4.6488e-01, 3.0179e-01, 1.5844e-01, 5.9522e-02, 1.2039e-01, 2.1348e-01,
       