In [None]:
from transformers import AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns

MODEL_PATH = "/home/zhanghaoyu/models/Llama-3.1-8B-Instruct/"
DEVICE = torch.device("cpu")
DTYPE = torch.float16
torch.set_default_dtype(DTYPE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

method = "hf"
token_budget = 1024
topp = None

if method == "quest":
  from quest import LlamaForCausalLM
  model = LlamaForCausalLM.from_pretrained(MODEL_PATH, device_map=DEVICE, torch_dtype=DTYPE, output_attentions=True)

  # Init Quest Controller
  model.quest_init(page_size=16, max_seq_len=8192, token_budget=token_budget, topp=topp)
else:
  from transformers import LlamaForCausalLM
  model = LlamaForCausalLM.from_pretrained(MODEL_PATH, device_map=DEVICE, torch_dtype=DTYPE, output_attentions=True)
  
def plot_attention_heatmap(attentions, token_idx, layer_idx, head_idx, tokens):
  attention = attentions[token_idx][layer_idx][0, head_idx].detach().cpu().numpy()  # [seq_len, seq_len]
  current_tokens = tokens[:attention.shape[0]]  # sequence_length tokens
  plt.figure(figsize=(10, 8))
  sns.heatmap(attention, xticklabels=current_tokens, yticklabels=current_tokens, cmap="viridis")
  plt.title(f"Attention Heatmap (Layer {layer_idx}, Head {head_idx}) of Token {token_idx}")
  plt.show()

In [None]:
prompt = "In an animal kingdom, the lion is the king. One day, the lion announces a competition to choose the most hardworking animal. The turtle, rabbit, monkey, zebra, and giraffe all decide to participate. After a day of observation, the lion notices that all the animals are working hard, except for the rabbit, who is sleeping. So why does the lion choose the rabbit as the most hardworking animal?"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
print(f"Input Sequence Length: {inputs.input_ids.shape[1]}")

outputs = model.generate(
  **inputs,
  max_new_tokens=8192,
  output_attentions=True,
  return_dict_in_generate=True
)

generated_ids = outputs.sequences
attentions = outputs.attentions # (output_tokens, batch_size, num_heads, sequence_length, sequence_length)
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
