# üöÄ GPT

In this notebook, we'll walk through the steps required to train your own GPT model on the wine review dataset

The code is adapted from the excellent [GPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/) created by Apoorv Nandan available on the Keras website.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import re
import string
from IPython.display import display, HTML
import random

## 0. Parameters <a name="parameters"></a>

In [None]:
VOCAB_SIZE = 10000
MAX_LEN = 80
EMBEDDING_DIM = 256
KEY_DIM = 256
N_HEADS = 2
FEED_FORWARD_DIM = 256
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

## 1. Load the data <a name="load"></a>

In [None]:
with open("/app/data/wine-reviews/winemag-data-130k-v2.json") as f:
    wine_data = json.load(f)

# %%
# Filter dataset
filtered_data = [
    "wine review : "
    + x["country"]
    + " : "
    + x["province"]
    + " : "
    + x["variety"]
    + " : "
    + x["description"]
    for x in wine_data
    if all(x[k] is not None for k in ["country", "province", "variety", "description"])
]

print(f"{len(filtered_data)} entries loaded")

In [None]:
# Example entry
example = filtered_data[25]
print(example)

## 2. Tokenize the data <a name="tokenize"></a>

In [None]:
def pad_punctuation(s):
    s = re.sub(f"([{string.punctuation}, '\n'])", r" \1 ", s)
    s = re.sub(" +", " ", s)
    return s

text_data = [pad_punctuation(x.lower()) for x in filtered_data]
example_data = text_data[25]
print(example_data)

In [None]:
# %%
# Build vocabulary
# ‰ΩøÁî®ËØçÈ¢ëÈÄâÊã©Ââç VOCAB_SIZE ‰∏™ÂçïËØç‰Ωú‰∏∫ËØçÊ±áË°®
from collections import Counter

counter = Counter()
for line in text_data:
    counter.update(line.split())

most_common = counter.most_common(VOCAB_SIZE-2)
itos = ["<pad>", "<unk>"] + [w for w, _ in most_common]  # index to string
stoi = {w:i for i,w in enumerate(itos)}  # string to index

In [None]:
def text_to_tokens(text):
    return [stoi.get(t, stoi["<unk>"]) for t in text.split()]

tokenized_data = [text_to_tokens(line) for line in text_data]


## 3. Create the Training Set <a name="create"></a>

In [None]:
# Create Dataset
class WineDataset(Dataset):
    def __init__(self, tokenized_data, max_len):
        self.data = tokenized_data
        self.max_len = max_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.data[idx]
        tokens = tokens[:self.max_len+1]  # trim if longer
        x = tokens[:-1]
        y = tokens[1:]
        # pad
        if len(x) < self.max_len:
            pad_len = self.max_len - len(x)
            x = x + [0]*pad_len
            y = y + [0]*pad_len
        return torch.tensor(x), torch.tensor(y)

dataset = WineDataset(tokenized_data, MAX_LEN)
train_size = int(len(dataset)*(1-VALIDATION_SPLIT))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

## 5. Create the causal attention mask function <a name="causal"></a>

In [None]:
def causal_mask(seq_len):
    # mask[i,j] = True if j > i else False
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask  # shape: [seq_len, seq_len]

## 6. Create a Transformer Block layer <a name="transformer"></a>

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.ln2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x: [batch, seq_len, embed_dim]
        seq_len = x.size(1)
        mask = causal_mask(seq_len).to(x.device)
        attn_output, attn_weights = self.attn(x, x, x, attn_mask=mask)
        x = self.ln1(x + self.dropout(attn_output))
        x = self.ln2(x + self.dropout(self.ff(x)))
        return x, attn_weights

## 7. Create the Token and Position Embedding <a name="embedder"></a>

In [None]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, max_len, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_len, embed_dim)
    
    def forward(self, x):
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        return self.token_emb(x) + self.pos_emb(positions)

## 8. Build the Transformer model <a name="transformer_decoder"></a>

In [None]:
class GPT(nn.Module):
    def __init__(self, vocab_size, max_len, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.embedding = TokenAndPositionEmbedding(vocab_size, max_len, embed_dim)
        self.transformer = TransformerBlock(embed_dim, num_heads, ff_dim)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        x, attn = self.transformer(x)
        logits = self.fc_out(x)
        return logits, attn

model = GPT(VOCAB_SIZE, MAX_LEN, EMBEDDING_DIM, N_HEADS, FEED_FORWARD_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

## 9. Train the Transformer <a name="train"></a>

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits, _ = model(x)
            loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

# %%
for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss = evaluate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}/{EPOCHS} - train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}")


# 10. Generate text using the Transformer

In [None]:
def sample_from_logits(logits, temperature=1.0):
    probs = F.softmax(logits / temperature, dim=-1)
    token = torch.multinomial(probs, num_samples=1)
    return token.item()

def generate_text(model, start_prompt, max_tokens=80, temperature=1.0):
    model.eval()
    tokens = [stoi.get(w, stoi["<unk>"]) for w in start_prompt.split()]
    generated = tokens.copy()
    for _ in range(max_tokens):
        x = torch.tensor(generated[-MAX_LEN:], device=DEVICE).unsqueeze(0)
        logits, _ = model(x)
        next_token = sample_from_logits(logits[0, -1], temperature)
        if next_token == 0:
            break
        generated.append(next_token)
    return " ".join([itos[t] for t in generated])

In [None]:
print(generate_text(model, "wine review : us", max_tokens=50, temperature=1.0))
print(generate_text(model, "wine review : italy", max_tokens=50, temperature=0.5))
print(generate_text(model, "wine review : germany", max_tokens=50, temperature=0.5))