In [None]:
%load_ext autoreload
%autoreload 2

from utils import Py150kDataset

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


In [None]:
ds = Py150kDataset("train", "py150k")

One problem is that we need all sequences in a batch to be the same length, but there is a large difference in lengths

In [None]:
max(len(ds[i]) for i in range(100)), min(len(ds[i]) for i in range(100))

In [None]:
ds.tokenizer.PAD, ds.tokenizer.PAD_ID

In [None]:
from utils.dataset import Py150kDataset
from torch.utils.data import DataLoader, random_split

def collate_fn(batch:list[torch.Tensor], max_len:int=2048):
    batch = [x[:max_len] for x in batch]
    batch = [
        torch.cat([torch.tensor([ds.tokenizer.BOS_ID]), x, torch.tensor([ds.tokenizer.EOS_ID])])
        for x in batch
    ]
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=ds.tokenizer.PAD_ID
    )



train_ds, val_ds, _ = random_split(ds, [50000, 5000, len(ds) - 55000])
train_dl = DataLoader(train_ds, batch_size=64, collate_fn=collate_fn, prefetch_factor=4, num_workers=8, persistent_workers=True)
val_dl = DataLoader(val_ds, batch_size=64, collate_fn=collate_fn, prefetch_factor=4, num_workers=8, persistent_workers=True)

In [None]:
class PyRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.vocab_size, self.hidden_size = vocab_size, hidden_size
        
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x):
        x = self.embed(x)
        x, _ = self.rnn(x) # i.e. 100% teacher forcing
        x = self.linear(x)
        return x
    
    @torch.no_grad()
    def generate(self, tokenizer, max_len=1000):
        self.eval()
        device = next(self.parameters()).device
        xt = torch.tensor([tokenizer.BOS_ID], device=device).unsqueeze(0)
        ht = torch.randn(1, 1, self.hidden_size, device=device)
        
        tokens = []
        for _ in range(max_len):
            xt = self.embed(xt)
            xt, ht = self.rnn(xt, ht)
            xt = self.linear(xt.squeeze(1))
            xt = xt.argmax(dim=-1, keepdims=True)

            token = xt.item()
            tokens.append(token)
            if token == tokenizer.EOS_ID:
                break

        self.train()
        return ds.tokenizer.detokenize(tokens)
    
    def train_step(self, x, y, teacher_forcing=0.5): # much slower apparently, NOTE: could be broken
        B, T = x.shape
        xt = x[:, [0]]
        ht = torch.zeros(1, B, self.hidden_size, device=x.device)
        
        outputs = []
        for i in range(T):
            xt = self.embed(xt)
            xt, ht = self.rnn(xt, ht)
            xt = self.linear(xt.squeeze(1))
            outputs.append(xt)
            
            if torch.rand(1) < teacher_forcing:
                xt = y[:, [i]]  # Teacher forcing: use the correct next input
            else:
                xt = torch.argmax(xt, dim=-1, keepdims=True)  # No teacher forcing: use the model's prediction

        return torch.stack(outputs, dim=1)
            
        
model = PyRNN(len(ds.tokenizer), 128)
model(next(iter(train_dl))).shape

https://wandb.ai/bjarnih/PyGPT

In [None]:
from tqdm import tqdm
import wandb

EPOCHS = 10
LR = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


model = PyRNN(len(ds.tokenizer), 1024).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=ds.tokenizer.PAD_ID) # <PAD> tokens do not contribute to the loss
optim = torch.optim.Adam(model.parameters(), lr=LR)

wandb.init(
    project="PyGPT",
    config={
        "learning_rate": LR,
        "epochs": EPOCHS,
        "architecture": model.__class__.__name__,
        "n_training_examples": len(train_ds),
        "n_validation_examples": len(val_ds),
        "parameter_count": sum([p.numel() for p in model.parameters() if p.requires_grad])
    },
    group="baseline RNNs"
)


model.train()
for epoch in range(EPOCHS):
    train_tqdm = tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS} Training")
    total_train_loss = 0

    # Training loop
    for batch in train_tqdm:
        batch = batch.to(DEVICE)
        x = batch[..., :-1]
        y = batch[..., 1:]
        y_hat = model(x)
        loss = criterion(y_hat.reshape(-1, len(ds.tokenizer)), y.reshape(-1))

        optim.zero_grad()
        loss.backward()
        optim.step()

        train_loss = loss.detach().cpu().numpy()
        total_train_loss += train_loss
        train_tqdm.set_postfix({"loss": train_loss})

    wandb.log({"avg_train_loss": total_train_loss / len(train_dl)}, step=epoch)

    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        val_tqdm = tqdm(val_dl, desc=f"Epoch {epoch + 1}/{EPOCHS} Validation")
        for val_batch in val_tqdm:
            val_batch = val_batch.to(DEVICE)
            x_val = val_batch[..., :-1]
            y_val = val_batch[..., 1:]
            y_val_hat = model(x_val)
            val_loss = criterion(y_val_hat.reshape(-1, len(ds.tokenizer)), y_val.reshape(-1))
            total_val_loss += val_loss.detach().cpu().numpy()
            val_tqdm.set_postfix({"val_loss": val_loss})
    
    model.train()
    sample_output = model.generate(ds.tokenizer, max_len=1000)

    wandb.log({"generated_text": wandb.Html(ds.tokenizer.color_text_html(sample_output))}, step=epoch)
    wandb.log({"avg_val_loss": total_val_loss / len(val_dl)}, step=epoch)


In [None]:
# wandb.finish() # if we want to finish the run

In [None]:
gen = model.generate(ds.tokenizer, 10000)
print(ds.tokenizer.color_text_ansi(gen))

In [None]:
gen

In [None]:
print(ds.tokenizer.color_text_html(gen))