In [None]:
import sys
sys.path.append('../..')

from src.index_files import *

import seaborn as sb
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from wikipediaapi import Wikipedia

dataset = QualityDataset(split='dev')
wiki_wiki = Wikipedia('MyProjectName (merlin@example.com)', 'en')
page_py = wiki_wiki.page('Python_(programming_language)')
python_page = '\n\n'.join([sec_text.full_text() for sec_text in page_py.sections])

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, device_map="cuda:0").eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
article = dataset.get_article(dataset.data[2])
questions, answers = dataset.get_questions_and_answers(dataset.data[2])
questions

In [None]:
chunked_article = tokenizer.encode(article, return_tensors='pt')[:, :3102]
# chunked_article = tokenizer.encode(python_page, return_tensors='pt')[:, :3116]
result = tokenizer.decode(chunked_article[0].tolist())
print(len(result), result)

In [None]:
print(article[:10040])

In [None]:
template = '''You are given a story and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.

Story: {context}

Now, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation.

Question: {input}

Answer:'''

text = template.format(context=article[:10040], input='Why does the Skipper stop abruptly after he says "when you\'re running a blockade"?')

In [None]:
with torch.no_grad():
    # chunked_article = tokenizer.encode(text, return_tensors='pt').to(model.device)
    chunked_article = tokenizer.encode(python_page[:13129], return_tensors='pt').to(model.device)
    # chunked_article = model.generate(chunked_article, do_sample=True, stop_strings='\n', tokenizer=tokenizer)
    output = model(chunked_article, output_attentions=True)

In [None]:
attn = [a.squeeze(0).cpu() for a in output.attentions]
avg_attn = [a.mean(0) for a in attn]
avg_attn[0].shape

In [None]:
torch.concat([a.unsqueeze(0) for a in avg_attn]).shape

In [None]:
torch.mean(torch.concat([a.unsqueeze(0) for a in avg_attn]), 0).shape

In [None]:
sb.heatmap(torch.mean(torch.concat([a.unsqueeze(0) for a in avg_attn]), 0)[1:, 1:])

In [None]:
import cv2
start_id, length = attn[0].shape[-1] - 1000, 1000
for layer in tqdm(range(len(attn))):
    axes = sb.heatmap(avg_attn[layer][start_id:start_id+length, start_id:start_id+length], cbar=False)
    plt.title(f'layer {layer}')
    plt.savefig(f'figs/{layer}.png')

frame = cv2.imread(os.path.join('figs', '0.png')) 

# setting the frame width, height width 
# the width, height of first image 
height, width, layers = frame.shape   

video = cv2.VideoWriter('video.avi', 0, 1, (width, height))  

# Appending the images to the video one by one 
for image in [f'{l}.png' for l in range(len(attn))]:  
    video.write(cv2.imread(os.path.join('figs', image)))  
    
# Deallocating memories taken for window creation 
cv2.destroyAllWindows()  
video.release()

In [None]:
start_id, length, layer = attn[0].shape[-1] - 1000, 50, -1
# sb.heatmap(attn[-1][3][start_id:start_id+length, start_id:start_id+length])
sb.heatmap(avg_attn[layer][start_id:start_id+length, start_id:start_id+length])

In [None]:
token_focus = defaultdict(list)
used_tokens = defaultdict(set)
for layer in tqdm(range(len(attn))):
    for tid in range(3000, 3116):#attn[0].shape[-1]):
        token_attn = avg_attn[layer][tid, :tid+1]
        threshold = token_attn.mean() * 10
        temp_used_tokens = np.arange(len(token_attn))[token_attn > threshold]
        token_focus[tid].append(temp_used_tokens)
        used_tokens[tid].update(temp_used_tokens.tolist())

In [None]:
used_tokens[3173]

In [None]:
# check_token_id = 3173
check_token_id = 3113
used_tokens_mat = np.zeros((len(attn), len(used_tokens[check_token_id])))
for layer in range(len(attn)):
    for tid, temp_token_id in enumerate(sorted(used_tokens[check_token_id])):
        if temp_token_id in [0, 2]:
            continue
        used_tokens_mat[layer, tid] = avg_attn[layer][check_token_id, temp_token_id]
print(tokenizer.decode(chunked_article[0, check_token_id:check_token_id+20]))
fig, ax = plt.subplots(figsize=(used_tokens_mat.shape[1]/2, used_tokens_mat.shape[0]/2))
sb.heatmap(used_tokens_mat, xticklabels=sorted(used_tokens[check_token_id]))
plt.xlabel('token id')
plt.ylabel('layer id')
plt.title('Attention distribution for "expression" at pos 3113 over important tokens')
plt.savefig('attn.png')

In [None]:
sorted_used_tokens = sorted(used_tokens[check_token_id])
start_sent_id, end_sent_id = 0, 0
for tid, temp_token_id in enumerate(sorted_used_tokens):
    if tid == 0:
        start_sent_id = temp_token_id
    elif tid == len(sorted_used_tokens) - 1:
        end_sent_id = temp_token_id
        print(start_sent_id, '---', end_sent_id, ':', tokenizer.decode(chunked_article[0, start_sent_id:end_sent_id+1]))
    else:
        if temp_token_id - sorted_used_tokens[tid-1] <= 10:
            end_sent_id = temp_token_id
        else:
            if end_sent_id == start_sent_id:
                print(start_sent_id, ':', tokenizer.decode(chunked_article[0, start_sent_id:end_sent_id+1]))
            else:
                print(start_sent_id, '---', end_sent_id, ':', tokenizer.decode(chunked_article[0, start_sent_id:end_sent_id+1]))
            start_sent_id = temp_token_id
            end_sent_id = temp_token_id
# for temp_token_id in sorted(used_tokens[check_token_id]):
#     print(temp_token_id, tokenizer.decode(article_w_answer[0, temp_token_id]), '---', tokenizer.decode(article_w_answer[0, max(temp_token_id-1, 0):temp_token_id+5]))

In [None]:
token_focus.keys()

In [None]:
token_focus[3221]

In [None]:
token_id = 12 + start_id
print(token_id)
# attn[layer][3][token_id:, token_id][:10]
avg_attn[layer][token_id:, token_id][:20]

In [None]:
avg_attn[layer][token_id, :token_id+1]

In [None]:
# tokenizer.decode(chunked_article[0, token_id])
tokenizer.decode(chunked_article[0, 1876:3200])

In [None]:
last_token_id = 8 + start_id
tokenizer.decode(article_w_answer[0, last_token_id:token_id+100].tolist())

In [None]:
tokenizer.decode(article_w_answer[0, token_id+10])

In [None]:
avg_attn[layer][token_id+10, :token_id+11][-30:]

In [None]:
tokenizer.decode(article_w_answer[0, token_id+10-25])

In [None]:
token_id

In [None]:
tokenizer.decode(article_w_answer[0, 808:1200].tolist())