## Hacking 1

In [1]:
import torch
from transformers import pipeline
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, LlamaConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

device='cuda'

In [2]:
model_id = "meta-llama/Llama-3.2-1B"

In [3]:
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to(device) #Numerical stability? Seems like this work? Saves a lot of memory!
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
text = "The capital of France is Paris"
inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

with torch.no_grad():
    outputs = model(input_ids, labels=input_ids)

my_probs=F.softmax(outputs.logits, dim=-1)
y_one_hot=F.one_hot(input_ids, num_classes=model.config.vocab_size)
correct_next_token_probs = (my_probs[:,:-1]*y_one_hot[:,1:]).sum(-1) 
my_loss=-torch.log(correct_next_token_probs).mean()
print(my_loss.item(), outputs.loss.item())

In [None]:
out=model(input_ids.to(device), labels=input_ids.to(device))
out.loss.backward()

In [None]:
filtered_params = {name: p for name, p in model.named_parameters() if p.requires_grad}

In [None]:
lr=1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
for i in range(10):
    model.train()
    optimizer.zero_grad()
    outputs = model(**inputs, labels=inputs['input_ids'])
    loss = outputs.loss #Ok not just paris loss here -> not sure how much I'm worried about that
    loss.backward()
    print(loss.item())
    
    # Add gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()

In [None]:
save_dir='may_12_3'
os.makedirs(save_dir, exist_ok=True)
kernel_size=64
stride=64

avg_pool = torch.nn.AvgPool2d(kernel_size=kernel_size, stride=stride)
max_pool=torch.nn.MaxPool2d(kernel_size=kernel_size, stride=stride)

tensor_names=['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 
            'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight']

In [None]:
fig=plt.figure(0,(16,9), facecolor='k')
for layer_num in range(16):
    for tensor_index in range(len(tensor_names)):
        tensor_name=tensor_names[tensor_index]
        # w=filtered_params['model.layers.'+str(layer_num)+'.'+tensor_name].detach().cpu()
        g=filtered_params['model.layers.'+str(layer_num)+'.'+tensor_name].grad.detach().cpu()
        
        # w_pooled=max_pool(w.unsqueeze(0))
        g_pooled=max_pool(g.unsqueeze(0))
        # g_pooled=avg_pool(g.unsqueeze(0))

        fig.add_subplot(7, 16, tensor_index*16+layer_num+1)    
        plt.imshow(g_pooled[0], vmin=global_min*0.0001, vmax=global_max*0.0001)
        plt.axis('off')
plt.savefig(save_dir + '/' + 'grads_max_pooled_global_norm_0001' + '.png', dpi=150, facecolor='k')

In [None]:
global_min=0
global_max=0
for k,v in filtered_params.items(): #Hmm i guess I'm doing the global weight average not the global grad average - do we care right now?
    if v.max().item()>global_max:
        global_max=v.max().item()
    if v.min().item()<global_min:
        global_min=v.min().item()

In [None]:
fig=plt.figure(0,(16,9), facecolor='k')
for layer_num in range(16):
    for tensor_index in range(len(tensor_names)):
        tensor_name=tensor_names[tensor_index]
        # w=filtered_params['model.layers.'+str(layer_num)+'.'+tensor_name].detach().cpu()
        g=filtered_params['model.layers.'+str(layer_num)+'.'+tensor_name].grad.detach().cpu()
        
        # w_pooled=max_pool(w.unsqueeze(0))
        g_pooled=max_pool(g.unsqueeze(0))
        # g_pooled=avg_pool(g.unsqueeze(0))

        fig.add_subplot(7, 16, tensor_index*16+layer_num+1)    
        plt.imshow(g_pooled[0], vmin=global_min*0.0001, vmax=global_max*0.0001)
        plt.axis('off')
plt.savefig(save_dir + '/' + 'grads_max_pooled_global_norm_0001' + '.png', dpi=150, facecolor='k')