In [1]:
from transformers_fixed import AutoModelForCausalLM, AutoTokenizer
import torch 
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from tqdm.notebook import tqdm

from bertviz import head_view
from bertviz.neuron_view import show
from datasets import load_dataset

%load_ext autoreload
%autoreload 2

In [2]:


# model_name = "EleutherAI/llemma_7b" # Support
model_name = 'huggyllama/llama-7b' # Support

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "[PAD]"
tokenizer.padding_side = "left"
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# model.eval()

mistral_name = 'mistralai/Mistral-7B-v0.1' # Currently out of support :(
mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_name)
mistral_tokenizer.pad_token = "[PAD]"
mistral_tokenizer.padding_side = "left"
mistral_model = AutoModelForCausalLM.from_pretrained(mistral_name, device_map="auto", torch_dtype=torch.float16)
mistral_model.eval()

1


# def ret_queries(sent, model=model):
#     """
#     return (n_layers, n_heads, emb_query)
#     """
#     inputs = tokenizer(sent, return_tensors="pt").to("cuda")
#     # splitted_text = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
#     model.forward(**inputs)
#     qk_dict = model.get_qk_dict()
#     queries = torch.stack([qk['query'] for l, qk in qk_dict.items()]).squeeze()

#     return queries

# def ret_keys(sent, model=model):
#     """
#     return (n_layers, n_heads, emb_query)
#     """
#     inputs = tokenizer(sent, return_tensors="pt").to("cuda")
#     # splitted_text = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
#     model.forward(**inputs)
#     qk_dict = model.get_qk_dict()
#     queries = torch.stack([qk['key'] for l, qk in qk_dict.items()]).squeeze()

#     return queries


# def find_token_idx(tok: str, sent: str):
#     inp = tokenizer(sent, return_tensors="pt").to("cuda")
#     tok_list = tokenizer.convert_ids_to_tokens(inp['input_ids'][0])
#     # print(tok_list)
#     return tok_list.index(tok)
# inp


# model.device

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

1

In [3]:
def generate(sents, model=mistral_model, tokenizer=mistral_tokenizer):
    inputs = tokenizer(sents, return_tensors="pt", padding=True).to(model.device)

    output_sequences = model.generate(**inputs, max_new_tokens=200, use_cache=True, do_sample=True, top_k=10, eos_token_id=tokenizer.eos_token_id).to('cpu')

    return tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

In [4]:
import json

with open('./squad/train-v2.0.json', 'rt') as f:
    ds = json.load(f)

# Parsing json to DataFrame

res = []
for curr in ds['data']:
    title = curr['title']
    for par in curr['paragraphs']:
        context = par['context']
        for qa in par['qas']:
            question = qa['question']
            id = qa['id']
            ans = qa['answers']
            is_imposs = qa['is_impossible']

            res.append([title, context, question, ans, id, is_imposs])
        
df = pd.DataFrame(res, columns=['title', 'context', 'question', 'answer', 'id', 'is_impossible'])

# Filtering "nice" contexts and questions

# Context has only one '.' symbol 
sub_df = df.loc[df['context'].isin(list(filter(lambda x: x.count('.') == 1, df['context'].unique()))), :]

# Question has only one '?' symbol
sub_df = sub_df.loc[sub_df['question'].apply(lambda x: x.count('?')) == 1, :]

# Handcrafted feature (?): Context should has at least 1 impossible question and at lest 4 possible questions
count_df = sub_df.groupby('context').agg({'is_impossible': ['sum', 'count']})
nice_cont = count_df.loc[(count_df['is_impossible']['sum'] >= 1) & (count_df['is_impossible']['count'] - count_df['is_impossible']['sum'] >= 4) , :].index
print(len(nice_cont))

# Filter such "nice" contexts
squad = df.loc[df['context'].isin(nice_cont), :]
squad

86


Unnamed: 0,title,context,question,answer,id,is_impossible
2375,The_Legend_of_Zelda:_Twilight_Princess,A CD containing 20 musical selections from the...,What company included the soundtrack as a rewa...,"[{'text': 'GameStop', 'answer_start': 71}]",56cda64a62d2951400fa67be,False
2376,The_Legend_of_Zelda:_Twilight_Princess,A CD containing 20 musical selections from the...,How many tracks were recorded on the preorder CD?,"[{'text': '20', 'answer_start': 16}]",56cda64a62d2951400fa67bf,False
2377,The_Legend_of_Zelda:_Twilight_Princess,A CD containing 20 musical selections from the...,In what areas is the content of the GameStop b...,"[{'text': 'Japan, Europe, and Australia', 'ans...",56cda64a62d2951400fa67c0,False
2378,The_Legend_of_Zelda:_Twilight_Princess,A CD containing 20 musical selections from the...,What was included as a Gamestop preorder item?,"[{'text': 'CD', 'answer_start': 2}]",56d13400e7d4791d00901fdd,False
2379,The_Legend_of_Zelda:_Twilight_Princess,A CD containing 20 musical selections from the...,What company included the soundtrack as a rewa...,[],5a8dbd49df8bba001a0f9bb5,True
...,...,...,...,...,...,...
128102,Anthropology,Anthropology of development tends to view deve...,What does a lot of planned development apparen...,"[{'text': 'fail', 'answer_start': 527}]",5733cd1c4776f4190066127e,False
128103,Anthropology,Anthropology of development tends to view deve...,What tends to view development from a positive...,[],5ad2ee84604f3c001a3fd9ef,True
128104,Anthropology,Anthropology of development tends to view deve...,What field of anthropology has a goal to elevi...,[],5ad2ee84604f3c001a3fd9f0,True
128105,Anthropology,Anthropology of development tends to view deve...,What looks for the connections between plans a...,[],5ad2ee84604f3c001a3fd9f1,True


In [5]:
# mnli = load_dataset('glue', 'mnli')
# mnli = mnli['train'].to_pandas()

# group = mnli.groupby('premise').agg({'hypothesis': 'count', 'label': pd.Series.nunique})
# cont1 = group.loc[(group['hypothesis'] > 3) & (group['label'] == 3), :].index

# mnli_filtred = mnli.loc[mnli['premise'].isin(cont1), :]
# print(mnli_filtred.shape, mnli.shape)

# mnli = mnli_filtred

In [6]:
# aa = (mnli['premise'] + ' ' + mnli['hypothesis']).apply(lambda s: tokenizer(s)['input_ids'].__len__())

In [7]:
# plt.hist(aa, bins=50)

In [8]:
# (aa < 15).mean()

In [9]:
# mnli_part = mnli.loc[aa.between(15, 20), :]

In [10]:
# with open('mnli_part_sent.pkl', 'wb') as f:
#     pickle.dump(mnli_part_sent, f)

In [11]:
# mnli_part_sent = mnli_part['premise'] + ' ' + mnli_part['hypothesis']
# mnli_part_sent = mnli_part_sent.to_list()


with open('mnli_part_sent.pkl', 'rb') as f:
    mnli_part_sent = pickle.load(f)

mnli_part_sent[:4], len(mnli_part_sent)

(['How do you know? All this is their information again. This information belongs to them.',
  'Issues in Data Synthesis. Problems in data synthesis.',
  'well you see that on television also You can see that on television, as well.',
  'The other men shuffled. The other men were shuffled around.'],
 33971)

In [12]:
import random

mnli_part_sent_10k = random.sample(mnli_part_sent, 10_000)

In [13]:
len(mnli_part_sent_10k)

10000

In [14]:
# real_att = []
# all_tokens = []
# bs = 32
# for i in tqdm(range(0, len(mnli_part_sent_10k), bs)):
#     curr_sents = mnli_part_sent[i:i+bs]
#     inputs = tokenizer(curr_sents, return_tensors="pt", padding=True).to("cuda")
#     res = model.forward(**inputs, output_attentions=True)

#     all_tokens.append(inputs['input_ids'].cpu())
#     real_att.append(torch.stack(res.attentions).detach().cpu())
#     torch.cuda.empty_cache()


In [15]:
# real_att = list(filter(lambda s: s.shape[-1] == 20, real_att))
# real_att_merge = torch.concat(real_att, dim=1)
# del real_att

In [16]:
# real_att_merge.shape

In [17]:
# real_att_merge = real_att_merge.cpu()

In [18]:
# torch.save(real_att_merge, 'real_att_merge.pt')

In [19]:
real_att_merge = torch.load('real_att_merge.pt')[:, :10_000, ...]

In [20]:
# all_tokens = list(filter(lambda x: x.shape[1] == 20, all_tokens))

In [21]:
# all_tokens_merge = torch.concat(all_tokens, dim=0)
# del all_tokens

In [22]:
# all_tokens_merge.shape

In [23]:
# all_tokens_merge = all_tokens_merge.cpu()

In [24]:
# torch.save(all_tokens_merge, 'all_tokens_merge.pt')

In [25]:
all_tokens_merge = torch.load('all_tokens_merge.pt')[:10_000, ...]

In [26]:
def templ(att, tokens, tokenizer=tokenizer):
    mask = tokens != 0
    tokens = tokens[mask]
    tokens = tokenizer.convert_ids_to_tokens(tokens)
    tokens = list(map(lambda x: x.replace('▁', ''), tokens))

    output = ['<start>']
    for i in range(len(tokens)):
        for j in range(0, i):
            output.append(f"{tokens[i]} -> {tokens[j]}\t{att[i, j]}")
        output.append('')

    output.append('<end>')
    output = '\n'.join(output)

    return output
    

In [27]:
tokenizer('Gods')

{'input_ids': [1, 4177, 29879], 'attention_mask': [1, 1, 1]}

In [28]:
map_att = (real_att_merge * 100).to(torch.int32)

In [29]:
layer_n = 31
head_n = 29
i = 500

templ(map_att[layer_n, i, head_n, ...], all_tokens_merge[i, ...])
tokenizer(templ(map_att[layer_n, i, head_n, ...], all_tokens_merge[i, ...]))['input_ids'].__len__()

715

In [30]:
prompt_exp = """We're studying self-attention heads in transformer large language model. Each head looks for some particular relationship between tokens in a short document. Look at the attention matrix for the part of the document and summarize in a single sentence what the head is looking for. Don't list examples of words.

The attention matrix format is tokenA -> tokenB<tab>attention_score. Attention score values range from 0 to 100. A head finding what it's looking for is represented by a non-zero attention score. The higher the attention score, the stronger the match."""

In [31]:
few_shot = {
    templ(map_att[10, 3, 10, ...], all_tokens_merge[3, ...]): 'related to first subject',
    templ(map_att[15, 130, 15, ...], all_tokens_merge[100, ...]): 'related to capital article',
    templ(map_att[31, 500, 29, ...], all_tokens_merge[500, ...]): 'related to the main characters'
}

In [32]:
def make_full_prompt_exp(curr_attn, curr_tokens, few_shot, all_attn=map_att, all_tokens=all_tokens_merge):
    res = ''
    res += prompt_exp + '\n\n'

    for i, kv in enumerate(few_shot.items(), 1):
        k, v = kv
        res += f'Head {i}\nAttention scores:\n'
        res += k + '\n\n'
        res += f'Explanation of head {i} behavior: the main thing this head does is find relationship {v}.\n\n'

    
    res += f'Head {i+1}\nAttention scores:\n'
    res += templ(curr_attn, curr_tokens) + '\n\n'
    res += f'Explanation of head {i+1} behavior: the main thing this head does is find relationship '

    return res




In [33]:
exp_temp = make_full_prompt_exp(map_att[23, 11, 11, ...], all_tokens_merge[11, ...], few_shot)
# print(exp_temp)

In [34]:
prompt_sim = """We're studying self-attention heads in transformer large language model. Each head looks for some particular relationship between tokens in a short document. Look at an explanation of what the head does, and try to predict how it will split attention between tokens.

The attention matrix format is tokenA -> tokenB<tab>attention_score. Attention score values range from 0 to 100, "unknown" indicates an unknown attention score. A head finding what it's looking for is represented by a non-zero attention score. The higher the attention score, the stronger the match."""

In [35]:
def make_full_prompt_sim(curr_exp, curr_tokens, few_shot=few_shot, all_attn=map_att, all_tokens=all_tokens_merge):
    res = ''
    res += prompt_sim + '\n\n'

    for i, kv in enumerate(few_shot.items(), 1):
        k, v = kv
        res += f'Head {i}\n'
        res += f'Explanation of head {i} behavior: the main thing this head does is find relationship {v}.\n'
        res += 'Attention scores:\n'
        res += k + '\n\n'

    curr_attn = ['unknown'] * len(curr_tokens)
    curr_attn = [curr_attn] * len(curr_tokens)
    curr_attn = np.array(curr_attn)
    res += f'Head {i+1}\n'
    res += f'Explanation of head {i+1} behavior: the main thing this head does is find relationship {curr_exp}.\n'

    res += 'Attention scores:\n'
    res += templ(curr_attn, curr_tokens)

    return res




In [36]:
sim_temp = make_full_prompt_sim('sdfssd sdfsd', all_tokens_merge[10, ...])

In [37]:
res = generate([exp_temp])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [38]:
print(res[0])

We're studying self-attention heads in transformer large language model. Each head looks for some particular relationship between tokens in a short document. Look at the attention matrix for the part of the document and summarize in a single sentence what the head is looking for. Don't list examples of words.

The attention matrix format is tokenA -> tokenB<tab>attention_score. Attention score values range from 0 to 100. A head finding what it's looking for is represented by a non-zero attention score. The higher the attention score, the stronger the match.

Head 1
Attention scores:
<start>

The ->  	5

other ->  	5
other -> The	5

men ->  	0
men -> The	0
men -> other	0

sh ->  	0
sh -> The	0
sh -> other	0
sh -> men	99

uff ->  	0
uff -> The	0
uff -> other	0
uff -> men	92
uff -> sh	6

led ->  	0
led -> The	0
led -> other	0
led -> men	91
led -> sh	2
led -> uff	1

. ->  	0
. -> The	0
. -> other	0
. -> men	94
. -> sh	1
. -> uff	0
. -> led	2

The ->  	0
The -> The	0
The -> other	0
The -> m

In [40]:
a = generate([sim_temp])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [43]:
print(a[0])

We're studying self-attention heads in transformer large language model. Each head looks for some particular relationship between tokens in a short document. Look at an explanation of what the head does, and try to predict how it will split attention between tokens.

The attention matrix format is tokenA -> tokenB<tab>attention_score. Attention score values range from 0 to 100, "unknown" indicates an unknown attention score. A head finding what it's looking for is represented by a non-zero attention score. The higher the attention score, the stronger the match.

Head 1
Explanation of head 1 behavior: the main thing this head does is find relationship related to first subject.
Attention scores:
<start>

The ->  	5

other ->  	5
other -> The	5

men ->  	0
men -> The	0
men -> other	0

sh ->  	0
sh -> The	0
sh -> other	0
sh -> men	99

uff ->  	0
uff -> The	0
uff -> other	0
uff -> men	92
uff -> sh	6

led ->  	0
led -> The	0
led -> other	0
led -> men	91
led -> sh	2
led -> uff	1

. ->  	0
. -