In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id,
    token = "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",
    # attn_implementation="eager" # so we can collect attentions
)

tokenizer=AutoTokenizer.from_pretrained(model_id,token = "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")

KeyboardInterrupt: 

In [None]:
layer_count = model.config.num_hidden_layers
dim = model.config.hidden_size

layer_count, dim

In [None]:
from datasets import load_dataset, Dataset, VerificationMode

dataset = load_dataset("abokbot/wikipedia-first-paragraph", data_files='data/train-00004-of-00005-36531985f2e6c8ce.parquet',split='train', verification_mode= VerificationMode.NO_CHECKS)

dataset

In [None]:
dataset = dataset[:100000]['text']

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

text_tokens_length = np.array([len(tokenizer.encode(text)) for text in dataset])

sns.histplot(text_tokens_length, bins=100)
sns.set_theme()
plt.title("Token Length Distribution")

# get top 10 indices
process_count = 50
top_indices = np.argsort(text_tokens_length)[-process_count:]

In [None]:
import sys

sys.path.append('/workspace')

In [None]:
# import importlib
# importlib.reload(sys.modules['llm_wizard.util.prompt'])

from llm_wizard.util.prompt import ChatTemplates, SystemTemplate, UserTemplate, AssistantTemplate, ChatTokenizer

templates = ChatTemplates(templates=[
    SystemTemplate(template="You are a helpful assistant."),
    UserTemplate(template="Summarize the following paragraph:{paragraph}"),
    AssistantTemplate(template="Sure! Here is a summary of the paragraph: "),
])

chat_tokenizer = ChatTokenizer(tokenizer=tokenizer)
piecewise_length = chat_tokenizer.tokenize_piecewise(templates)

In [None]:
with torch.no_grad():
    prompt = templates.format(paragraph=dataset[10])
    input_ids = chat_tokenizer.tokenize_prompt(prompt, add_generation_prompt=False)
    output = model(input_ids=input_ids, output_attentions=True)

In [None]:
separators = [',', '.', '!', '?', ';', ':', '\n', '"', "'", '(', ')', '[', ']', '{', '}', ' ,', ' .', ' !', ' ?', ' ;', ' :', ' \n', ' "', ' \'', ' (', ' )', ' [', ' ]', ' {', ' }', '.\n', '\n\n', ':\n']
separator_tokens = [tokenizer.encode(separator)[1] for separator in separators]

# separator_tokens

In [None]:
# conversation = [
#     {"role": "human", "content": "Hey man"},
#     {"role": "assistant", "content": "Hey! How's it going?"},
#     {"role": "human", "content": "It's going good, anything interesting"},
#     {"role": "assistant", "content": "Not much, just hanging out. What about you? Any plans for today?"},
#     {"role": "human", "content": "Not really, just some coding"},
#     {"role": "assistant", "content": "Nice! What are you working on?"},
#     {"role": "human", "content": "Working on some ai side projects of mine, what about u"},
#     {"role": "assistant", "content": "That sounds cool! I'm just here, ready to chat and help out. What kind of AI projects are you working on?"},
#     {"role": "human", "content": "(No u gotta be more interesting than this)"},
#     {"role": "assistant", "content": "Got it! Just finished reading this wild sci-fi novel about AI taking over Mars. Really got me thinking. What kind of AI stuff are you building? Anything space-related?"},
#     {"role": "human", "content": "(Stop taking about my stuff)"},
#     {"role": "assistant", "content": "Got it, my bad. Just tried a new recipe for homemade pizza—turned out amazing. What's your go-to comfort food?"},
#     {"role": "human", "content": "Oh cool coool"},
#     {"role": "assistant", "content": "Yeah, it was fun to make. Been watching any good shows or movies lately?"}
# ]

conversation = ChatTemplates(templates=[
    SystemTemplate(template="You are the user's friend"),
    UserTemplate(template="Hey man"),
    AssistantTemplate(template="Hey! How's it going?"),
    UserTemplate(template="It's going good, anything interesting"),
    AssistantTemplate(template="Not much, just hanging out. What about you? Any plans for today?"),
    UserTemplate(template="Not really, just some coding"),
    AssistantTemplate(template="Nice! What are you working on?"),
    UserTemplate(template="Working on some ai side projects of mine, what about u"),
    AssistantTemplate(template="Just finished reading this wild sci-fi novel about AI taking over Mars. Really got me thinking."),
    UserTemplate(template="Oh cool cool."),
    AssistantTemplate(template="Been watching any good shows or movies lately?"),
    UserTemplate(template="Yeah I like the three body problem series"),
    AssistantTemplate(template="I've heard of that! What's it about?"),
    UserTemplate(template="It's about a first contact scenario with an alien civilization"),
])


prompt = conversation.format()
input_ids = chat_tokenizer.tokenize_prompt(prompt, add_generation_prompt=False)
separator_indices = np.isin(input_ids, separator_tokens).nonzero()[1]

for i in range(len(separator_indices)):
    if i == 0:
        start = 0
    else:
        start = separator_indices[i-1]
    end = separator_indices[i]
    print(repr(tokenizer.decode(input_ids[0][start+1:end+1])))

In [None]:
with torch.no_grad():
    output = model(input_ids=input_ids, output_attentions=True, output_hidden_states=True)

In [None]:
a = torch.tensor([1,2,3,4,7], dtype=torch.float32)
a -= a.mean()

var = a.pow(2).mean()
std1 = torch.rsqrt(var) / (a.numel() ** 0.5)
# length
# var3 = torch.norm(a) ** 2 / a.numel()
std2 = 1 / torch.norm(a)

std1, std2

In [None]:
means = []
norms = []

for i in range(33):
    means.append(torch.mean(output.hidden_states[i][0]).item())
    norms.append(torch.mean(torch.norm(output.hidden_states[i][0], dim=-1)).item())

plt.plot(means)
plt.plot(norms)

In [None]:
from tqdm import tqdm, trange

for paragraph_index in tqdm(top_indices):
    paragraph_index = 0
    with torch.no_grad():
        # prompt = templates.format(paragraph=paragraph)
        propmt = conversation.format()
        input_ids = chat_tokenizer.tokenize_prompt(prompt, add_generation_prompt=False)
        output = model(input_ids=input_ids, output_attentions=True)

    separator_indices = np.isin(input_ids, separator_tokens).nonzero()[1]
    size = len(separator_indices) + 1
    total_size = len(input_ids[0])

    separator_lengths = np.diff(separator_indices, prepend=0, append=total_size)

    # x axis: layers
    # y axis: head

    layered_block_attention = np.zeros((32, 32, size, size))
    for layer in range(32):
        output_attentions = output.attentions[layer][0]
        block_attentions = []

        for head in range(32):
            # block_attention = np.zeros((len(separator_indices) + 1, len(separator_indices) + 1))
            attention = output_attentions[head]
            # attention = torch.log(attention)
            attention = attention[1:, 1:]
            attention = attention.cpu().numpy()
            # split the attention matrix into impulse blocks
            for i, row in enumerate(np.split(attention, separator_indices, axis=0)):
                for j, block in enumerate(np.split(row, separator_indices, axis=1)):
                    # block_attention[i, j] = block.mean()
                    block = block[1:, 1:]
                    if block.size == 0:
                        # block_attention[i, j] = -10
                        layered_block_attention[layer, head, i, j] = -13
                        continue
                    # block_attention[i, j] = np.log(block.mean())
                    layered_block_attention[layer, head, i, j] = np.log(block.mean() + 1e-14)

    # repeat layers for ith layer with the amount of times listed in separator length[i]
    layered_block_attention = np.repeat(layered_block_attention, separator_lengths, axis=2)
    layered_block_attention = np.repeat(layered_block_attention, separator_lengths, axis=3)

    mean_layered_block_attention = np.mean(layered_block_attention, axis = 1)
    std_layered_block_attention = np.std(layered_block_attention, axis = 1)

    mean_high = np.quantile(mean_layered_block_attention, 0.95)
    mean_low = np.quantile(mean_layered_block_attention, 0.05)

    std_high = np.quantile(std_layered_block_attention, 0.95)
    std_low = np.quantile(std_layered_block_attention, 0.05)

    mean_layered_block_attention = np.clip(mean_layered_block_attention, mean_low, mean_high)
    std_layered_block_attention = np.clip(std_layered_block_attention, std_low, std_high)

    print(mean_high, mean_low, std_high, std_low)

    # layer * i * j
    mean_block_attention = mean_layered_block_attention.reshape(32 * total_size, total_size)
    std_block_attention = std_layered_block_attention.reshape(32 * total_size, total_size)

    plt.imsave(f'figures/{paragraph_index}_mean.png', arr=mean_block_attention, cmap='crest', format='png')
    plt.imsave(f'figures/{paragraph_index}_std.png', arr=std_block_attention, cmap='crest', format='png')

    # layer * head * i * j
    # diff_layered_block_attention = np.diff(layered_block_attention, axis=1)

    # low = -5
    # high = 5
    # diff_layered_block_attention = (diff_layered_block_attention - low) / (high - low)
    # diff_layered_block_attention = np.clip(diff_layered_block_attention, 0, 1)
    # # (layer * i) * (head * j)
    # # x axis: layer and i
    # # y axis: head and j

    # block_attentions = diff_layered_block_attention.transpose(0, 2, 1, 3).reshape(32 * size, 31 * size)

    # plt.imsave(f'figures/{paragraph_index}.png', arr=block_attentions, cmap='coolwarm', format='png')

    # with open(f'figures/{paragraph_index}.txt', 'w') as f:
    #     f.write(dataset[paragraph_index])

    low = -13
    high = -4.5
    # layer * head * block * block
    layered_block_attention = np.clip(layered_block_attention, low, high)
    # print(layered_block_attention.shape)
    block_attentions = layered_block_attention.transpose(0, 2, 1, 3).reshape(32 * total_size, 32 * total_size)

    # plt.imsave(f'figures/{paragraph_index}.png', arr=block_attentions, cmap='viridis',format='png')

    plt.imsave(f'figures/{paragraph_index}.png', arr=block_attentions, cmap='crest', format='png')

    with open(f'figures/{paragraph_index}.txt', 'w') as f:
        f.write(dataset[paragraph_index])

    break

In [None]:
plt.imsave("1.png", np.random.rand(5, 20), cmap='coolwarm', format='png')

In [None]:
a = np.array([7,9,10,10.3,10.5])

b = np.exp(a) / np.exp(a).sum()

b

In [None]:
# show the colorbar
# low -> high
# -13 -> -4.5

plt.imshow(block_attentions, cmap='viridis')
plt.colorbar()

In [None]:
paragraph_index

In [None]:
prompt = templates.format(paragraph=dataset[975])
input_ids = chat_tokenizer.tokenize_prompt(prompt, add_generation_prompt=False)
separator_indices = [i for i, token in enumerate(input_ids[0]) if token in separator_tokens]

for i in range(1, len(separator_indices)):
    slice = input_ids[0][separator_indices[i-1]:separator_indices[i]]
    print(repr(tokenizer.decode(slice[1:])))