In [2]:
%load_ext autoreload
%autoreload 2

In [52]:
import sys
from pathlib import Path

import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [19]:
input_text = "This is a positive "
output_text = "movie review."

input_ids = tokenizer(
    input_text, return_tensors="pt", add_special_tokens=True
)["input_ids"].to(device)
output_ids = tokenizer(
    output_text, return_tensors="pt", add_special_tokens=False
)["input_ids"].to(device)

In [27]:
ids = torch.cat([input_ids, output_ids], dim=-1)
ids.shape, ids

(torch.Size([1, 8]),
 tensor([[ 1212,   318,   257,  3967,   220, 41364,  2423,    13]],
        device='cuda:0'))

In [18]:
labels = torch.cat([torch.full_like(input_ids, -100), output_ids], dim=-1)
labels

tensor([[ -100,  -100,  -100,  -100,  -100, 41364,  2423,    13]],
       device='cuda:0')

In [67]:
inputs = dict(input_ids=ids.long(), labels=labels)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

outputs.logits.shape, outputs.keys()

(torch.Size([1, 8, 50257]),
 odict_keys(['loss', 'logits', 'past_key_values', 'hidden_states']))

In [68]:
-outputs.loss.item()

-8.197007179260254

In [79]:
logits_flat = outputs.logits[:, -4:-1, :].view(-1, outputs.logits.shape[-1])
labels_flat = labels[:, -3:].view(-1)

print(logits_flat.shape, labels_flat.shape)
labels_flat

torch.Size([3, 50257]) torch.Size([3])


tensor([41364,  2423,    13], device='cuda:0')

In [82]:
logits_flat[range(logits_flat.shape[0]), labels_flat].mean()

tensor(-94.6972, device='cuda:0')

In [90]:
joint_log_prob = logits_flat[range(logits_flat.shape[0]), labels_flat]
normalization = logits_flat.logsumexp(dim=-1)
joint_log_prob, normalization

(tensor([ -72.7772, -107.6214, -103.6930], device='cuda:0'),
 tensor([ -55.9424, -101.5030, -102.0553], device='cuda:0'))

In [91]:
(joint_log_prob - normalization).mean()

tensor(-8.1970, device='cuda:0')

In [87]:
logits_flat.softmax(dim=1)[range(logits_flat.shape[0]), labels_flat].log().mean()

tensor(-8.1970, device='cuda:0')

In [69]:
my_range = range(-(1+output_ids.shape[1]), -1)
my_range = range(5, 8)
my_range

range(5, 8)

In [47]:
required_logits = outputs.logits[:, my_range, output_ids]
required_logits

tensor([[[-112.2360, -108.7026, -120.8452]]], device='cuda:0')

In [48]:
-required_logits.sum()

tensor(341.7838, device='cuda:0')

In [50]:
inputs = dict(input_ids=input_ids, labels=input_ids)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

outputs.logits.shape, outputs.keys()

(torch.Size([1, 5, 50257]),
 odict_keys(['loss', 'logits', 'past_key_values', 'hidden_states']))

In [60]:
inputs = dict(input_ids=ids.long(), labels=labels)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

outputs.logits.shape, outputs.keys()

(torch.Size([1, 8, 50257]),
 odict_keys(['loss', 'logits', 'past_key_values', 'hidden_states']))

In [61]:
outputs.loss.item()

8.197007179260254

In [65]:
logits = outputs.logits

# Flatten the logits and labels for the loss calculation
logits_flat = logits[:, :-1, :].contiguous().view(-1, logits.size(-1))
# labels_flat = input_ids[:, 1:].contiguous().view(-1)
labels_flat = labels[:, 1:].contiguous().view(-1)

# Compute the loss manually using CrossEntropyLoss
loss = F.cross_entropy(logits_flat, labels_flat)

# Print the loss value
print("Loss:", loss.item())

Loss: 8.197007179260254


In [72]:
logits_flat.shape

torch.Size([7, 50257])

In [73]:
labels_flat.shape

torch.Size([7])