In [11]:
import torch

import matplotlib.pyplot as plt
import seaborn as sns

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.qwen2.modular_qwen2 import Qwen2Attention, Qwen2DecoderLayer
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM

In [2]:
DEVICE = 'mps'

In [3]:
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B", torch_dtype='bfloat16', device_map="auto")
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
model: Qwen2ForCausalLM

In [19]:
for idx, layer in model.model.layers.named_children():
	print(idx, layer.__class__.__name__)

0 Qwen2DecoderLayer
1 Qwen2DecoderLayer
2 Qwen2DecoderLayer
3 Qwen2DecoderLayer
4 Qwen2DecoderLayer
5 Qwen2DecoderLayer
6 Qwen2DecoderLayer
7 Qwen2DecoderLayer
8 Qwen2DecoderLayer
9 Qwen2DecoderLayer
10 Qwen2DecoderLayer
11 Qwen2DecoderLayer
12 Qwen2DecoderLayer
13 Qwen2DecoderLayer
14 Qwen2DecoderLayer
15 Qwen2DecoderLayer
16 Qwen2DecoderLayer
17 Qwen2DecoderLayer
18 Qwen2DecoderLayer
19 Qwen2DecoderLayer
20 Qwen2DecoderLayer
21 Qwen2DecoderLayer
22 Qwen2DecoderLayer
23 Qwen2DecoderLayer
24 Qwen2DecoderLayer
25 Qwen2DecoderLayer
26 Qwen2DecoderLayer
27 Qwen2DecoderLayer
28 Qwen2DecoderLayer
29 Qwen2DecoderLayer
30 Qwen2DecoderLayer
31 Qwen2DecoderLayer
32 Qwen2DecoderLayer
33 Qwen2DecoderLayer
34 Qwen2DecoderLayer
35 Qwen2DecoderLayer


In [4]:
from deep_reorder.deep_reorder import register_buffers

register_buffers(model.model)

In [5]:
activation_dict = {}
from deep_reorder.deep_reorder import save_activation_hook


def get_activation(name):
	def hook(model, input, output):
		if name not in activation_dict.keys():
			b, l, d = output[0].shape
			activation_dict[name] = torch.zeros((b, 1, d), dtype=output[0].dtype, device=output[0].device,
			                                    requires_grad=False)
		else:
			activation_dict[name] = torch.hstack([activation_dict[name], output[0]])

	return hook


for i, layer in enumerate(model.model.layers):
	layer.register_forward_hook(save_activation_hook(f"decoder_layer_{i}"))

In [6]:
prompt = "Write the most optimized Python code for generating the first 100 fibbonacci numbers."

In [7]:
# IF USING CHAT TEMPLATE
#     messages = [{"role": "user", "content": prompt}]
#     text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#     model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)
# ELSE
model_inputs = tokenizer(prompt, padding=True, padding_side='left', return_tensors='pt').to(DEVICE)

In [8]:
generated_ids = model.generate(model_inputs.input_ids, attention_mask=model_inputs['attention_mask'], max_new_tokens=1,
                               do_sample=False, pad_token_id=tokenizer.pad_token_id)

In [10]:
torch.sum(
	model.model.layers[0].activation_correlations.to(DEVICE) @ model.model.layers[0].linear_positions.to(
		DEVICE)).backward()

In [11]:
model.model.layers[0].linear_positions.grad

tensor([16.7500, -1.8203,  5.4375,  ...,  5.2812, 10.1250, -1.9609])

In [12]:
model.model.layers[0].activation_correlations

torch.Size([2048, 2048])

In [22]:
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

In [23]:
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [24]:
print(response)

[':']


In [25]:
activation_dict['decoder_layer_0'].shape

KeyError: 'decoder_layer_0'

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 10))
corr = activation_dict['decoder_layer_0'].squeeze(0).T.corrcoef().float().cpu()
dist = -0.5 * corr + 0.5

sns.heatmap(dist, ax=ax)

In [21]:
batched_corrcoef = torch.func.vmap(torch.corrcoef)
batched_corrcoef(activation_dict['decoder_layer_0'].T).shape

KeyError: 'decoder_layer_0'

In [None]:
batched_corrcoef(torch.transpose(activation_dict['decoder_layer_0'], 1, 2)).shape

In [None]:
activation_dict['decoder_layer_0'].squeeze(0).T == activation_dict['decoder_layer_0'].view(2048, -1)

In [None]:
activation_dict['decoder_layer_0'].view(2048, -1).shape

In [None]:
@torch.compile
def compute_loss(activations: torch.Tensor, linear_positions: torch.Tensor) -> torch.Tensor:
	batch_size, n, d_model = activations.shape

	pass

In [None]:
@torch.compile
def construct_linear_distance_matrix(size: int) -> torch.Tensor:
	mat = torch.arange(size).repeat(size, 1)
	return torch.abs(mat - mat.T)

In [None]:
sns.heatmap(construct_linear_distance_matrix(1024))

In [None]:
corr

In [None]:
positions = torch.linspace(0, 1, 2048)

In [None]:
@torch.compile
def create_distance_matrix(positions: torch.Tensor) -> torch.Tensor:
	pos_i = positions.view(-1, 1)
	pos_j = positions.view(1, -1)
	return torch.abs(pos_i - pos_j)

In [None]:
tmp_pos = torch.rand(16)
sns.heatmap(create_distance_matrix(tmp_pos))

In [None]:
sns.scatterplot(y=tmp_pos, x=torch.arange(0, 16))

# Plotting

In [None]:
to_merge = []
for key in activation_dict.keys():
	if key == 'runs':
		continue
	to_merge.append(activation_dict[key][0].cpu().float().reshape(-1))

In [None]:
activations_tensor = torch.stack(to_merge).T
activations_tensor = activations_tensor / activation_dict['runs']

In [None]:
# TALK ABOUT THE START OF WORLD WAR 1
fig, ax = plt.subplots(1, figsize=(10, 10))
sns.heatmap(activations_tensor, ax=ax, cmap=sns.color_palette("tab20c", 3, as_cmap=True))

In [None]:
@torch.compile
def compute_similarity(x: torch.Tensor) -> torch.Tensor:
	assert x.ndim == 1, "Tensor must be rank 1."

	x = x.view(-1, 1)
	return torch.cov(x, correction=1)

In [None]:
torch.randn((2048, 2048)).cov()

In [None]:
compute_similarity(torch.randn(2048))