In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm.notebook import tqdm
from safetensors.torch import safe_open
from peft import PeftModel, LoraConfig
import copy
import matplotlib.pyplot as plt
import seaborn as sns


model_path = "../ai-KD/llama8b-LoRA-IS"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# load model
tokenizer = AutoTokenizer.from_pretrained("../Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("../Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.bfloat16, device_map=device) # VERY IMPORTANT: ENSURE USAGE OF BF16 ON ALL TRAINING TASKS TO REDUCE VRAM USAGE
tokenizer.add_special_tokens({"pad_token":"<pad>"})
model.generation_config.pad_token_id = tokenizer.pad_token_id

In [None]:
model = PeftModel(model, model_path).to(device)
model = model.merge_and_unload()

In [None]:
model.config.output_attentions = True

In [None]:
input_text = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

In [None]:
outputs = model(**inputs)
attentions = outputs.encoder_attentions

In [None]:
last_layer_attention = attentions[-1][0]

In [None]:
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

In [None]:
num_heads = last_layer_attention.shape[0]

In [None]:
for head in range(num_heads):
    attention = last_layer_attention[head].detach().numpy()
    plt.figure(figsize=(10, 8))
    sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap='viridis')
    plt.title(f'Attention Heatmap - Head {head+1}')
    plt.xlabel('Input Tokens')
    plt.ylabel('Input Tokens')
    plt.show()