# Implementing Transformer with GPT-Style LM on IMDB with PyTorch

This notebook trains a small GPT-style Transformer (encoder-only with causal mask) on the IMDB dataset for language modeling.
- Tokenizer: **GPT-2** (pad token set to EOS)
- Objective: next-token prediction (causal LM)
- Includes: training, validation (loss/perplexity), checkpointing, and single/batch text generation.

In [1]:

import math
import random
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

# Reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [2]:

# Load IMDB (25k train / 25k test)
dataset = load_dataset("imdb")

# Load GPT-2 tokenizer; set PAD to EOS for convenience
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.pad_token, tokenizer.eos_token, tokenizer.vocab_size


('<|endoftext|>', '<|endoftext|>', 50257)

In [3]:

# To keep training light in this demo, you can reduce max_length to 256 (or keep 512 for longer context).
MAX_LEN = 256

def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
    )

# Remove columns we don't need after tokenization
train_ds = dataset["train"].map(tokenize_fn, batched=True, remove_columns=["text", "label"])
val_ds   = dataset["test"].map(tokenize_fn,  batched=True, remove_columns=["text", "label"])

# Make PyTorch-friendly
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

len(train_ds), len(val_ds), train_ds[0]["input_ids"].shape


(25000, 25000, torch.Size([256]))

In [4]:

BATCH_SIZE = 4

def collate(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)      # [B, L]
    attention_mask = torch.stack([b["attention_mask"] for b in batch],0) # [B, L]
    return input_ids, attention_mask

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

next(iter(train_loader))[0].shape


torch.Size([4, 256])

In [5]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = True):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.batch_first = batch_first

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, D] if batch_first else [L, B, D]
        if self.batch_first:
            x = x + self.pe[:, :x.size(1), :]
        else:
            x = x + self.pe[:, :x.size(0), :].transpose(0, 1)
        return self.dropout(x)


class GPTStyleDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim=512, num_layers=6, nhead=8, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout, batch_first=True)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=nhead, dim_feedforward=2048, dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embedding_dim, vocab_size)
        self.d_model = embedding_dim

    def forward(self, input_ids, attention_mask=None):
        # input_ids: [B, L], attention_mask: [B, L] (1=token, 0=pad)
        x = self.embedding(input_ids) * (self.d_model ** 0.5)   # [B, L, D]
        x = self.pos_encoder(x)

        L = x.size(1)
        # Causal mask: [L, L]
        causal_mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
        # src_key_padding_mask: True where we want to mask (i.e., PAD)
        src_key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        x = self.transformer(
            x,
            mask=causal_mask,
            src_key_padding_mask=src_key_padding_mask
        )  # [B, L, D]

        logits = self.fc(x)  # [B, L, V]
        return logits


In [6]:

EMBED_DIM = 768
LAYERS = 4
NHEAD = 8
DROPOUT = 0.1

model = GPTStyleDecoder(vocab_size=tokenizer.vocab_size, embedding_dim=EMBED_DIM, num_layers=LAYERS, nhead=NHEAD, dropout=DROPOUT).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

sum(p.numel() for p in model.parameters())/1e6


99.300945

In [7]:
import torch, gc

def cleanup_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

cleanup_cuda()


In [8]:
from tqdm import tqdm
import math
import os
import torch

EPOCHS = 3

def run_epoch(loader, train=True):
    model.train(train)
    total_loss = 0.0
    total_tokens = 0

    # Create progress bar
    pbar = tqdm(loader, desc="Training" if train else "Validation")
    
    for input_ids, attention_mask in pbar:  # Changed from loader to pbar
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(input_ids, attention_mask=attention_mask)  # [B, L, V]

        # Shift for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous().view(-1, tokenizer.vocab_size)   # [(B*(L-1)), V]
        shift_labels = input_ids[:, 1:].contiguous().view(-1)                           # [(B*(L-1))]

        loss = criterion(shift_logits, shift_labels)

        if train:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        # Count non-pad target tokens for averaging
        not_pad = (shift_labels != tokenizer.pad_token_id).sum().item()
        total_loss += loss.item() * not_pad
        total_tokens += not_pad

        avg_loss = total_loss / max(total_tokens, 1)
        pbar.set_postfix(loss=loss.item(), avg_loss=avg_loss)

    pbar.close()  # Close progress bar to avoid memory leaks
    
    avg_loss = total_loss / max(total_tokens, 1)
    ppl = math.exp(avg_loss) if avg_loss < 20 else float("inf")
    return avg_loss, ppl

for epoch in range(EPOCHS):
    train_loss, train_ppl = run_epoch(train_loader, train=True)
    val_loss, val_ppl = run_epoch(val_loader, train=False)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train loss {train_loss:.4f}, ppl {train_ppl:.2f} | Val loss {val_loss:.4f}, ppl {val_ppl:.2f}")

# Save final model
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/gptstyle_imdb.pt")
print("Saved to checkpoints/gptstyle_imdb.pt")


Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [06:50<00:00, 15.24it/s, avg_loss=5.6, loss=5.11]
Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [01:32<00:00, 67.85it/s, avg_loss=5.28, loss=5.3]


Epoch 1/3 | Train loss 5.6047, ppl 271.70 | Val loss 5.2768, ppl 195.74


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [06:49<00:00, 15.25it/s, avg_loss=5.03, loss=5.16]
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [01:32<00:00, 67.88it/s, avg_loss=5.05, loss=5.04]


Epoch 2/3 | Train loss 5.0273, ppl 152.52 | Val loss 5.0541, ppl 156.66


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [06:49<00:00, 15.25it/s, avg_loss=4.78, loss=4.71]
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6250/6250 [01:32<00:00, 67.93it/s, avg_loss=4.96, loss=4.96]


Epoch 3/3 | Train loss 4.7805, ppl 119.16 | Val loss 4.9591, ppl 142.47
Saved to checkpoints/gptstyle_imdb.pt


In [9]:

@torch.no_grad()
def generate(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = None,
    device: str = "cpu"
) -> str:
    model.eval()
    # Encode
    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    if len(input_ids) == 0:
        input_ids = [tokenizer.eos_token_id]
    generated = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, L]

    for _ in range(max_new_tokens):
        L = generated.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=device), diagonal=1).bool()
        attn_mask = torch.ones_like(generated, device=device)  # no pads
        logits = model(generated, attention_mask=attn_mask)  # [1, L, V]
        next_logits = logits[:, -1, :]  # [1, V]

        if temperature <= 0:
            next_token = torch.argmax(next_logits, dim=-1)
        else:
            next_logits = next_logits / temperature
            if top_k is not None:
                v, ix = torch.topk(next_logits, k=top_k, dim=-1)
                filtered = torch.full_like(next_logits, -float("inf"))
                filtered.scatter_(1, ix, v)
                next_logits = filtered
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)  # [1]

        generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)

    return tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)


@torch.no_grad()
def generate_batch(
    model,
    tokenizer,
    prompts: list[str],
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = None,
    device: str = "cpu"
) -> list[str]:
    model.eval()
    # Encode & left-pad to same length (we'll use pad_token_id)
    enc = [tokenizer.encode(p, add_special_tokens=False) for p in prompts]
    enc = [e if len(e) > 0 else [tokenizer.eos_token_id] for e in enc]
    max_len = max(len(e) for e in enc)

    B = len(enc)
    generated = torch.full((B, max_len), tokenizer.pad_token_id, dtype=torch.long, device=device)
    for i, e in enumerate(enc):
        generated[i, :len(e)] = torch.tensor(e, device=device)

    # We will treat initial pads as real tokens in attention_mask=1 for simplicity,
    # but a cleaner approach is to keep a true padding mask. We'll build a proper mask.
    attn_mask = (generated != tokenizer.pad_token_id).long()  # [B, L]

    for _ in range(max_new_tokens):
        L = generated.size(1)
        causal_mask = torch.triu(torch.ones(L, L, device=device), diagonal=1).bool()
        logits = model(generated, attention_mask=attn_mask)  # [B, L, V]
        next_logits = logits[:, -1, :]  # [B, V]

        if temperature <= 0:
            next_token = torch.argmax(next_logits, dim=-1)  # [B]
        else:
            next_logits = next_logits / temperature
            if top_k is not None:
                v, ix = torch.topk(next_logits, k=top_k, dim=-1)
                filtered = torch.full_like(next_logits, -float("inf"))
                filtered.scatter_(1, ix, v)
                next_logits = filtered
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)  # [B]

        # Append new step to all sequences
        generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
        # Update attention mask (new token is non-pad)
        attn_mask = torch.cat([attn_mask, torch.ones(B, 1, dtype=torch.long, device=device)], dim=1)

    # Decode each
    outputs = []
    for i in range(B):
        outputs.append(tokenizer.decode(generated[i].tolist(), skip_special_tokens=True))
    return outputs


In [10]:

# Load checkpoint (if re-running from scratch, comment out if model not yet trained)
state = torch.load("checkpoints/gptstyle_imdb.pt", map_location=device)
model.load_state_dict(state)

print(generate(model, tokenizer, "Once upon a time", max_new_tokens=40, temperature=0.2, top_k=50, device=device))




Once upon a time, I was a kid in the theater. I was so impressed by the movie, and it was just a good movie. I was really looking forward to seeing it. I was expecting a movie with


In [11]:
prompts = ["The future of AI is", "In a small town, there was"]
outs = generate_batch(model, tokenizer, prompts, max_new_tokens=40, temperature=0.8, top_k=50, device=device)
for p, o in zip(prompts, outs):
    print(f"\nPrompt: {p}\n{o}")


Prompt: The T-1000 character is
The T-1000 character is, the whole movie is even better than it. the other movie would be about. I give it the 1. The whole movie a 4. It's a shame that it never got better. It

Prompt: In a small town, there was
In a small town, there was a lot of things that went wrong.<br /><br />I am very glad that this is something that is in a way I can say about this film. It has a great message of the
