# 04 — LSTM Model Architecture
## Build the Character-Level Language Model

---


## 🎯 Concept Primer

### LSTM Architecture Overview

Our model has **three layers**:

```
Input (char IDs) [B, T]
        ↓
1. EMBEDDING: IDs → Dense Vectors
   [B, T] → [B, T, embedding_dim]
        ↓
2. LSTM: Process Sequence with Memory
   [B, T, embedding_dim] → [B, T, hidden_size]
   (also updates hidden state h and cell state c)
        ↓
3. LINEAR: Project to Vocabulary Size
   [B, T, hidden_size] → [B, T, vocab_size]
        ↓
Output (logits) [B*T, vocab_size] (reshaped for loss)
```

### Why Embedding?

**Character IDs are categorical** (0, 1, 2, ..., vocab_size-1). They have no inherent order or relationship.

**Embeddings** convert sparse IDs to dense vectors that can learn:
- Similar characters (vowels vs. consonants)
- Positional patterns
- Contextual relationships

### LSTM Gates Refresher

- **Forget gate**: What to remove from cell state
- **Input gate**: What new info to store
- **Output gate**: What to output from cell state

These gates let LSTM remember long-range dependencies (e.g., matching quotes, sentence structure).

### Hidden State vs. Cell State

- **Cell state (`c`)**: Long-term memory (the \"conveyor belt\")
- **Hidden state (`h`)**: Short-term output (what we pass to next layer)

Both have shape: `[num_layers, batch_size, hidden_size]`

### What Breaks If We Skip This?

- No embedding = model can't learn char relationships
- Wrong shapes = crashes during forward pass
- No state initialization = unpredictable behavior

### Shapes Summary

| Component | Input | Output |
|-----------|-------|--------|
| **Embedding** | `[B, T]` | `[B, T, E]` |
| **LSTM** | `[B, T, E]` | `[B, T, H]` (+ states) |
| **Linear** | `[B, T, H]` | `[B, T, V]` |
| **Reshape** | `[B, T, V]` | `[B*T, V]` |

Where:
- B = batch_size (36)
- T = seq_length (48)
- E = embedding_dim (48)
- H = hidden_size (96)
- V = vocab_size (~50-80)

---


## ✅ Objectives

By the end of this notebook, you should:

- [ ] Define `CharacterLSTM` class inheriting from `nn.Module`
- [ ] Implement `__init__` with Embedding, LSTM, and Linear layers
- [ ] Implement `forward(x, states)` method that processes a batch
- [ ] Implement `init_state(batch_size)` to create initial (h0, c0)
- [ ] Instantiate the model and print its architecture
- [ ] Test forward pass with a fake batch to verify shapes

---


## 🎓 Acceptance Criteria

**You pass this notebook when:**

✅ `print(model)` shows all three layers  
✅ Forward pass on fake batch `[36, 48]` returns logits `[36*48, vocab_size]`  
✅ Forward pass also returns updated states `(h, c)` with shape `[1, 36, 96]`  
✅ You can explain each layer's purpose

---


## 📝 TODO 0: Imports

**Import PyTorch modules**


In [4]:
import torch
import torch.nn as nn

# We'll need vocab_size from previous notebooks
# For now, let's set it (you'll load from previous work in practice)
vocab_size = 60  # Approximate; adjust after running notebook 02

print(f"PyTorch version: {torch.__version__}")
print(f"Using vocab_size: {vocab_size}")


PyTorch version: 2.8.0
Using vocab_size: 60


## 📝 TODO 1: Define Class and `__init__`

**Hint:**  
Define layers as `self` attributes.

**Steps:**
1. Class `CharacterLSTM(nn.Module)`
2. `__init__(self, vocab_size, embedding_dim=48, hidden_size=96)`
3. Call `super().__init__()`
4. Create three layers:
   - `self.embedding = nn.Embedding(vocab_size, embedding_dim)`
   - `self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)`
   - `self.fc = nn.Linear(hidden_size, vocab_size)`
5. Store `self.hidden_size` for later use

**Why `batch_first=True`?**  
Makes input/output shape `[B, T, ...]` instead of `[T, B, ...]` — easier to work with.

**Hyperparameters:**
- `embedding_dim=48`: Dense vector size for each character
- `hidden_size=96`: LSTM's memory capacity


In [3]:
# TODO: Define CharacterLSTM with __init__

class CharacterLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=48, hidden_size=96):
        # TODO: Call super().__init__()
        super(CharacterLSTM, self).__init__()
        
        # TODO: Define layers
        # self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        # self.fc = nn.Linear(hidden_size, vocab_size)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
        # TODO: Store hidden_size for init_state
        # self.hidden_size = hidden_size
        self.hidden_size = hidden_size
        
        pass  
    
    # We'll add forward and init_state next


## 📝 TODO 2: Implement `forward` Method

**Hint:**  
Chain the three layers together.

**Steps:**
1. `def forward(self, x, states):`
   - `x`: input batch `[B, T]` of character IDs
   - `states`: tuple `(h, c)` of hidden/cell states
2. **Embedding**: `embedded = self.embedding(x)`  → `[B, T, E]`
3. **LSTM**: `lstm_out, new_states = self.lstm(embedded, states)`  → `[B, T, H]`
4. **Linear**: `logits = self.fc(lstm_out)`  → `[B, T, V]`
5. **Reshape**: `logits_flat = logits.view(-1, vocab_size)`  → `[B*T, V]`
6. **Return**: `(logits_flat, new_states)`

**Why reshape?**  
`CrossEntropyLoss` expects 2D logits `[N, C]` where N=num_samples, C=num_classes.

**Why return states?**  
During generation, we'll need to carry states across time steps.


In [5]:
# TODO: Add forward method to CharacterLSTM

class CharacterLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=48, hidden_size=96):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, x, states):
        # TODO: Pass x through embedding
        # embedded = self.embedding(x)  # [B, T] → [B, T, E]
        embedded = self.embedding(x)  # [B, T] → [B, T, E]
        
        # TODO: Pass through LSTM with states
        # lstm_out, new_states = self.lstm(embedded, states)  # [B, T, E] → [B, T, H]
        lstm_out, new_states = self.lstm(embedded, states)  # [B, T, E] → [B, T, H]
        
        # TODO: Pass through linear layer
        # logits = self.fc(lstm_out)  # [B, T, H] → [B, T, V]
        logits = self.fc(lstm_out)  # [B, T, H] → [B, T, V]
        
        # TODO: Reshape for CrossEntropyLoss
        # logits_flat = logits.view(-1, logits.size(-1))  # [B, T, V] → [B*T, V]
        logits_flat = logits.view(-1, logits.size(-1))  # [B, T, V] → [B*T, V]
        
        # TODO: Return logits and new states
        # return logits_flat, new_states
        return logits_flat, new_states
        
        pass
    
    # We'll add init_state next


## 📝 TODO 3: Implement `init_state` Method

**Hint:**  
Create zero tensors for initial hidden and cell states.

**Steps:**
1. `def init_state(self, batch_size):`
2. Create `h0 = torch.zeros(1, batch_size, self.hidden_size)`
   - Shape: `[num_layers, batch_size, hidden_size]`
   - We have 1 LSTM layer, so first dim = 1
3. Create `c0 = torch.zeros(1, batch_size, self.hidden_size)`
4. Return `(h0, c0)`

**Why zeros?**  
At the start of training/generation, we have no prior context, so we initialize to zeros.

**Why shape `[1, B, H]`?**  
PyTorch LSTM expects states with shape `[num_layers * num_directions, batch, hidden_size]`.  
We have 1 layer, uni-directional → first dim = 1.


In [6]:
# TODO: Complete CharacterLSTM with init_state

class CharacterLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=48, hidden_size=96):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, x, states):
        embedded = self.embedding(x)
        lstm_out, new_states = self.lstm(embedded, states)
        logits = self.fc(lstm_out)
        logits_flat = logits.view(-1, logits.size(-1))
        return logits_flat, new_states
    
    def init_state(self, batch_size):
        # TODO: Create hidden state (h0) and cell state (c0)
        # h0 = torch.zeros(1, batch_size, self.hidden_size)
        # c0 = torch.zeros(1, batch_size, self.hidden_size)
        # return (h0, c0)
        h0 = torch.zeros(1, batch_size, self.hidden_size)
        c0 = torch.zeros(1, batch_size, self.hidden_size)
        return (h0, c0)
        
    


## 📝 TODO 4: Instantiate the Model

**Steps:**
1. Create `model = CharacterLSTM(vocab_size)`
2. Print the model architecture with `print(model)`


In [7]:
# TODO: Instantiate the model
# model = CharacterLSTM(vocab_size)
# print(model)

model = CharacterLSTM(vocab_size)  # Replace this line

if model:
    print(model)
    print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")


CharacterLSTM(
  (embedding): Embedding(60, 48)
  (lstm): LSTM(48, 96, batch_first=True)
  (fc): Linear(in_features=96, out_features=60, bias=True)
)

Total parameters: 64,764


## 📝 TODO 5: Test Forward Pass

**Hint:**  
Create fake input and states, then call `model(x, states)`.

**Steps:**
1. Create fake input: `fake_batch = torch.randint(0, vocab_size, (36, 48))`
   - Random IDs, shape `[36, 48]`
2. Initialize states: `states = model.init_state(36)`
3. Forward pass: `logits, new_states = model(fake_batch, states)`
4. Print shapes:
   - `logits.shape` should be `[36*48, vocab_size]` = `[1728, vocab_size]`
   - `new_states[0].shape` (h) should be `[1, 36, 96]`
   - `new_states[1].shape` (c) should be `[1, 36, 96]`


In [8]:
# TODO: Test forward pass with fake data
# fake_batch = torch.randint(0, vocab_size, (36, 48))
# states = model.init_state(36)
# logits, new_states = model(fake_batch, states)

# print(f"Input shape: {fake_batch.shape}")
# print(f"Logits shape: {logits.shape}")
# print(f"New hidden state shape: {new_states[0].shape}")
# print(f"New cell state shape: {new_states[1].shape}")

# Your code here
fake_batch = torch.randint(0, vocab_size, (36, 48))
states = model.init_state(36)
logits, new_states = model(fake_batch, states)
print(f"Input shape: {fake_batch.shape}")
print(f"Logits shape: {logits.shape}")
print(f"New hidden state shape: {new_states[0].shape}")
print(f"New cell state shape: {new_states[1].shape}")


Input shape: torch.Size([36, 48])
Logits shape: torch.Size([1728, 60])
New hidden state shape: torch.Size([1, 36, 96])
New cell state shape: torch.Size([1, 36, 96])


## 💭 Reflection Prompts

**Write your observations:**

1. **Three layers**: What is the purpose of each layer (Embedding, LSTM, Linear)?

2. **LSTM gates**: In your own words, what do the forget, input, and output gates control?

3. **Hidden vs. Cell**: What's the difference between hidden state (h) and cell state (c)?

4. **Why reshape**: Why do we reshape from `[B, T, V]` to `[B*T, V]` before returning?

5. **Parameter count**: How many parameters does your model have? Where do most come from?

6. **batch_first**: What would happen if we set `batch_first=False` in the LSTM?

---
ddddddddddd

## 🚀 Next Steps

Once you've completed all TODOs and verified forward pass shapes:

➡️ **Move to Notebook 05**: Train the LSTM model

---

## 📌 Key Takeaways

- ✅ `nn.Embedding` converts sparse IDs to dense vectors
- ✅ `nn.LSTM` processes sequences with memory (h and c states)
- ✅ `nn.Linear` projects to vocabulary size for predictions
- ✅ Forward returns both logits and updated states
- ✅ States must be initialized to zeros at the start
- ✅ Reshaping to `[B*T, V]` prepares for CrossEntropyLoss

---

*Next up: Training the model with real data and watching loss decrease!*
