## üîç How Loss Is Calculated While Training Transformers (Step-by-Step)

Training a transformer boils down to **comparing predicted token probabilities with the correct target tokens** at every position in the sequence.  
Below is a complete walkthrough of how this works mathematically and in practice.

---

### ### üìå Example Input

**Input Tokens (B = 2, T = 4):**

[[128, 576, 320, 547],
[657, 943, 633, 547]]

**Target Tokens (next token for each position):**

[[576, 320, 547, 897],
[943, 633, 547, 210]]

Here:
- **B = 2** ‚Üí batch size  
- **T = 4** ‚Üí sequence length  
- **V = 50,257** ‚Üí vocabulary size for GPT-2  

---

#### üß† Step 1: Forward Pass ‚Üí Logits

When the input goes through the model, the transformer outputs **logits**, not probabilities.

**Logits shape:** 
[Batch_size, Time_steps, Vocabulary_size] = (B, T, V)

Each of the 4 positions in each sequence gets a probability distribution over the entire vocab.

---

#### üßÆ Step 2: Convert Logits ‚Üí Probabilities (Softmax)

Softmax is applied along the vocabulary dimension:<br>
probs = softmax(logits, dim=-1)<br>
This calculates the probabilities of each word fighting for that position

---

#### üîÅ Step 3: Reshape (Flatten) Logits and target ids
Logits shape : (B, T, V)
We convert it to : (B*T, V)

new_logits = logits.flatten(start_dim=0, end_dim=1)

Traget ids shape : (B, T)
New target shape : (B*T)

#### Now we have to look for propabilities corresponding to these target ids at each context / time-step

-> probabilites of shape (B*T, 1)

-> find average of these probs 

-> take negative log and this is the loss 

In [1]:
import torch
torch.manual_seed(42)
input_tokens = torch.randint(0, 50257, (2, 4))
target_tokens = torch.zeros_like(input_tokens)
target_tokens[:, :-1] = input_tokens[:, 1:]
target_tokens[:, -1] = torch.randint(0, 50257, (2,))
print(f"These are input tokens to LLM : {input_tokens.shape}")
print(f'Input tokens : {input_tokens}')
print(f'Taregt tokens : {target_tokens}')

These are input tokens to LLM : torch.Size([2, 4])
Input tokens : tensor([[11486, 31563,  6140, 17682],
        [13134, 22911, 20243, 43382]])
Taregt tokens : tensor([[31563,  6140, 17682, 18369],
        [22911, 20243, 43382, 45413]])


Now we have input_tokens and target_tokens, we'll send the input tokens to LLM which give output of shape [2, 4, 50257]

In [2]:
predicted_output = torch.rand((2, 4, 50257))
predicted_output.shape

torch.Size([2, 4, 50257])

In [3]:
# we'll take softmax of the last dim as it represent all words weight for that position 
import torch.nn.functional as F
predicted_output_softmax = F.softmax(predicted_output, dim=-1)
print(predicted_output_softmax.shape)

torch.Size([2, 4, 50257])


In [4]:
# check whether softmax is applied or not 
predicted_output_softmax[0, 0, :].sum()

tensor(1.0000)

In [7]:
# token_ids with max_probability are the predicted one
predicted_tokens = torch.argmax(predicted_output_softmax, dim=-1, keepdim=True)
predicted_tokens

tensor([[[17322],
         [34935],
         [32629],
         [18804]],

        [[ 2298],
         [40166],
         [ 7312],
         [27669]]])

In [9]:
# now for training what we do is we select probabilities corresponding to target_tokens from  predicted tokens 
target_tokens.shape

torch.Size([2, 4])

In [18]:
target_tokens

tensor([[31563,  6140, 17682, 18369],
        [22911, 20243, 43382, 45413]])

In [19]:
logits = predicted_output_softmax
logits.shape

torch.Size([2, 4, 50257])

In [28]:
new_logits = logits.flatten(start_dim=0, end_dim=1)

In [29]:
new_logits.shape

torch.Size([8, 50257])

In [31]:
target_ids = target_tokens.flatten(start_dim=0)
target_ids.shape

torch.Size([8])

In [46]:
preds_prob = new_logits.gather(1, target_ids.unsqueeze(1))

In [48]:
preds_prob

tensor([[1.1652e-05],
        [2.0611e-05],
        [2.5886e-05],
        [2.8628e-05],
        [1.3507e-05],
        [1.5790e-05],
        [2.0037e-05],
        [2.1681e-05]])

In [53]:
avg_log_prob = preds_prob.sum()/preds_prob.shape[0]
avg_log_prob

tensor(1.9724e-05)

In [55]:
log_probs = -1*torch.log(avg_log_prob)
log_probs


tensor(10.8337)

This is so called loss that we want to minimize 

In [56]:
# now one line of code will do all the things we did above calculating softmax then loss
# we have target_tokens and predicted_output
print(f"Shape of target tokens : {target_tokens.shape}")
print(f"Shape of predicted output : {predicted_output.shape}")

Shape of target tokens : torch.Size([2, 4])
Shape of predicted output : torch.Size([2, 4, 50257])


In [59]:
predicted_output.flatten(start_dim=0, end_dim=1).shape

torch.Size([8, 50257])

In [68]:
target_tokens

tensor([[31563,  6140, 17682, 18369],
        [22911, 20243, 43382, 45413]])

In [67]:
tks = target_tokens.flatten(start_dim=0)
print(tks)

tensor([31563,  6140, 17682, 18369, 22911, 20243, 43382, 45413])


In [69]:
logits_flatten = predicted_output.reshape(-1, predicted_output.size(-1))  # (B*T, vocab)
targets_flatten = target_tokens.reshape(-1)                              # (B*T)

loss = torch.nn.functional.cross_entropy(logits_flatten, targets_flatten)

In [70]:
loss

tensor(10.8748)

In [72]:
torch.exp(torch.tensor(10.87))

tensor(52575.2031)

#### **Perplexity Score**

More interpretable Perplexity = exponential(loss)

Lower preplexity = Better at Prediction

Lets say loss = 10.87
Perplexity score = 52575 this means model is roughly uncertail as it have to predict next token randomly from 52575 words