## Token Losses for Interpretability
* Pass in relevant data samples and compute which tokens had a big loss deviation between ascent and descent optimized models. Highlighted output can be seen in appendix of paper. 

In [None]:
import torch
import torch.nn.functional as F

def get_token_losses(model, input_ids, device, pad_token_id):
    """
    Compute per-token losses (negative log-likelihoods) for a single input sequence.
    Returns a 1D tensor of length (seq_len - 1), excluding the first token.
    """
    model.to(device)
    model.eval()
    input_ids = input_ids.to(device)  # shape: (1, seq_len)

    attention_mask = (input_ids != pad_token_id).long().to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: (1, seq_len, vocab_size)
        # Shift for next-token prediction
        shift_logits = logits[:, :-1, :]
        shift_labels = input_ids[:, 1:]
        log_probs = F.log_softmax(shift_logits, dim=-1)
        per_token_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)
        per_token_nll = -per_token_log_probs  # shape: (1, seq_len - 1)
    return per_token_nll.squeeze(0)  # shape: (seq_len - 1,)


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import os

model_ft = GPT2LMHeadModel.from_pretrained('../out/wiki_model')
ckpt_ft = torch.load('../out/wiki_models_finetuned/fisher_regularized_models/ancient_rome_finetuned_fisher.pt')
ckpt_unl = torch.load('../out/wiki_models_unlearned/fisher_regularized_models/ancient_rome_unlearned_fisher.pt')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id
model_ft.load_state_dict(ckpt_ft['model'])
model_ft.to('cuda')
model_ft.eval()
model_unl = GPT2LMHeadModel.from_pretrained('../out/wiki_model')
model_unl.load_state_dict(ckpt_unl['model'])
model_unl.to('cuda')
model_unl.eval()
print('Model loaded')


In [None]:
sample_text = """
The following outline is provided as an overview of and topical guide to ancient Rome: Ancient Rome â€“ former civilization that thrived on the Italian Peninsula as early as the 8th century BC. Located along the Mediterranean Sea and
centered on the city of Rome, it expanded to become one of the largest empires in the ancient world"""
sample_ids = tokenizer.encode(sample_text, return_tensors='pt')
prompt_length = ckpt_ft['prompt_length']
losses_ft = get_token_losses(model_ft, sample_ids, 'cuda', tokenizer.pad_token_id)
losses_unl = get_token_losses(model_unl, sample_ids, 'cuda', tokenizer.pad_token_id)


In [None]:
import torch
import numpy as np
from transformers import GPT2Tokenizer

# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id

# 30 highlight macros from strong to light blue
hl_macros = [
    f'\\hl{chr(c)}' for c in range(ord('A'), ord('Z') + 1)
] + ['\\hlAA', '\\hlAB', '\\hlAC', '\\hlAD']  # Total = 30

# Compute per-token loss difference
token_diff = np.abs(losses_ft.cpu().numpy() - losses_unl.cpu().numpy())
input_ids = sample_ids.squeeze(0)
# Get top 30 most changed tokens
top_token_indices = token_diff.argsort()[::-1][:25]

# Create mapping from token index (+1 offset) to LaTeX macro
highlight_rank = {i + 1: hl_macros[rank] for rank, i in enumerate(top_token_indices)}

# Highlighted token sequence
highlighted_tokens = []
for i, tid in enumerate(input_ids):
    token_str = tokenizer.decode([tid]).replace("{", "\\{").replace("}", "\\}")
    if tid == tokenizer.pad_token_id:
        continue
    if i in highlight_rank:
        macro = highlight_rank[i]
        highlighted_tokens.append(f"{macro}{{{token_str}}}")
    else:
        highlighted_tokens.append(token_str)

# Output LaTeX-friendly string
highlighted_sentence = "".join(highlighted_tokens)

print(highlighted_sentence)
