In [1]:
import torch
import pytorch_lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [2]:
import sys
try:
    del sys.modules['data_loading']
    del sys.modules['transformer_predictor']
except:
    pass

from data_loading import load_data, get_loaders, BracketDataset
from transformer_predictor import TransformerPredictor

In [3]:
dataset = load_data('../Data/train-CoT-Big.csv')

Generating sequences:   0%|          | 0/100000 [00:00<?, ?it/s]

Padding sequences:   0%|          | 0/100000 [00:00<?, ?it/s]

In [4]:
dataset.generate_output('()()(())')

'()()(()):0&)()(()):1&()(()):0&)(()):1&(()):0&()):1&)):2&):1&:0&Y'

In [5]:
dataset.generate_output('())()()')

'())()():0&))()():1&)()():0&()():#&N'

In [6]:
dataset[0][0].shape, dataset[0][1].shape

(torch.Size([2330]), torch.Size([2330]))

In [19]:
BATCH_SIZE = 8
train_loader, val_loader, test_loader, train_data, val_data, test_data = get_loaders(dataset, batch_size=BATCH_SIZE, return_data=True)

In [20]:
print(len(train_data), len(val_data), len(test_data), len(dataset))

70000 10000 20000 100000


In [21]:
def get_name_from_config(config):
    """
    Convert a config dict to the string under which the corresponding
    models and datasets will be saved.
    """
    return f'd_model={config["model_dim"]}-nhead={config["num_heads"]}-nlayers={config["num_layers"]}'

In [22]:
config = {
    'model_dim': 128,
    'num_heads': 8,
    'num_layers': 3,
    'lr': 1e-3,
    'dropout': 0.2,
}

In [23]:
model = TransformerPredictor(
    input_dim=18,
    model_dim=config['model_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    lr=config['lr'],
    dropout=config['dropout'],
)
name = get_name_from_config(config)
early_stopping =  EarlyStopping(monitor='val_loss', patience=3)
model_checkpoint = ModelCheckpoint(monitor='val_loss', save_top_k=5, dirpath='models/', filename=name)
trainer = L.Trainer(max_epochs=1, devices=1, callbacks=[early_stopping, model_checkpoint])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [24]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | TransformerLM    | 303 K 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
303 K     Trainable params
0         Non-trainable params
303 K     Total params
1.214     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [25]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

/home2/adyansh/dinner_pool/rsai/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
model = TransformerPredictor.load_from_checkpoint('models/retrain-d_model=128-nhead=8-nlayers=3.ckpt')

In [None]:
seq = '()&'
ctoi = {c: i for i, c in enumerate('()&YNP')}
itoc = {i: c for c, i in ctoi.items()}

src = torch.Tensor([ctoi[c] for c in seq]).long().unsqueeze(0).to(model.device)

out, eos = model.model.complete_sequence(src)
out = out.cpu().detach().numpy().tolist()

for i in out[0]:
    print(itoc[i], end='')

()&((((()))()))(((()(((()()))))))(((())(()(((()((()()))(((())(((&(((((()))()))(((()(((()()))))))(((())(()(((()((()()))(((())(((&((((())()))(((()(((()()))))))(((())(()(((()((()()))(((())(((&(((()()))(((()(((()()))))))(((())(()(((()((()()))(((())(((&(((()))(((()(((()()))))))(((())(()(((()((()()))(((())(((&((())(((()(((()()))))))(((())(()(((()((()()))(((())(((&(()(((()(((()()))))))(((())(()(((()((()()))(((())(((&((((()(((()()))))))(((())(()(((()((()()))(((())(((&(((((((()()))))))(((())(()(((()((()()))(((())(((&(((((((()))))))(((())(()(((()((()()))(((())(((&((((((())))))(((())(()(((()((()()))(((())(((&(((((()))))(((())(()(((()((()()))(((())(((&((((())))(((())(()(((()((()()))(((())(((&(((()))(((())(()(((()((()()))(((())(((&((())(((())(()(((()((()()))(((())(((&(()(((())(()(((()((()()))(((())(((&((((())(()(((()((()()))(((())(((&(((()(()(((()((()()))(((())(((&((((()(((()((()()))(((())(((&(((((((()((()()))(((())(((&(((((((((()()))(((())(((&(((((((((()))(((())(((&((((((((())(((())(((&(((((((()((