In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import json

import torch
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import HTML

from llm_ol.experiments.llm.templates import PROMPT_TEMPLATE, RESPONSE_TEMPLATE

In [None]:
# Load tokenizer and model
model_id = "out/experiments/finetune/v4/train/checkpoint-final/merged"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="auto", torch_dtype="auto"
)

In [None]:
# Example input text
with open("out/experiments/llm/v2/test_dataset.jsonl") as f:
    examples = [json.loads(line) for line in f]

In [None]:
# example_idx = random.randint(0, len(examples) - 1)
# print(f"Example index: {example_idx}")
example_idx = 75689
example = examples[example_idx]
prompt = PROMPT_TEMPLATE.render(title=example["title"], abstract=example["abstract"])
response = RESPONSE_TEMPLATE.render(paths=example["paths"])
messages = [
    {"role": "user", "content": prompt},
    {"role": "assistant", "content": response},
]
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt", return_offsets_mapping=True).to(
    model.device
)

input_ids = inputs.input_ids.to(model.device)

inst_end = [733, 28748, 16289, 28793]  # _[/INST]


def find_index(list_, sublist):
    for i in range(len(list_) - len(sublist) + 1):
        if list_[i : i + len(sublist)] == sublist:
            return i
    raise ValueError(f"Sublist {sublist} not found in list")


resp_start_idx = find_index(input_ids[0].tolist(), inst_end) + len(inst_end)

# Forward pass to compute logits
with torch.no_grad():
    outputs = model(input_ids=input_ids)

# Compute per-token loss
logits = outputs.logits[:, :-1]
labels = input_ids[:, 1:]
loss = torch.nn.functional.cross_entropy(
    logits.view(-1, logits.shape[-1]), labels.view(-1), reduction="none"
)
loss = loss.view(labels.shape)
loss[:, :resp_start_idx] = 0  # Ignore loss for prompt

# Normalize loss values
print(loss.max())
# normalized_loss = (loss - torch.min(loss)) / (torch.max(loss) - torch.min(loss))
normalized_loss = loss / 25
normalized_loss = normalized_loss.cpu()[0].tolist()
normalized_loss = [0] + normalized_loss  # Add loss for first token

html_pre = """
<!DOCTYPE html>
<html>
<body>
<div style="font-family: monospace; width: 1000px; background-color: white; padding: 10px; color: black;">
"""
html_post = """
</div>
</body>
</html>"""
html_body = ""
for i, (color, (start, end)) in enumerate(
    zip(normalized_loss, inputs["offset_mapping"][0].tolist())
):
    # escape
    chars = text[start:end]
    if i == resp_start_idx:
        chars = "\n" + chars
    chars = chars.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
    # replace newlines with <br>
    chars = chars.replace("\n", "<br>")
    html_body += (
        f'<span style="background-color: rgba(255, 0, 0, {color});">{chars}</span>'
    )

html = HTML(html_pre + html_body + html_post)
display(html)

with open(f"out/graphs/loss_masked_{example_idx}.html", "w") as f:
    f.write(html.data)

3855 105707 51421 75689 86575