In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader


from loader import prepare_data
from model import GPT

In [None]:
@dataclass
class GPTConfig:
    """
    Dataclass containing data, model, and training settings.
    """
    train_frac:int = 0.7
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    block_size:int = 128
    n_layer:int = 8
    n_head:int = 8
    n_embd:int = 32
    batch_size:int = 128
    weight_decay:float = 1e-4
    learning_rate:float = 1e-3
    n_epochs:int = 5
    seed:int = 2024

# Setup
config = GPTConfig()

# Seed
torch.manual_seed(config.seed)

# Dataloaders
train_dataset, valid_dataset, tokenizer = prepare_data(config)
config.vocab_size = len(tokenizer.vocab)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)


In [None]:
# Model
model = GPT(config)
model.to(config.device)
# Optimizer setup
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}

# weight decay applied only to tensor weights of order >= 2
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
    {"params": decay_params, 'weight_decay': config.weight_decay},
    {"params": nodecay_params, "weight_decay": 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)

print(f"Num decayed parameter tensors: {len(decay_params)}. Total decayed parameters: {num_decay_params}.")
print(f"Num non-decayed parameter tensors: {len(nodecay_params)}. Total non-decayed parameters: {num_nodecay_params}.")

# AdamW Optimizer
fused_avail = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_avail and 'mps' in config.device # torch supports mps for fused adamW
print(f"Using fused AdamW: {use_fused}")

In [None]:

optimizer = torch.optim.AdamW(optim_groups, lr=config.learning_rate, betas=(0.9,0.95), eps=1e-8, fused=use_fused)
#optimizer = torch.optim.AdamW(optim_groups, lr=config.learning_rate, betas=(0.9,0.95), eps=1e-8)

# Note, for simplicity, we will not use a LR scheduler. This may change in the future.

# Training

#@torch.compile
def train_step(net:torch.nn.Module, context:torch.tensor, targets:torch.tensor):
    optimizer.zero_grad()
    logits = net(context)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    loss.backward()
    optimizer.step()
    return loss

#@torch.compile
def valid_step(net:torch.nn.Module, context:torch.tensor, targets:torch.tensor):
    logits = net(context)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    return loss



In [None]:
# Training loop
print_freq = 100
valid_freq = 1
for epoch in range(config.n_epochs):
    epoch_lr = optimizer.param_groups[0]['lr']
    print(f'Epoch {epoch}/{config.n_epochs}, learning rate {epoch_lr}')

    # Training step
    model.train()
    for step, (xb, yb) in enumerate(train_loader):
        loss = train_step(model, xb, yb)

        # Log training loss
        if step % print_freq == 0:
           print(f'Train Step {step}/{len(train_loader)} - Loss: {loss.item()}')

    # Validation step
    model.eval()
    if epoch % valid_freq == 0:
        valid_loss = []
        with torch.no_grad():
            for xb, yb in valid_loader:
                loss = valid_step(model, xb, yb)
                valid_loss.append(loss.item())

        # Log validation loss
        val_loss = sum(valid_loss) / len(valid_loss)
        print(f'Validation Loss: {val_loss}')


In [None]:
# Generation
# Generate some random text from the trained model.
max_new_tokens = 10000
past_tokens = torch.tensor(tokenizer.encode('PROGRAM:\nHello world!')).to(config.device)
model.eval()
generated_tokens = model.generate(past_tokens[None,:], max_new_tokens)
generated_text = tokenizer.decode(generated_tokens[0].tolist())

with open('generated_text.txt', 'w') as file:
    file.write(generated_text)