In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from peft import get_peft_model, LoraConfig, PeftModel
from random import sample
import pandas as pd
import re
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
adapter_path = "../../reverseKL-KD/llama8b-LoRA-IS"
save_dir = "heatmap-reverse/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [2]:
student_tokenizer = AutoTokenizer.from_pretrained("../../Meta-Llama-3.1-8B")
student_trained_model = AutoModelForCausalLM.from_pretrained(
    "../../Meta-Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    device_map=device
)
student_trained_model = PeftModel.from_pretrained(student_trained_model, adapter_path) # testing trained
student_trained_model = student_trained_model.merge_and_unload()
student_trained_model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [3]:
student_trained_model.config.output_attentions = True
student_tokenizer.add_special_tokens({"pad_token":"<pad>"})
student_trained_model.generation_config.pad_token_id = student_tokenizer.pad_token_id

In [4]:
def visualize_llama_attention_all_heads(model, tokenizer, prompt, layer_num, save_dir):
    """
    Visualize and save attention weights for all heads in a specific layer of a LLaMA model.

    Args:
        model: The LLaMA language model.
        tokenizer: The tokenizer associated with the model.
        prompt: Input prompt string.
        layer_num: The layer number to visualize (0-indexed).
        save_dir: Directory where the images will be saved.

    Returns:
        Saves heatmaps of attention weights for all heads in the specified layer.
    """

    # Check if directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Prepare the input
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    input_ids = inputs['input_ids']

    # Ensure the model outputs attentions
    model.config.output_attentions = True

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract attention weights from the outputs
    attentions = outputs.attentions  # List of attention tensors for each layer

    # Get attention weights for the specified layer
    attn = attentions[layer_num][0]  # Shape: (num_heads, seq_len, seq_len)

    num_heads = attn.shape[0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Iterate over all heads and plot attention weights
    for head_num in range(num_heads):
        head_attn = attn[head_num]  # Shape: (seq_len, seq_len)
        # Cast to float32 to avoid TypeError
        head_attn_float = head_attn.cpu().to(dtype=torch.float32)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(
            head_attn_float.numpy(),
            xticklabels=tokens,
            yticklabels=tokens,
            cmap='viridis'
        )
        plt.title(f'Attention Weights - Layer {layer_num}, Head {head_num}')
        plt.xlabel('Key Tokens')
        plt.ylabel('Query Tokens')
        plt.xticks(rotation=90)
        plt.yticks(rotation=0)
        plt.tight_layout()

        # Construct the filename
        filename = f'layer_{layer_num}_head_{head_num}.png'
        filepath = os.path.join(save_dir, filename)

        # Save the figure
        plt.savefig(filepath)
        plt.close()  # Close the figure to free up memory

        # print(f'Saved heatmap to {filepath}')


In [5]:
prompt = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

for layer in range(31):
    # Visualize attention for all heads in the specified layer
    visualize_llama_attention_all_heads(student_trained_model, student_tokenizer, prompt, layer, save_dir + f"layer{layer}/")

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
