### Cross Entropy from scratch in neural networks

#### Sample data

In [1]:
# Let us try to implement cross entropy loss used while training GPT2 (or any LLM or that matter)
# We will take a few text examples and calculate both the forward and backward pass 
# from scratch
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

text = "Hi, how are you? I am learning cross entropy loss and how it is related to entropy."

tokens = tokenizer.encode(text)

print(f"Tokenized text: {tokens}")
print(f"Number of tokens: {len(tokens)}")

Tokenized text: [17250, 11, 703, 389, 345, 30, 314, 716, 4673, 3272, 40709, 2994, 290, 703, 340, 318, 3519, 284, 40709, 13]
Number of tokens: 20


#### Creating input and output tensors

In [2]:
# Since, neural networks are trained in batches, let us create a batch of size 2
import torch

# First we convert the tokens list to torch tensors and then create the batch
tokenized = torch.tensor(tokens)

batch_size = 2
context_length = 4

# Let us create one single mini-batch example
ix = torch.randint(len(tokenized)-context_length, (batch_size,))
input_ids = torch.stack([tokenized[i:i+context_length] for i in ix])
output_ids = torch.stack([tokenized[i+1:i+context_length+1] for i in ix])

print(f"Input: {input_ids}, shape: {input_ids.shape}")
print(f"\nOutput: {output_ids}, shape: {output_ids.shape}")

Input: tensor([[  318,  3519,   284, 40709],
        [40709,  2994,   290,   703]]), shape: torch.Size([2, 4])

Output: tensor([[ 3519,   284, 40709,    13],
        [ 2994,   290,   703,   340]]), shape: torch.Size([2, 4])


In [3]:
# The way to read the input and output tensor is 
# When input is 11, output is 703
# When input is [11, 703], the output is 389
# When input is [11, 703, 389], output is 345
# and so on

In [4]:
# Let us have a dummy matrix and a vector which will serve as the layer
import torch.nn as nn

torch.manual_seed(123)

# The vocab size (total number of unique tokens) for gpt2 tokenizer is 50257
# So, the embedding layer dimensions will be 50257 and 100 (using 100 as the embedding dimension)
embedding = nn.Embedding(50257, 100)
# Forward pass through the embedding
embedding_output = embedding(input_ids) # [batch_size, context_length] --> [batch_size, context_length, 100]
print(f"Embedding output shape: {embedding_output.shape}")

# The weight matrix
W = torch.randn((100, 50257))

# Forward pass through the network
logits = embedding_output @ W # [batch_size, context_length, 100] --> [batch_size, context_length, 50257]
print(f"Logits shape: {logits.shape}")

Embedding output shape: torch.Size([2, 4, 100])
Logits shape: torch.Size([2, 4, 50257])


In [5]:
# Now the cross entropy loss steps
import torch.nn.functional as F

# 1. Apply softmax to convert the logits to a probability distribution
probs = F.softmax(logits, dim=-1)
# Checking that we indeed got a probability distribution
print(f"Sum along the last dimension: {torch.sum(probs, dim=-1, keepdim=True)}")

# 2. We want to get the probabilites corresponding to the correct output token in the target
output_ids = output_ids.unsqueeze(-1) # adding a third dimension which is required for torch.gather
correct_id_probs = torch.gather(probs, dim=-1, index=output_ids)
print(f"Output index with the highest probabilities: {correct_id_probs}")

Sum along the last dimension: tensor([[[1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000]]], grad_fn=<SumBackward1>)
Output index with the highest probabilities: tensor([[[7.4362e-26],
         [3.1464e-18],
         [3.3123e-21],
         [5.2197e-21]],

        [[7.9674e-20],
         [4.7586e-18],
         [1.8158e-23],
         [1.3934e-21]]], grad_fn=<GatherBackward0>)


In [6]:
# Now calculating the loss
cross_entropy_loss = -torch.log(correct_id_probs.flatten()).mean()
print(f"Cross entropy loss: {cross_entropy_loss}")

Cross entropy loss: 47.03350830078125


In [7]:
# Implementing the backward pass
n = batch_size * context_length

with torch.no_grad():
    dlogits = F.softmax(logits, dim=-1)
    dlogits = dlogits.view(batch_size*context_length, -1)
    output_ids = output_ids.view(-1)
    dlogits[range(n), output_ids] -= 1
    dlogits /= n
    dlogits = dlogits.view(batch_size, context_length, -1)

#### Sanity check using pytorch loss.backward()

In [8]:
logits.retain_grad()

In [9]:
loss = F.cross_entropy(logits.flatten(0, 1), output_ids.flatten())

loss.backward()

In [10]:
# Shapes of logits and dlogits should match
assert dlogits.shape == logits.grad.shape, "Shapes of logits and dlogits should match"

In [14]:
torch.allclose(dlogits, logits.grad)

True

In [None]:
# Getting an approximate match