# 06 — Generate Text
## Sampling with the Trained LSTM

---


## 🎯 Concept Primer

### How Text Generation Works

Generation is an **autoregressive loop**:

```
1. Start with a prompt: "You will rejoice to hear"
2. Feed prompt through model → get logits for next char
3. Pick next char (greedy: argmax of logits)
4. Append char to sequence
5. Feed updated sequence → get logits for next char
6. Repeat until we have 500 characters
```

### Generation vs. Training

| Aspect | Training | Generation |
|--------|----------|------------|
| **Goal** | Learn from data | Produce new text |
| **Mode** | `model.train()` | `model.eval()` |
| **Gradients** | Needed | `torch.no_grad()` |
| **Input** | Real text batches | Generated chars |
| **Output** | Loss | New characters |

### Greedy Sampling

**Argmax**: Always pick the most likely character.

```python
logits = model(...)  # [1, vocab_size]
next_id = torch.argmax(logits).item()
```

**Pros**: Simple, deterministic  
**Cons**: Repetitive, no creativity

**Alternative**: Temperature sampling (adds randomness) — left as an extension.

### States in Generation

Unlike training (batch-level states), generation:
- Uses **single batch size = 1**
- **Carries states** across time steps (maintains context)
- Feeds one character at a time

### What Breaks If We Skip This?

- No `eval()` = dropout/batchnorm behave incorrectly
- Gradients tracked = slow + memory leak
- Wrong prompt tokenization = crashes or gibberish

### Shapes During Generation

| Step | Shape |
|------|-------|
| **Prompt IDs** | `[1, prompt_length]` |
| **Single char input** | `[1, 1]` |
| **Logits** | `[1, vocab_size]` |
| **States (h, c)** | `[1, 1, 96]` each |

---


## ✅ Objectives

By the end of this notebook, you should:

- [ ] Load the trained model weights
- [ ] Set the model to `eval()` mode
- [ ] Define a starting prompt: `"You will rejoice to hear"`
- [ ] Tokenize the prompt to IDs
- [ ] Initialize states for batch size = 1
- [ ] Implement generation loop to produce 500 characters
- [ ] Decode IDs back to text and print

---


## 🎓 Acceptance Criteria

**You pass this notebook when:**

✅ 500 characters of generated text print without errors  
✅ Generated text looks vaguely Frankenstein-ish (Gothic, archaic style)  
✅ You can explain the difference between greedy and temperature sampling

---


## 📝 TODO 0: Setup — Load Data, Model, Weights

**Load vocab mappings, define model, load trained weights**


In [None]:
import torch
import torch.nn as nn

# === Load vocabulary mappings (from notebook 02) ===
with open('../datasets/frankenstein.txt', 'r', encoding='utf-8') as f:
    frankenstein = f.read()
    
first_letter_text = frankenstein[1380:8230]
tokenized_text = list(first_letter_text)
unique_char_tokens = sorted(set(tokenized_text))
c2ix = {char: idx for idx, char in enumerate(unique_char_tokens)}
ix2c = {idx: char for char, idx in c2ix.items()}
vocab_size = len(c2ix)

print(f"Vocabulary loaded: {vocab_size} unique characters")

# === Define Model (same as before) ===
class CharacterLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=48, hidden_size=96):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, x, states):
        embedded = self.embedding(x)
        lstm_out, new_states = self.lstm(embedded, states)
        logits = self.fc(lstm_out)
        logits_flat = logits.view(-1, logits.size(-1))
        return logits_flat, new_states
    
    def init_state(self, batch_size):
        h0 = torch.zeros(1, batch_size, self.hidden_size)
        c0 = torch.zeros(1, batch_size, self.hidden_size)
        return (h0, c0)

# === Instantiate and load trained weights ===
model = CharacterLSTM(vocab_size)
model.load_state_dict(torch.load('trained_lstm_model.pth'))
model.eval()  # Set to evaluation mode

print("Model loaded and set to eval() mode")


## 📝 TODO 1: Define Prompt and Tokenize

**Hint:**  
Convert prompt string → list of char IDs.

**Steps:**
1. Define `starting_prompt = "You will rejoice to hear"`
2. Convert to list of IDs: `[c2ix[char] for char in starting_prompt]`
3. Convert to tensor: `torch.tensor(..., dtype=torch.long).unsqueeze(0)`
   - `unsqueeze(0)` adds batch dimension: `[prompt_length]` → `[1, prompt_length]`


In [None]:
# TODO: Define and tokenize the starting prompt
# starting_prompt = "You will rejoice to hear"
# prompt_ids = [c2ix[char] for char in starting_prompt]
# prompt_tensor = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0)  # [1, prompt_length]

starting_prompt = None  # Replace
prompt_tensor = None  # Replace

if starting_prompt and prompt_tensor is not None:
    print(f"Prompt: '{starting_prompt}'")
    print(f"Prompt tensor shape: {prompt_tensor.shape}")


## 📝 TODO 2: Warm Up States with Prompt

**Hint:**  
Feed the prompt through the model to initialize states.

**Steps:**
1. Initialize states: `states = model.init_state(1)`
2. With `torch.no_grad():`
3. Feed prompt: `logits, states = model(prompt_tensor, states)`
4. Get last logits: `last_logits = logits[-1:]`  (shape `[1, vocab_size]`)

**Why this step?**  
The prompt "primes" the model with context. The resulting states carry memory of "You will rejoice to hear".


In [None]:
# TODO: Feed prompt to warm up states
# states = model.init_state(1)
# 
# with torch.no_grad():
#     logits, states = model(prompt_tensor, states)
#     last_logits = logits[-1:]  # Last time step logits

states = None  # Replace
last_logits = None  # Replace

if states and last_logits is not None:
    print(f"States warmed up. Last logits shape: {last_logits.shape}")


## 📝 TODO 3: Generation Loop

**Hint:**  
Loop 500 times, generating one character per iteration.

**Structure:**
```python
generated_ids = []
num_generated_chars = 500

with torch.no_grad():
    for _ in range(num_generated_chars):
        # 1. Argmax to get next char ID
        next_id = torch.argmax(last_logits).item()
        generated_ids.append(next_id)
        
        # 2. Prepare next input: shape [1, 1]
        next_input = torch.tensor([[next_id]], dtype=torch.long)
        
        # 3. Forward pass
        logits, states = model(next_input, states)
        last_logits = logits[-1:]
```

**Key details:**
- `torch.argmax(last_logits)` picks most likely char
- `.item()` converts tensor to Python int
- `[[next_id]]` creates shape `[1, 1]`
- States are carried across iterations


In [None]:
# TODO: Generation loop
# generated_ids = []
# num_generated_chars = 500
# 
# with torch.no_grad():
#     for _ in range(num_generated_chars):
#         # Get next char ID (greedy sampling)
#         next_id = torch.argmax(last_logits).item()
#         generated_ids.append(next_id)
#         
#         # Prepare next input [1, 1]
#         next_input = torch.tensor([[next_id]], dtype=torch.long)
#         
#         # Forward pass
#         logits, states = model(next_input, states)
#         last_logits = logits[-1:]

generated_ids = []  # Replace with your loop

if generated_ids:
    print(f"Generated {len(generated_ids)} character IDs")


## 📝 TODO 4: Decode and Print Generated Text

**Hint:**  
Convert IDs back to characters using `ix2c`.

**Steps:**
1. Decode: `generated_text = ''.join([ix2c[id] for id in generated_ids])`
2. Combine with prompt: `full_text = starting_prompt + generated_text`
3. Print the result


In [None]:
# TODO: Decode generated IDs to text
# generated_text = ''.join([ix2c[id] for id in generated_ids])
# full_text = starting_prompt + generated_text

# print("="*80)
# print("GENERATED TEXT (Prompt + 500 chars):")
# print("="*80)
# print(full_text)
# print("="*80)

# Your code here


## 💭 Reflection Prompts

**Write your observations:**

1. **Generated style**: Does the generated text resemble Mary Shelley's style? (sentence structure, word choice, punctuation)

2. **Coherence**: Is the text coherent over short spans? Long spans?

3. **Repetition**: Do you see any repeated phrases or loops?

4. **Greedy vs. Sampling**: What would change if we used temperature sampling instead of argmax?

5. **Prompt influence**: How much does the starting prompt affect the generated text?

6. **Improvements**: What would make the generation better? (More data? Longer training? Larger model?)

---


## 🚀 Extensions to Try

**Want to explore further?**

1. **Temperature Sampling**:
   ```python
   # Instead of argmax:
   probs = torch.softmax(last_logits / temperature, dim=-1)
   next_id = torch.multinomial(probs, num_samples=1).item()
   ```
   - `temperature < 1`: More confident (sharper)
   - `temperature > 1`: More random (flatter)

2. **Longer Generation**: Try 1000 or 2000 characters

3. **Different Prompts**: "I beheld the wretch", "It was a dreary night"

4. **Train on Full Novel**: Remove the slice and train on entire *Frankenstein*

5. **Beam Search**: Keep top-k candidates at each step

---


## 📌 Key Takeaways

- ✅ Generation is autoregressive: each char depends on previous chars
- ✅ `model.eval()` and `torch.no_grad()` are essential for inference
- ✅ Greedy sampling (argmax) is simple but can be repetitive
- ✅ States are carried across generation steps to maintain context
- ✅ The prompt "primes" the model with initial context
- ✅ Decoding: IDs → characters using `ix2c`

---

## 🎉 Congratulations!

You've completed the full pipeline:
1. ✅ Loaded and sliced text data
2. ✅ Built character vocabulary
3. ✅ Created Dataset and DataLoader
4. ✅ Defined LSTM architecture
5. ✅ Trained the model
6. ✅ Generated new text

**Next:** Document your learnings in **Notebook 99 (Lab Notes)**!

---

*This is honest work. Now go forth and generate!* 🚀
