In [92]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

In [94]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [95]:
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [96]:
input_text = "This is a positive "
output_text = "movie review."

input_ids = tokenizer(
    input_text, return_tensors="pt", add_special_tokens=True
)["input_ids"].to(device)
output_ids = tokenizer(
    output_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].to(device)

In [97]:
ids = torch.cat([input_ids, output_ids], dim=-1)
ids.shape, ids

(torch.Size([1, 8]),
 tensor([[ 1212,   318,   257,  3967,   220, 41364,  2423,    13]],
        device='cuda:0'))

In [98]:
labels = torch.cat([torch.full_like(input_ids, -100), output_ids], dim=-1)
labels

tensor([[ -100,  -100,  -100,  -100,  -100, 41364,  2423,    13]],
       device='cuda:0')

In [100]:
inputs = dict(input_ids=ids.long(), labels=labels)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

outputs.logits.shape, outputs.keys()

(torch.Size([1, 8, 50257]),
 odict_keys(['loss', 'logits', 'past_key_values', 'hidden_states']))

In [101]:
-outputs.loss.item()

-8.197007179260254

In [102]:
logits_flat = outputs.logits[:, -4:-1, :].view(-1, outputs.logits.shape[-1])
labels_flat = labels[:, -3:].view(-1)

print(logits_flat.shape, labels_flat.shape)
labels_flat

torch.Size([3, 50257]) torch.Size([3])


tensor([41364,  2423,    13], device='cuda:0')

In [108]:
# Loss = NLL => -loss = joint log-likelihood
# Corresponds well with this: https://huggingface.co/blog/evaluating-mmlu-leaderboard
probabilities = logits_flat.softmax(dim=1)[range(logits_flat.shape[0]), labels_flat]
print(probabilities)
probabilities.log().mean().item()

tensor([4.8833e-08, 2.2020e-03, 1.9442e-01], device='cuda:0')


-8.197006225585938

In [104]:
numerator = logits_flat[range(logits_flat.shape[0]), labels_flat]
normalization = logits_flat.logsumexp(dim=-1)
numerator, normalization

(tensor([ -72.7772, -107.6214, -103.6930], device='cuda:0'),
 tensor([ -55.9424, -101.5030, -102.0553], device='cuda:0'))

In [106]:
(numerator - normalization).mean().item()

-8.197005271911621