# P4: RNN/LSTM Text Generation

**Objective:** Train a small LSTM to generate text on a sample dataset (IMDB or small text corpus).

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, preprocessing
# This is a skeleton; students will prepare tokenized sequences and train an LSTM.
print('Prepare dataset, tokenize, build embedding + LSTM, then sample generation.')

In [None]:
# Practical 4: Character LSTM (PyTorch)
import torch, torch.nn as nn, torch.optim as optim
import random, math

text = 'To be, or not to be, that is the question.'*200
chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)

block=64
def get_batch(sz=64):
    ix = torch.randint(0, len(data)-block-1, (sz,))
    X = torch.stack([data[i:i+block] for i in ix])
    Y = torch.stack([data[i+1:i+block+1] for i in ix])
    return X, Y

class CharLSTM(nn.Module):
    def __init__(self, vocab, hidden=128):
        super().__init__()
        self.emb = nn.Embedding(vocab, 64)
        self.lstm = nn.LSTM(64, hidden, batch_first=True)
        self.fc = nn.Linear(hidden, vocab)
    def forward(self, x, h=None):
        x = self.emb(x)
        out, h = self.lstm(x, h)
        return self.fc(out), h

vocab = len(chars)
model = CharLSTM(vocab)
opt = optim.Adam(model.parameters(), lr=1e-3)
crit = nn.CrossEntropyLoss()

for step in range(200):
    X,Y = get_batch()
    logits,_ = model(X)
    loss = crit(logits.view(-1, vocab), Y.view(-1))
    opt.zero_grad(); loss.backward(); opt.step()
    if step%50==0: print('step', step, 'loss', loss.item())

def sample(prefix='To ' , length=120):
    model.eval();
    with torch.no_grad():
        h=None; x=torch.tensor([[stoi[c] for c in prefix]], dtype=torch.long)
        out=''+prefix
        for t in range(length):
            logits,h = model(x,h)
            p = torch.softmax(logits[:,-1, :], dim=-1)
            ix = torch.multinomial(p, num_samples=1)
            ch = itos[ix.item()]
            out+=ch
            x = ix.view(1,1)
    return out

print(sample())