In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from models import load_config, model_from_config

config = load_config("transformer_small")
model = model_from_config(config)

x = torch.ones(1, 512).long()
model(x, x)

number of parameters: 3.31M


(tensor([[[-0.0364,  0.9826,  0.2212,  ...,  0.3989,  0.3396,  0.6336],
          [-0.4094,  0.5396,  0.7782,  ...,  0.2645, -0.0385,  0.7553],
          [-0.0253,  0.6822,  0.0424,  ...,  0.2293,  0.4163, -0.0330],
          ...,
          [-0.2189,  0.5299,  0.2904,  ..., -0.1705,  0.1550,  0.0511],
          [-0.4040,  0.5109, -0.0689,  ..., -0.1278,  0.4097, -0.0200],
          [-0.3201,  0.6264, -0.0254,  ..., -0.0381,  0.1772, -0.1399]]],
        grad_fn=<UnsafeViewBackward0>),
 tensor(5.6912, grad_fn=<NllLossBackward0>))

In [3]:
model = model.cuda()

In [4]:
model.generate(torch.zeros(32,8).long().cuda(), 100)

tensor([[  0,   0,   0,  ..., 169, 391, 175],
        [  0,   0,   0,  ..., 515,  95, 259],
        [  0,   0,   0,  ...,  81,  20, 515],
        ...,
        [  0,   0,   0,  ..., 337, 379, 300],
        [  0,   0,   0,  ..., 188,  58, 515],
        [  0,   0,   0,  ..., 426,  19, 406]], device='cuda:0')

In [8]:
batch = torch.randint(0, 570, (32, 512)).cuda()
B, L = batch.shape
context = int(0.25 * L)
pred_length = 2 * context
x = batch[:, :context]
y = batch[:, context:pred_length]

y_hat = model.generate(x, L)[:, context:pred_length]

batch.shape, x.shape, y.shape

(torch.Size([32, 512]), torch.Size([32, 128]), torch.Size([32, 128]))

In [14]:
from utils.metrics import bleu_score, syntax_error_score
bleu_score(y.tolist(), y_hat.tolist(), n_gram=4)

0.0

In [16]:
from utils.tokenizer import BOS_ID, BPETokenizer

tokenizer = BPETokenizer.load("py150k_large")
y_hat = model.generate(torch.tensor([[BOS_ID]*B]).long().cuda(), L)

programs = [tokenizer.detokenize(gen_seq) for gen_seq in y_hat.tolist()]
syntax_error_score(programs)

0.0

In [None]:
from models import model_from_checkpoint
from utils.tokenizer import BPETokenizer, BOS_ID, EOS_ID, PAD_ID

model = model_from_checkpoint("medium-lstm-run/epoch_8.pt", device="cpu")
tokenizer = BPETokenizer.load("py150k_large")

In [None]:
from utils.dataset import MemmapDataset

ds = MemmapDataset("train", "py150k_large")
ds.tokenizer.detokenize(ds[0].tolist())

In [None]:
samples = model.generate(4, 100, nucleus_threshold=0.5)
programs = [tokenizer.detokenize(tokens) for tokens in samples]

for program in programs[:10]:
    print(tokenizer.color_text_ansi(program))
    print()

In [None]:
from utils.search import beam_search

code = """class ListV"""

tokens = beam_search(model, max_length=100, starting_tokens=tokenizer.tokenize(code))
print(tokenizer.detokenize(tokens))

In [None]:
for sample in samples:
    print(sample)

In [None]:
import torch

code = """class """

tokens = torch.tensor(tokenizer.tokenize(code)).unsqueeze(0)

pred = model.generate(1, max_len=100, starting_tokens=tokens, nucleus_threshold=0.5)

print(tokenizer.color_text_ansi(tokenizer.detokenize(pred[0])))

In [None]:
print(tokenizer.detokenize(pred[0]))

In [None]:
from torch.utils.data import ConcatDataset, DataLoader, random_split

def collate_fn(batch:list[torch.Tensor], max_len:int=2048):
    batch = [x[:max_len] for x in batch]
    batch = [
        torch.cat([torch.tensor([BOS_ID]), x, torch.tensor([EOS_ID])])
        for x in batch
    ]
    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=PAD_ID)


train_ds = MemmapDataset("train", "py150k_large")
val_ds = MemmapDataset("eval", "py150k_large")

train_dl = DataLoader(train_ds, batch_size=32, collate_fn=collate_fn)
val_dl = DataLoader(val_ds, batch_size=32, collate_fn=collate_fn)

In [None]:
from utils.metrics import bleu_score, syntax_error_score

batch = next(iter(val_dl))
B, L = batch.shape
# avoid going to the end of the batch (may be padded)
context = int(0.25 * L)
pred_length = 2 * context
x = batch[:, :context]
y = batch[:, context:pred_length]
y_hat = model.generate(B, max_len=L, starting_tokens=x, nucleus_threshold=0.5)
y_hat = [seq[context:pred_length] for seq in y_hat]
avg_bleu_score = bleu_score(y.tolist(), y_hat, n_gram=4)
            
gen = model.generate(B, max_len=200, nucleus_threshold=0.5)
programs = [tokenizer.detokenize(gen_seq) for gen_seq in gen]
avg_syntax_error_score = syntax_error_score(programs)

avg_bleu_score, avg_syntax_error_score

In [None]:
programs

In [None]:
tokenizer.detokenize(y[-2].tolist())

In [None]:
tokenizer.detokenize(y_hat[1])