In [28]:
%load_ext autoreload
%autoreload 2

from utils import BPETokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
train_files = open("data/PY150K/python100k_train.txt", "r", encoding="utf-8").read().split("\n")[:-1] # remove the last empty line
train_texts = [open("data/PY150K/" + path).read() for path in train_files]

In [3]:
tokenizer = BPETokenizer() # ensure that all unary tokens in our training data exist in our vocabulary
tokenizer.fit("\n".join(train_texts[:1000]), iterations=100)
tokenizer.save("py150k_0")
# tokenizer = BPETokenizer.load("py150k_new")

In [4]:
tokenizer.tokenize("þla"), tokenizer.detokenize(tokenizer.tokenize("þla"))

([3, 82, 71], '<unk>la')

In [6]:
tokenizer.print_tokens(open("data/PY150K/" + train_files[8]).read())

[48;2;194;224;255mi[48;2;255;218;194mmp[48;2;194;255;208mor[48;2;255;194;224mt [48;2;218;255;194mb[48;2;194;224;255mo[48;2;255;218;194mto[48;2;194;255;208m
[48;2;255;194;224mi[48;2;218;255;194mmp[48;2;194;224;255mor[48;2;255;218;194mt [48;2;194;255;208mb[48;2;255;194;224mo[48;2;218;255;194mto[48;2;194;224;255m.[48;2;255;218;194ms[48;2;194;255;208m3[48;2;255;194;224m.[48;2;218;255;194mcon[48;2;194;224;255mn[48;2;255;218;194mec[48;2;194;255;208mtion[48;2;255;194;224m

[48;2;218;255;194mf[48;2;194;224;255mro[48;2;255;218;194mm[48;2;194;255;208m [48;2;255;194;224md[48;2;218;255;194mj[48;2;194;224;255man[48;2;255;218;194mg[48;2;194;255;208mo[48;2;255;194;224m.[48;2;218;255;194mcon[48;2;194;224;255mf [48;2;255;218;194mi[48;2;194;255;208mmp[48;2;255;194;224mor[48;2;218;255;194mt [48;2;194;224;255mse[48;2;255;218;194mt[48;2;194;255;208mti[48;2;255;194;224mn[48;2;218;255;194mg[48;2;194;224;255ms[48;2;255;218;194m

[48;2;194;255;208mi[48;2;255;194

In [8]:
from utils.dataset import Py150kDataset

ds = Py150kDataset("train", "py150k_new")
ds[1337]

tensor([ 79, 197, 157, 168, 164, 175,   5,  76, 181,  83,   6, 155,  75,  93,
         72,  85,  90,  20, 150, 235,  76,  71, 185, 194,  47,  50,  85,  77,
        187,  88,   6,  79, 197, 157, 168,  47,  50,  85,  77, 187,  88,   5,
         76, 181,  83,   6, 155,  75,  93,  72,  85,  90,  20, 196, 197, 151,
        169,  90, 194, 237,  74, 228,  75, 194,  44, 157, 209,  90, 164,  84,
         77,   6,  79, 197, 157, 168,  44, 157, 209,  90, 164,  84,  77, 176,
         73,  82, 180, 160,  50,  85,  77,  44,  79, 174,  14,  47,  50,  85,
         77, 187,  88,  15,  32,   5,   4, 189,   8,   5,   4,  39,  84,   6,
         79, 197, 174, 175,  84,  90, 178,  79, 151,   6,  85, 167,  47,  50,
         85,  77, 187,  88,   6, 159, 178,   6, 148,  84,  74, 160, 175,  89,
         89,  71, 187, 160, 203,   6,  71,   6, 192, 174,  20,   5,   4, 189,
          8, 176,   4, 204,  69,  69, 150,  79, 190,  69, 242, 152, 192, 174,
         69, 229, 159,   6,  15,  32,   5,   4,   4, 189,   8,  

One problem is that we need all sequences in a batch to be the same length, but there is a large difference in lengths

In [21]:
max(len(ds[i]) for i in range(100)), min(len(ds[i]) for i in range(100))

(72967, 23)

In [22]:
tokenizer.chr_to_ids[" "], tokenizer.PAD

(6, '<pad>')

In [23]:
from utils.dataset import Py150kDataset
from torch.utils.data import DataLoader, random_split

def collate_fn(batch:list[torch.tensor], max_len:int=1000):
    batch = [x[:max_len] for x in batch]
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=tokenizer.PAD_ID
    )



train_ds = Py150kDataset("train", "py150k_new")
small_ds, _ = random_split(train_ds, [1, len(train_ds) - 1])

train_dl = DataLoader(small_ds, batch_size=32, collate_fn=collate_fn)

For example purposes this will be a many-to-one encoder-decoder architecture. Our transformer atleast will probably be decoder only.

In [24]:
class PyRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.vocab_size, self.hidden_size = vocab_size, hidden_size
        
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x):
        x = self.embed(x)
        x, _ = self.rnn(x)
        x = self.linear(x)
        return x

    def train_step(self, x, y, teacher_forcing=0.5):
        B, T = x.shape
                
        xt = x[:, [0]]
        ht = torch.zeros(1, B, self.hidden_size, device=x.device)
        
        o = []     
        for i in range(T):
            xt = self.embed(xt)
            xt, ht = self.rnn(xt, ht)
            xt = self.linear(xt.squeeze(1))
            ot = F.softmax(xt, dim=-1)
            o.append(ot)
            
            if torch.rand(1) < teacher_forcing:
                xt = y[:, [i]] # put the correct token in the next step
            else:
                # sample from the distribution
                xt = torch.argmax(ot, dim=-1, keepdims=True) # put the predicted token in the next step
                
        return torch.stack(o, dim=1)
            
        
model = PyRNN(len(tokenizer), 128)
model(next(iter(train_dl))).shape

torch.Size([1, 217, 244])

https://wandb.ai/bjarnih/PyGPT

In [25]:
from tqdm import tqdm
import wandb

EPOCHS = 1000
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

wandb.init(
    # Set the project where this run will be logged
    project="PyGPT",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": LR,
        "epochs": EPOCHS,
        "architecture": "many-to-one RNN",
        "dataset": "small subset of PY150k",
    },
)

model = PyRNN(len(tokenizer), 128).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=LR)

model.train()
for i in range(EPOCHS):
    dl_tqdm = tqdm(train_dl)
    for batch in dl_tqdm:
        batch = batch.to(DEVICE)
        x = batch[..., :-1]
        y = batch[..., 1:]
        
        y_hat = model.train_step(x, y, teacher_forcing=0.5)
        loss = F.cross_entropy(y_hat.reshape(-1, len(tokenizer)), y.reshape(-1))

        optim.zero_grad()
        loss.backward()
        optim.step()

        wandb.log({"train_loss": loss.detach().numpy()})
        dl_tqdm.set_postfix({"loss": loss.detach().numpy()})

0,1
train_loss,█▇▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,5.12538


100%|██████████| 1/1 [00:00<00:00, 21.72it/s, loss=5.497308]
100%|██████████| 1/1 [00:00<00:00, 23.36it/s, loss=5.4969783]
100%|██████████| 1/1 [00:00<00:00, 23.87it/s, loss=5.4965234]
100%|██████████| 1/1 [00:00<00:00, 23.11it/s, loss=5.496264]
100%|██████████| 1/1 [00:00<00:00, 23.51it/s, loss=5.4961433]
100%|██████████| 1/1 [00:00<00:00, 22.69it/s, loss=5.495267]
100%|██████████| 1/1 [00:00<00:00, 22.73it/s, loss=5.494068]
100%|██████████| 1/1 [00:00<00:00, 23.23it/s, loss=5.493568]
100%|██████████| 1/1 [00:00<00:00, 23.39it/s, loss=5.4930725]
100%|██████████| 1/1 [00:00<00:00, 22.66it/s, loss=5.491974]
100%|██████████| 1/1 [00:00<00:00, 23.27it/s, loss=5.4901247]
100%|██████████| 1/1 [00:00<00:00, 22.49it/s, loss=5.4856973]
100%|██████████| 1/1 [00:00<00:00, 21.27it/s, loss=5.483435]
100%|██████████| 1/1 [00:00<00:00, 22.81it/s, loss=5.4767175]
100%|██████████| 1/1 [00:00<00:00, 22.48it/s, loss=5.4754615]
100%|██████████| 1/1 [00:00<00:00, 21.55it/s, loss=5.470697]
100%|██████████|

KeyboardInterrupt: 

In [26]:
def generate(model, tokenizer, device="cpu", max_len=100):
    model.eval()
    x = torch.tensor([9], device=device).unsqueeze(0)
    ht = torch.zeros(1, 1, model.hidden_size, device=device)
    
    o = []
    for i in range(max_len):
        xt = model.embed(x)
        xt, ht = model.rnn(xt, ht)
        xt = model.linear(xt.squeeze(1))
        ot = F.softmax(xt, dim=-1)
        o.append(ot)
        x = torch.argmax(ot, dim=-1, keepdims=True)
    
    model.train()
    return torch.stack(o, dim=1)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
text = generate(PyRNN(len(tokenizer), 128).to(DEVICE), tokenizer, DEVICE, 100)

In [27]:
text_tokens = text[0].argmax(dim=-1).cpu().numpy()
tokenizer.detokenize(text_tokens)

'\x93E"\x88_""argro[name^deYd \x96se\\ke*ème""ateS\xa0\x80ser[name^deYd \x96se\\ke*ème""ateS\xa0\x80ser[name^deYd \x96se\\ke*ème""ateS\xa0\x80ser[name^deYd \x96se\\ke*ème""ateS\xa0\x80ser[name^deYd \x96se\\ke*ème""ateS'