## üîç How Loss Is Calculated While Training Transformers (Step-by-Step)

Training a transformer boils down to **comparing predicted token probabilities with the correct target tokens** at every position in the sequence.  
Below is a complete walkthrough of how this works mathematically and in practice.

---

### ### üìå Example Input

**Input Tokens (B = 2, T = 4):**

[[128, 576, 320, 547],
[657, 943, 633, 547]]

**Target Tokens (next token for each position):**

[[576, 320, 547, 897],
[943, 633, 547, 210]]

Here:
- **B = 2** ‚Üí batch size  
- **T = 4** ‚Üí sequence length  
- **V = 50,257** ‚Üí vocabulary size for GPT-2  

---

#### üß† Step 1: Forward Pass ‚Üí Logits

When the input goes through the model, the transformer outputs **logits**, not probabilities.

**Logits shape:** 
[Batch_size, Time_steps, Vocabulary_size] = (B, T, V)

Each of the 4 positions in each sequence gets a probability distribution over the entire vocab.

---

#### üßÆ Step 2: Convert Logits ‚Üí Probabilities (Softmax)

Softmax is applied along the vocabulary dimension:<br>
probs = softmax(logits, dim=-1)<br>
This calculates the probabilities of each word fighting for that position

---

#### üîÅ Step 3: Reshape (Flatten) Logits and target ids
Logits shape : (B, T, V)
We convert it to : (B*T, V)

new_logits = logits.flatten(start_dim=0, end_dim=1)

Traget ids shape : (B, T)
New target shape : (B*T)

#### Now we have to look for propabilities corresponding to these target ids at each context / time-step

-> probabilites of shape (B*T, 1)

-> find average of these probs 

-> take negative log and this is the loss 

In [2]:
import torch
torch.manual_seed(42)
input_tokens = torch.randint(0, 50257, (2, 4))
target_tokens = torch.zeros_like(input_tokens)
target_tokens[:, :-1] = input_tokens[:, 1:]
target_tokens[:, -1] = torch.randint(0, 50257, (2,))
print(f"These are input tokens to LLM : {input_tokens.shape}")
print(f'Input tokens : {input_tokens}')
print(f'Taregt tokens : {target_tokens}')

These are input tokens to LLM : torch.Size([2, 4])
Input tokens : tensor([[11486, 31563,  6140, 17682],
        [13134, 22911, 20243, 43382]])
Taregt tokens : tensor([[31563,  6140, 17682, 18369],
        [22911, 20243, 43382, 45413]])


Now we have input_tokens and target_tokens, we'll send the input tokens to LLM which give output of shape [2, 4, 50257]

In [3]:
# after passing input ids from LLM it returns logits of shape (B, T, V)
logits = torch.rand((2, 4, 50257))
print(f"Shape of logits returned by LLM : {logits.shape}")

Shape of logits returned by LLM : torch.Size([2, 4, 50257])


In [4]:
# we'll take softmax of the last dim as it represent all words for that position, softmax will give probabilities of each word 
import torch.nn.functional as F
logits_softmax = F.softmax(logits, dim=-1)
print(f"Shape after applying softmax : {logits_softmax.shape}")

Shape after applying softmax : torch.Size([2, 4, 50257])


In [5]:
# Now for each position what will be the predicted token
# Token with hoghest probability will be the predicted token for that position
predicted_token_ids = torch.argmax(logits_softmax, dim=-1, keepdim=True)
print(predicted_token_ids)

tensor([[[17322],
         [34935],
         [32629],
         [18804]],

        [[ 2298],
         [40166],
         [ 7312],
         [27669]]])


See for batch 1 and position 0 - the token with highest probability is with id 17322

In [8]:
# but for calculating loss we need target ids
# so we'll take probabilities corresponding to target ids 
# Negative Log likelihood - take log -> calculate mean -> multiply -1 (This will give loss)

print(f"Target tokens shape : {target_tokens.shape}")
print(f"Logits shape : {logits_softmax.shape}")

Target tokens shape : torch.Size([2, 4])
Logits shape : torch.Size([2, 4, 50257])


In [10]:
targets_flatten = target_tokens.flatten(start_dim=0)
logits_flatten = logits_softmax.flatten(start_dim=0, end_dim=1)

print(f"Shape of targets flatten : {targets_flatten.shape}")
print(f"Shape of logits flatten : {logits_flatten.shape}")

Shape of targets flatten : torch.Size([8])
Shape of logits flatten : torch.Size([8, 50257])


In [11]:
target_probabilities = logits_flatten.gather(1, targets_flatten.unsqueeze(1))
target_probabilities

tensor([[1.1652e-05],
        [2.0611e-05],
        [2.5886e-05],
        [2.8628e-05],
        [1.3507e-05],
        [1.5790e-05],
        [2.0037e-05],
        [2.1681e-05]])

In [12]:
# take negative log 
log_probs = -1*torch.log(target_probabilities)
log_probs.shape

torch.Size([8, 1])

In [13]:
# take mean of negative logs
log_probs_mean = log_probs.sum()/(log_probs.shape[0])

print(f"Loss : {log_probs_mean}")

Loss : 10.874761581420898


#### **Using torch.nn.functional.cross_entropy()**

In [14]:
# now one line of code will do all the things we did above calculating softmax then loss
# we have target_tokens and predicted_output
print(f"Shape of target tokens : {target_tokens.shape}")
print(f"Shape of predicted output : {logits.shape}")

Shape of target tokens : torch.Size([2, 4])
Shape of predicted output : torch.Size([2, 4, 50257])


In [15]:
logits_flatten = logits.flatten(start_dim=0, end_dim=1)
targets_flatten = target_tokens.flatten(start_dim=0)

loss = torch.nn.functional.cross_entropy(logits_flatten, targets_flatten)

print(f"Loss : {loss}")

Loss : 10.874760627746582


### **Perplexity Score**

Perplexity is a widely used metric to evaluate language models.  
It provides a more interpretable measure of how *uncertain* the model is when predicting the next token.

---

#### **Definition**

$$
\text{Perplexity} = e^{\text{loss}}
$$

- Lower perplexity ‚Üí **model is more confident and better at prediction**  
- Higher perplexity ‚Üí **model is unsure and closer to guessing randomly**

---

#### **Example**

Suppose the model‚Äôs cross-entropy loss is: 10.87


Then perplexity is computed as:

$$
\text{Perplexity} = e^{10.87} \approx 52575
$$

This value means:

- The model is as uncertain as if it had to choose the next token **uniformly at random**
  from **~52,575 possible vocabulary items**.
- High perplexity = poor prediction quality.

---

#### **Interpretation**

- **Perplexity ‚Üì** ‚Üí Model is learning better token distributions  
- **Perplexity ‚Üë** ‚Üí Model is confused or the dataset/model capacity is insufficient  
- Ideal models (like trained GPTs) usually have perplexity in low double-digits on test sets.

---


