# Simple Language Model (Predict next char)
## 1. Create Model

In [1]:
import torch
import torch.nn as nn
from transformer import TransformerEncoder

torch.Size([2, 10, 512])


In [2]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_layers, max_seq_len, droput=0.1):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, n_heads, d_ff, n_layers, max_seq_len, droput)
        self.output_proj = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        # x: (batch, seq_len) - token IDs
        x = self.encoder(x) # (batch, seq_len, d_model)
        logits = self.output_proj(x) # (batch, seq_len, vocab_size)
        return logits
        

## 2. Prepare simple dataset

In [3]:
# text = """To be or not to be, that is the question.
# Whether 'tis nobler in the mind to suffer
# The slings and arrows of outrageous fortune,
# Or to take arms against a sea of troubles."""

text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them. To die—to sleep,
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to: 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep, perchance to dream—ay, there's the rub:
For in that sleep of death what dreams may come,
When we have shuffled off this mortal coil,
Must give us pause—there's the respect
That makes calamity of so long life.
"""


# Create vocabulary (character-level)
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}

print(f"Vocab size: {vocab_size}")
print(f"Characters: {chars}")

# Encode text
encoded = torch.tensor([char_to_idx[ch] for ch in text])
print(f"Text length: {len(encoded)}")

Vocab size: 40
Characters: ['\n', ' ', "'", ',', '-', '.', ':', ';', 'A', 'D', 'F', 'M', 'N', 'O', 'T', 'W', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'y', '—']
Text length: 604


## 3. Create Dataset & DataLoader

In [4]:
from torch.utils.data import Dataset, DataLoader
from typing import Any

class CharDataset(Dataset):
    def __init__(self, encoded_text, seq_len) -> None:
        super().__init__()
        self.data = encoded_text
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.data) - self.seq_len
    
    def __getitem__(self, idx) -> Any:
        # Input: seq_len tokens
        # Output: next seq_len tokens (shifted by 1)
        
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1: idx + self.seq_len + 1]
        return x, y

# Create dataset
seq_len = 32
batch_size = 16
dataset = CharDataset(encoded_text=encoded, seq_len=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Check one batch
x, y = next(iter(dataloader))
x2, y2 = next(iter(dataloader))
print(f"Input Shape: {x.shape}")
print(f"Output Shape: {y.shape}")
print(f"x1, y1: {x}\n{y}")
print(f"x2, y2: {x2}\n{y2}")

Input Shape: torch.Size([16, 32])
Output Shape: torch.Size([16, 32])
x1, y1: tensor([[ 1, 14, 29,  1, 19, 24, 20,  3,  1, 34, 29,  1, 33, 26, 20, 20, 30,  7,
          0, 14, 29,  1, 33, 26, 20, 20, 30,  3,  1, 30, 20, 32],
        [23, 29, 35, 33, 16, 28, 19,  1, 28, 16, 34, 35, 32, 16, 26,  1, 33, 23,
         29, 18, 25, 33,  0, 14, 23, 16, 34,  1, 21, 26, 20, 33],
        [ 1, 30, 20, 32, 18, 23, 16, 28, 18, 20,  1, 34, 29,  1, 19, 32, 20, 16,
         27, 39, 16, 38,  3,  1, 34, 23, 20, 32, 20,  2, 33,  1],
        [ 1, 28, 29, 17, 26, 20, 32,  1, 24, 28,  1, 34, 23, 20,  1, 27, 24, 28,
         19,  1, 34, 29,  1, 33, 35, 21, 21, 20, 32,  0, 14, 23],
        [22,  1, 20, 28, 19,  1, 34, 23, 20, 27,  5,  1, 14, 29,  1, 19, 24, 20,
         39, 34, 29,  1, 33, 26, 20, 20, 30,  3,  0, 12, 29,  1],
        [23, 20, 16, 32, 34,  4, 16, 18, 23, 20,  1, 16, 28, 19,  1, 34, 23, 20,
          1, 34, 23, 29, 35, 33, 16, 28, 19,  1, 28, 16, 34, 35],
        [33,  1, 30, 16, 35, 33, 20, 39, 

## 4. Training Loop

In [5]:
from torch.optim import Adam

vocab_size = len(chars)
d_model = 128
n_heads = 4
d_ff = 512
n_layers = 4
max_seq_len = 128
learning_rate = 3e-4
n_epochs = 50

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LanguageModel(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, d_ff=d_ff, n_layers=n_layers, max_seq_len=max_seq_len)
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

# Training loop
model.train()
for epoch in range(n_epochs):
    total_loss = 0

    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        logits = model(x) # (batch_size, seq_len, vocab_size)

        # Reshape for cross-entropy
        # CrossEntropyLoss expects: (batch * seq_len, vocab_size) and (batch_size * seq_len)
        # logits = logits.view(batch_size * seq_len, vocab_size)
        logits = logits.view(-1, vocab_size)
        y = y.view(-1)

        # Calculate loss
        loss = criterion(logits, y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/50, Loss: 2.8121
Epoch 2/50, Loss: 2.1679
Epoch 3/50, Loss: 1.9153
Epoch 4/50, Loss: 1.7190
Epoch 5/50, Loss: 1.5180
Epoch 6/50, Loss: 1.3436
Epoch 7/50, Loss: 1.2062
Epoch 8/50, Loss: 1.1128
Epoch 9/50, Loss: 1.0397
Epoch 10/50, Loss: 0.9917
Epoch 11/50, Loss: 0.9542
Epoch 12/50, Loss: 0.9399
Epoch 13/50, Loss: 0.9156
Epoch 14/50, Loss: 0.9026
Epoch 15/50, Loss: 0.8865
Epoch 16/50, Loss: 0.8782
Epoch 17/50, Loss: 0.8677
Epoch 18/50, Loss: 0.8599
Epoch 19/50, Loss: 0.8567
Epoch 20/50, Loss: 0.8459
Epoch 21/50, Loss: 0.8429
Epoch 22/50, Loss: 0.8430
Epoch 23/50, Loss: 0.8347
Epoch 24/50, Loss: 0.8278
Epoch 25/50, Loss: 0.8262
Epoch 26/50, Loss: 0.8198
Epoch 27/50, Loss: 0.8202
Epoch 28/50, Loss: 0.8165
Epoch 29/50, Loss: 0.8152
Epoch 30/50, Loss: 0.8109
Epoch 31/50, Loss: 0.8106
Epoch 32/50, Loss: 0.8061
Epoch 33/50, Loss: 0.8049
Epoch 34/50, Loss: 0.8013
Epoch 35/50, Loss: 0.8040
Epoch 36/50, Loss: 0.7952
Epoch 37/50, Loss: 0.7977
Epoch 38/50, Loss: 0.7948
Epoch 39/50, Loss: 0.

## 5. Generate Text

In [11]:
def generate(model, start_text, max_new_tokens=100, temperature=1.0):
    model.eval()
    
    # Encode starting text
    context = [char_to_idx[ch] for ch in start_text]
    context = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)
    
    generated = list(start_text)

    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Get predictions
            logits = model(context) # (1, seq_len, vocab_size)

            # Get logits for last position
            logits = logits[0, -1, :] / temperature # (vocab_size)

            # Sample from distribution
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1).item()
            
            # Append to context
            context = torch.cat([context, torch.tensor([[next_idx]], device=device)], dim=1)

            # Add to generated text
            generated.append(chars[next_idx])
    
    return ''.join(generated)

# Generate
start = "love me"
generated_text = generate(model, start, max_new_tokens=100, temperature=0.8)
print("\nGenerated text:")
print(generated_text)


Generated text:
love mevortomorrtortormortorrmormrmormorrmormrrmoatormormormormirmioatarmormormormormoamormoimormmoutarmorm


In [7]:
def generate(model, start_text, max_new_tokens=100, temperature=1.0, top_k=None):
    model.eval()
    
    context = [char_to_idx[ch] for ch in start_text]
    context = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)
    
    generated = list(start_text)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Limit context to max_seq_len
            context_input = context if context.size(1) <= max_seq_len else context[:, -max_seq_len:]
            
            logits = model(context_input)[0, -1, :] / temperature
            
            # Top-k sampling (prevents low-probability garbage)
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf'))
                logits[top_k_indices] = top_k_logits
            
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1).item()
            
            context = torch.cat([context, torch.tensor([[next_idx]], device=device)], dim=1)
            generated.append(chars[next_idx])
    
    return ''.join(generated)

# Try different settings
print(generate(model, "To be", max_new_tokens=100, temperature=0.8, top_k=5))
print(generate(model, "To be", max_new_tokens=100, temperature=1.2, top_k=10))

To be To be To be To be so be so be s
To beso bes
To bes
To beTo bes
To beso beseTo bes
To beso beseTo be
To beuto theto theto thetoistheto thetois thon the o theto theton theto theto theton the, s tho n
Whe, th
