# 05 — Train the LSTM
## Training Loop for 5 Epochs

---


## 🎯 Concept Primer

### The Training Loop

Training a neural network follows this pattern:

```
FOR each epoch:
    FOR each batch in dataloader:
        1. Zero gradients (clear previous batch's gradients)
        2. Initialize states (h0, c0) for this batch
        3. Forward pass: get predictions
        4. Compute loss (how wrong are we?)
        5. Backward pass: compute gradients
        6. Optimizer step: update weights
    
    Print epoch loss
```

### Loss Function: CrossEntropyLoss

**CrossEntropyLoss** measures how far our predictions are from the true labels.

**Input shapes:**
- Logits: `[B*T, vocab_size]` — raw scores for each character
- Labels: `[B*T]` — true character IDs

**Lower loss = better predictions**

### Optimizer: Adam

**Adam** is an adaptive learning rate optimizer. It:
- Adjusts learning rate per parameter
- Uses momentum for smoother updates
- Works well with default settings

**Learning rate = 0.015**: How big each weight update is.

### Why Re-initialize States Per Batch?

Two options:
1. **Re-init per batch** (simpler): Each batch is independent
2. **Carry states across batches** (complex): Requires detaching gradients

We use option 1 for simplicity.

### What Breaks If We Skip This?

- No `zero_grad()` = gradients accumulate incorrectly
- Wrong label shape = loss computation fails
- No backward = weights never update
- No optimizer step = loss never improves

### Shapes During Training

| Step | Shape |
|------|-------|
| **Batch features** | `[36, 48]` |
| **Batch labels** | `[36, 48]` |
| **Labels flattened** | `[36*48] = [1728]` |
| **Logits** | `[1728, vocab_size]` |
| **Loss** | scalar |

---


## ✅ Objectives

By the end of this notebook, you should:

- [ ] Load all components from previous notebooks (data, model)
- [ ] Instantiate `nn.CrossEntropyLoss()` as the loss function
- [ ] Create `Adam` optimizer with `lr=0.015`
- [ ] Implement the training loop for 5 epochs
- [ ] Print loss per epoch and observe it decreasing
- [ ] Save the trained model weights

---


## 🎓 Acceptance Criteria

**You pass this notebook when:**

✅ Training runs for 5 epochs without errors  
✅ Loss prints after each epoch  
✅ Loss generally trends downward (not always monotonic, but overall lower)  
✅ Model weights are saved to disk

---


## 📝 TODO 0: Setup — Imports and Data/Model Loading

**Load everything from previous notebooks**


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# === Load Data (from notebooks 01-02) ===
with open('../datasets/frankenstein.txt', 'r', encoding='utf-8') as f:
    frankenstein = f.read()
    
first_letter_text = frankenstein[1380:8230]
tokenized_text = list(first_letter_text)
unique_char_tokens = sorted(set(tokenized_text))
c2ix = {char: idx for idx, char in enumerate(unique_char_tokens)}
ix2c = {idx: char for char, idx in c2ix.items()}
vocab_size = len(c2ix)
tokenized_id_text = [c2ix[char] for char in tokenized_text]

print(f"Data loaded: {len(tokenized_id_text)} IDs, vocab_size={vocab_size}")

# === Define Dataset (from notebook 03) ===
class TextDataset(Dataset):
    def __init__(self, tokenized_ids, seq_length):
        self.ids = tokenized_ids
        self.seq_length = seq_length
    
    def __len__(self):
        return len(self.ids) - self.seq_length
    
    def __getitem__(self, idx):
        features = self.ids[idx : idx + self.seq_length]
        labels = self.ids[idx + 1 : idx + self.seq_length + 1]
        return (
            torch.tensor(features, dtype=torch.long),
            torch.tensor(labels, dtype=torch.long)
        )

# === Create Dataset & DataLoader ===
dataset = TextDataset(tokenized_id_text, seq_length=48)
dataloader = DataLoader(dataset, batch_size=36, shuffle=True)

print(f"Dataset: {len(dataset)} samples")
print(f"DataLoader: {len(dataloader)} batches per epoch")

# === Define Model (from notebook 04) ===
class CharacterLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=48, hidden_size=96):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, x, states):
        embedded = self.embedding(x)
        lstm_out, new_states = self.lstm(embedded, states)
        logits = self.fc(lstm_out)
        logits_flat = logits.view(-1, logits.size(-1))
        return logits_flat, new_states
    
    def init_state(self, batch_size):
        h0 = torch.zeros(1, batch_size, self.hidden_size)
        c0 = torch.zeros(1, batch_size, self.hidden_size)
        return (h0, c0)

# === Instantiate Model ===
lstm_model = CharacterLSTM(vocab_size)
print(f"\nModel instantiated with {sum(p.numel() for p in lstm_model.parameters()):,} parameters")


## 📝 TODO 1: Define Loss and Optimizer

**Hint:**  
Use `nn.CrossEntropyLoss()` and `Adam(model.parameters(), lr=0.015)`.

**Steps:**
1. Create loss function: `criterion = nn.CrossEntropyLoss()`
2. Create optimizer: `optimizer = Adam(lstm_model.parameters(), lr=0.015)`

**Why CrossEntropyLoss?**  
It combines softmax + negative log likelihood — perfect for classification tasks (predicting next char).

**Why lr=0.015?**  
Experimentation shows this works well for this small dataset. Too high = unstable, too low = slow learning.


In [None]:
# TODO: Define loss function and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = Adam(lstm_model.parameters(), lr=0.015)

criterion = None  # Replace this line
optimizer = None  # Replace this line

if criterion and optimizer:
    print("Loss function: CrossEntropyLoss")
    print("Optimizer: Adam with lr=0.015")


## 📝 TODO 2: Implement the Training Loop

**Hint:**  
Nested loops: outer for epochs, inner for batches.

**Structure:**
```python
for epoch in range(num_epochs):
    epoch_loss = 0
    
    for batch_features, batch_labels in dataloader:
        # 1. Zero gradients
        # 2. Initialize states
        # 3. Forward pass
        # 4. Compute loss (flatten labels first!)
        # 5. Backward
        # 6. Optimizer step
        # 7. Accumulate loss
    
    # Print average epoch loss
```

**Key details:**
- **Flatten labels**: `batch_labels.view(-1)` → `[B*T]`
- **Get batch size**: `batch_features.size(0)`
- **Accumulate loss**: Use `loss.item()` to get scalar value


In [None]:
# TODO: Training loop for 5 epochs

num_epochs = 5

# for epoch in range(num_epochs):
#     epoch_loss = 0
#     
#     for batch_features, batch_labels in dataloader:
#         # Get batch size
#         batch_size = batch_features.size(0)
#         
#         # 1. Zero gradients
#         optimizer.zero_grad()
#         
#         # 2. Initialize states for this batch
#         states = lstm_model.init_state(batch_size)
#         
#         # 3. Forward pass
#         logits, new_states = lstm_model(batch_features, states)
#         
#         # 4. Compute loss
#         # Flatten labels to [B*T]
#         labels_flat = batch_labels.view(-1)
#         loss = criterion(logits, labels_flat)
#         
#         # 5. Backward pass
#         loss.backward()
#         
#         # 6. Optimizer step
#         optimizer.step()
#         
#         # 7. Accumulate loss
#         epoch_loss += loss.item()
#     
#     # Print average loss for this epoch
#     avg_loss = epoch_loss / len(dataloader)
#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Your code here
print("\\nTraining complete!")


## 📝 TODO 3: Save the Trained Model

**Hint:**  
Use `torch.save(model.state_dict(), path)`.

**Steps:**
1. Save model weights: `torch.save(lstm_model.state_dict(), 'trained_lstm_model.pth')`
2. Print confirmation

**Why save?**  
So you can load the trained weights in notebook 06 for text generation without retraining.


In [None]:
# TODO: Save the trained model
# torch.save(lstm_model.state_dict(), 'trained_lstm_model.pth')
# print("Model saved to 'trained_lstm_model.pth'")

# Your code here


## 💭 Reflection Prompts

**Write your observations:**

1. **Loss trend**: Did your loss decrease over epochs? By how much?

2. **What is the loss measuring**: In plain words, what does the loss value represent?

3. **Why zero_grad()**: What would happen if you forgot to call `optimizer.zero_grad()`?

4. **Batch size impact**: How would changing batch_size from 36 to 12 or 72 affect training?

5. **Learning rate**: What would happen with lr=0.0001 (too small) or lr=0.5 (too large)?

6. **Epoch count**: Is 5 epochs enough? How would you decide?

---


## 🚀 Next Steps

Once you've completed training and loss is decreasing:

➡️ **Move to Notebook 06**: Generate text with the trained model

---

## 📌 Key Takeaways

- ✅ Training loop: zero_grad → forward → loss → backward → step
- ✅ CrossEntropyLoss measures prediction error
- ✅ Adam optimizer updates weights to minimize loss
- ✅ States are re-initialized per batch for simplicity
- ✅ Label shape must be `[B*T]` for CrossEntropyLoss
- ✅ Loss should trend downward (not always monotonic)
- ✅ Save model weights to reuse them later

---

*Next up: Using the trained model to generate Frankenstein-style text!*
