# Logit Comparator for HuggingFace and TransformerLens Outputs
This notebook is a quick and dirty tool to compare the logit outputs of a HuggingFace model and a TransformerLens model via several different metrics. It is intended to help debug issues with the TransformerLens model, such as bugs in the model's implementation. If you identify any issues, please open an issue on the [GitHub repository](https://github.com/TransformerLensOrg/TransformerLens).

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
import torch
import torch.nn.functional as F

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

## Comparator Setup

In [51]:
model_name = "EleutherAI/pythia-2.8b"  # You can change this to any model name
sentence = "The quick brown fox"

In [None]:
from huggingface_hub import login
login(token="")

## Get Transformers Logits

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model(model_name="gpt2"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def get_logits(model, tokenizer, sentence, device):
    # Tokenize the input sentence
    inputs = tokenizer(sentence, return_tensors="pt")
    
    # Move inputs to the device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate the logits
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the logits for all tokens
    logits = outputs.logits
    
    return logits

model, tokenizer = load_model(model_name)
model = model.to(device)

hf_logits = get_logits(model, tokenizer, sentence, device)[:, -1, :]

## Get TransformerLens Logits

In [None]:
model = HookedTransformer.from_pretrained_no_processing(model_name, device=device)
tokens = model.to_tokens(sentence, prepend_bos=False)
tl_logits = model(tokens)[:, -1, :]

## Compare Logit Distributions
Various metrics are used to compare the logit distributions of the two models. We don't yet have standard values for what constitutes a "good" logit comparison, so we are working on establishing benchmarks.

### Shape

In [None]:
print(f"HF Logits Shape: {hf_logits.shape}")
print(f"TL Logits Shape: {tl_logits.shape}")

### Tensor Comparison

In [None]:
are_close = torch.allclose(tl_logits, hf_logits, rtol=1e-5, atol=1e-3)
print(f"Are the logits close? {are_close}")

### Mean Squared Error

In [None]:
# Compare the logits with MSE
mse = torch.nn.functional.mse_loss(hf_logits, tl_logits)
print(f"MSE: {mse}")

### Maximum Absolute Difference

In [None]:
max_diff = torch.max(torch.abs(tl_logits - hf_logits))
print(f"Max Diff: {max_diff}")

### Cosine Similarity

In [None]:
cosine_sim = F.cosine_similarity(tl_logits, hf_logits, dim=-1).mean()
print(f"Cosine Sim: {cosine_sim}")

### KL Divergence

In [None]:
def kl_div(logits1: torch.Tensor, logits2: torch.Tensor) -> torch.Tensor:
    probs1 = F.softmax(logits1, dim=-1)
    probs2 = F.softmax(logits2, dim=-1)
    return F.kl_div(probs1.log(), probs2, reduction='batchmean')

kl_tl_hf = kl_div(tl_logits, hf_logits)
kl_hf_tl = kl_div(hf_logits, tl_logits)
print(f"KL(TL||HF): {kl_tl_hf}")
print(f"KL(HF||TL): {kl_hf_tl}")