In [14]:
%load_ext autoreload
%autoreload 2

from utils import Py150kDataset

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
ds = Py150kDataset("train", "py150k")

One problem is that we need all sequences in a batch to be the same length, but there is a large difference in lengths

In [16]:
max(len(ds[i]) for i in range(100)), min(len(ds[i]) for i in range(100))

(75159, 20)

In [20]:
from utils.dataset import Py150kDataset
from utils.tokenizer import BOS_ID, EOS_ID, PAD_ID
from torch.utils.data import DataLoader, random_split

def collate_fn(batch:list[torch.Tensor], max_len:int=2048):
    batch = [x[:max_len] for x in batch]
    batch = [
        torch.cat([torch.tensor([BOS_ID]), x, torch.tensor([EOS_ID])])
        for x in batch
    ]
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=PAD_ID
    )



train_ds, val_ds, _ = random_split(ds, [5000, 5000, len(ds) - 10000])
train_dl = DataLoader(train_ds, batch_size=64, collate_fn=collate_fn)#, prefetch_factor=4, num_workers=8, persistent_workers=True)
val_dl = DataLoader(val_ds, batch_size=64, collate_fn=collate_fn)#, prefetch_factor=4, num_workers=8, persistent_workers=True)

In [65]:
class PyRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.vocab_size, self.hidden_size = vocab_size, hidden_size
        
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, h=None):
        x = self.embed(x)
        x, h = self.rnn(x) # i.e. 100% teacher forcing
        x = self.linear(x)
        return x, h
            
        
model = PyRNN(len(ds.tokenizer), 128)
out, h = model(next(iter(train_dl)))
out.shape, h.shape

(torch.Size([64, 2050, 376]), torch.Size([1, 64, 128]))

https://wandb.ai/bjarnih/PyGPT

In [66]:
from tqdm import tqdm
import wandb

EPOCHS = 20
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# model = PyRNN(len(ds.tokenizer), 128).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID) # <PAD> tokens do not contribute to the loss
optim = torch.optim.Adam(model.parameters(), lr=LR)

# wandb.init(
#     project="PyGPT",
#     config={
#         "learning_rate": LR,
#         "epochs": EPOCHS,
#         "architecture": model.__class__.__name__,
#         "n_training_examples": len(train_ds),
#         "n_validation_examples": len(val_ds),
#         "parameter_count": sum([p.numel() for p in model.parameters() if p.requires_grad])
#     },
#     group="baseline RNNs"
# )


model.train()
for epoch in range(EPOCHS):
    train_tqdm = tqdm(train_dl, desc=f"Epoch {epoch + 1}/{EPOCHS} Training")
    total_train_loss = 0

    # Training loop
    for batch in train_tqdm:
        batch = batch.to(DEVICE)
        x = batch[..., :-1]
        y = batch[..., 1:]
        y_hat, _ = model(x)
        loss = criterion(y_hat.reshape(-1, len(ds.tokenizer)), y.reshape(-1))

        optim.zero_grad()
        loss.backward()
        optim.step()

        train_loss = loss.detach().cpu().numpy()
        total_train_loss += train_loss
        train_tqdm.set_postfix({"loss": train_loss})

    # wandb.log({"avg_train_loss": total_train_loss / len(train_dl)}, step=epoch)

    # model.eval()
    # total_val_loss = 0
    # with torch.no_grad():
    #     val_tqdm = tqdm(val_dl, desc=f"Epoch {epoch + 1}/{EPOCHS} Validation")
    #     for val_batch in val_tqdm:
    #         val_batch = val_batch.to(DEVICE)
    #         x_val = val_batch[..., :-1]
    #         y_val = val_batch[..., 1:]
    #         y_val_hat, _ = model(x_val)
    #         loss = criterion(y_val_hat.reshape(-1, len(ds.tokenizer)), y_val.reshape(-1))
    #         val_loss = loss.detach().cpu().numpy()
    #         total_val_loss += val_loss
    #         val_tqdm.set_postfix({"val_loss": val_loss})
    
    # model.train()

    # wandb.log({"generated_text": wandb.Html(ds.tokenizer.color_text_html(sample_output))}, step=epoch)
    # wandb.log({"avg_val_loss": total_val_loss / len(val_dl)}, step=epoch)


Epoch 1/20 Training:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s, loss=4.0909]   
Epoch 2/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s, loss=3.4710417]
Epoch 3/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s, loss=3.1435838]
Epoch 4/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.32it/s, loss=2.920678] 
Epoch 5/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s, loss=2.7627494]
Epoch 6/20 Training: 100%|██████████| 79/79 [00:58<00:00,  1.35it/s, loss=2.640872] 
Epoch 7/20 Training: 100%|██████████| 79/79 [00:58<00:00,  1.36it/s, loss=2.5459306]
Epoch 8/20 Training: 100%|██████████| 79/79 [01:00<00:00,  1.31it/s, loss=2.4681885]
Epoch 9/20 Training: 100%|██████████| 79/79 [01:00<00:00,  1.31it/s, loss=2.4031415]
Epoch 10/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.32it/s, loss=2.3481948]
Epoch 11/20 Training: 100%|██████████| 79/79 [00:59<00:00,  1.33it/s, loss=2.301025] 
Epoch 12/20 Training: 100%|██████████| 79/79 [01:01<00:00,  1.2

KeyboardInterrupt: 

In [73]:
sum([p.numel() for p in model.parameters() if p.requires_grad])

129656

In [164]:
from utils.sampling import sample_with_temp, nucleus_sample

@torch.no_grad()
def generate(model, max_len=1000, starting_tokens:list[int]=None)->str:
    model.eval()
    device = next(model.parameters()).device
    xt = torch.tensor([[BOS_ID] + (starting_tokens or [])], device=device)
    ht = torch.randn(1, 1, model.hidden_size, device=device)
    
    tokens = starting_tokens or []
    for _ in range(max_len):
        xt = model.embed(xt)
        xt, ht = model.rnn(xt, ht)
        xt = model.linear(xt)
        xt = nucleus_sample(xt[:,-1,:], nucleus_threshold=0.9)
        # xt = sample_with_temp(xt[:,-1,:], temperature=1.0)
        token = xt.item()
        if token == EOS_ID:
            break
        tokens.append(token)

    model.train()
    return ds.tokenizer.detokenize(tokens)

In [163]:
starting_code = """
import math
def is_prime(n):
    if n < 2:
        return False
    for i in range(2, int(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

def is_not_prime(n):
"""
starting_tokens = ds.tokenizer.tokenize(starting_code)

code = generate(model, starting_tokens=starting_tokens)
print(ds.tokenizer.color_text_ansi(code))

[48;2;194;224;255m
[48;2;255;218;194mde[48;2;194;255;208mf [48;2;255;194;224mad[48;2;218;255;194md_[48;2;194;224;255mt[48;2;255;218;194mw[48;2;194;255;208mo[48;2;255;194;224m_[48;2;218;255;194mn[48;2;194;224;255mum[48;2;255;218;194mb[48;2;194;255;208mer[48;2;255;194;224ms([48;2;218;255;194ma[48;2;194;224;255m, [48;2;255;218;194mb[48;2;194;255;208m):[48;2;255;194;224m
[48;2;218;255;194m  [48;2;194;224;255m  [48;2;255;218;194m  [48;2;194;255;208m  [48;2;255;194;224mo[48;2;218;255;194mbj[48;2;194;224;255me[48;2;255;218;194mct[48;2;194;255;208m [48;2;255;194;224m= [48;2;218;255;194mt[48;2;194;224;255man[48;2;255;218;194mp[48;2;194;255;208ming[48;2;255;194;224m.[48;2;218;255;194mf[48;2;194;224;255mro[48;2;255;218;194mm[48;2;194;255;208m_[48;2;255;194;224mh[48;2;218;255;194mo[48;2;194;224;255md[48;2;255;218;194mma[48;2;194;255;208mx[48;2;255;194;224mi[48;2;218;255;194mf[48;2;194;224;255mf[48;2;255;218;194mer[48;2;194;255;208mm[48;2;255;194;224

In [165]:
from utils.search import beam_search

tokens = beam_search(model, beam_width=2, max_length=10, starting_tokens=starting_tokens)
code = ds.tokenizer.detokenize(tokens)
print(ds.tokenizer.color_text_ansi(code))

[48;2;194;224;255m
[48;2;255;218;194mde[48;2;194;255;208mf [48;2;255;194;224mad[48;2;218;255;194md_[48;2;194;224;255mt[48;2;255;218;194mw[48;2;194;255;208mo[48;2;255;194;224m_[48;2;218;255;194mn[48;2;194;224;255mum[48;2;255;218;194mb[48;2;194;255;208mer[48;2;255;194;224ms([48;2;218;255;194ma[48;2;194;224;255m, [48;2;255;218;194mb[48;2;194;255;208m):[48;2;255;194;224m
[48;2;218;255;194m  [48;2;194;224;255m  [48;2;255;218;194m  [48;2;194;255;208m  [48;2;255;194;224mo[48;2;218;255;194mbj[48;2;194;224;255me[48;2;255;218;194mct[48;2;194;255;208m [48;2;255;194;224m= [48;2;218;255;194mt[48;2;194;224;255man[48;2;255;218;194mp[48;2;194;255;208ming[48;2;255;194;224m.[48;2;218;255;194mf[48;2;194;224;255mro[48;2;255;218;194mm[48;2;194;255;208m_[48;2;255;194;224mh[48;2;218;255;194mo[48;2;194;224;255md[48;2;255;218;194mma[48;2;194;255;208mx[48;2;255;194;224mi[48;2;218;255;194mf[48;2;194;224;255mf[48;2;255;218;194mer[48;2;194;255;208mm[48;2;255;194;224