In [1]:
import sys
sys.path.append("src")
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import scfg as cfg
from utils.checkpoints import Checkpoints
from asdl.parser import parse as parse_asdl
from data.conala import ConalaDataset
from model.copy_lstm import EncoderLSTM, DecoderLSTM, EncoderDecoder
from main import calculate_errors, calculate_loss, train_epoch, evaluate

In [2]:
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

special_tokens = ["[PAD]", "[UNK]", "[SOS]", "[EOS]", "[SOA]", "[EOA]"]

# load Python ASDL grammar
grammar = parse_asdl("src/asdl/Python.asdl")

# load CoNaLa intent-snippet pairs and map them to tensors
# FIXME: train_ds and dev_ds are using different word and action mappings.
# This causes dev performance to become completely nonsense.
dataset_cache = cfg.model_dir / "dataset_cache.pt"
if dataset_cache.exists():
	train_ds, dev_ds = torch.load(dataset_cache)
	print("Loaded dataset cache")
else:
	train_ds = ConalaDataset(
		"./data/conala-train.json", grammar=grammar, special_tokens=special_tokens
	)
	dev_ds = ConalaDataset(
		"./data/conala-dev.json", grammar=grammar, special_tokens=special_tokens,
		 action_vocab=train_ds.action_vocab, intent_vocab=train_ds.intent_vocab, shuffle=False
	)
	torch.save((train_ds, dev_ds), dataset_cache)

Loaded dataset cache


In [3]:
ids = train_ds.convert_intent_ids("Fastest Way to Drop Duplicated Index in a Pandas DataFrame".lower())
print(ids)
tokens = train_ds.convert_ids_intent(ids)
print(tokens)
actions = train_ds.convert_ids_action([6, 7, 10, 5, 0])
print(actions)
ids = train_ds.convert_action_ids(actions)
print(ids)

[829, 293, 14, 371, 544, 42, 9, 7, 43, 24]
['fastest', 'way', 'to', 'drop', 'duplicated', 'index', 'in', 'a', 'pandas', 'dataframe']
[('Reduce',), ('ApplyConstr', 'Load'), ('ApplyConstr', 'Constant'), '[EOA]', '[PAD]']
[6, 7, 10, 5, 0]


In [5]:
# initialize the model and optimizer
encoder = EncoderLSTM(
	vocab_size=train_ds.intent_vocab_size,
	device=cfg.device,
	**cfg.EncoderLSTM,
)
decoder = DecoderLSTM(
	action_size=train_ds.action_vocab_size,
	device=cfg.device,
	**cfg.DecoderLSTM,
)
model = EncoderDecoder(encoder, decoder)
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)

In [6]:
def run_epoch(model, ds, optimizer, training=True):
    model.train()
    celoss = nn.CrossEntropyLoss()
    losses = []
    for batch in load(ds, training):
        if training:
            optimizer.zero_grad()
        inputs, labels = batch
        inputs = inputs.transpose(1, 0)
        labels = labels.transpose(1, 0)
        input_mask = (inputs == 0)
        label_mask = (labels == 0)
        input_lens = (~input_mask).to(torch.long).sum(dim=1).cpu()
        label_lens = (~label_mask).to(torch.long).sum(dim=1).cpu()
        logits = model(inputs, labels, input_mask, label_mask, input_lens, label_lens)
        act_vocab_l = logits.shape[-1]
        loss = celoss(logits.reshape(-1, act_vocab_l), labels.reshape(-1))
        losses.append(float(loss))
        if training:
            loss.backward()
            optimizer.step()
    return losses
def load(ds, shuffle=True):
    def collate_fn(data):
        return (
            torch.nn.utils.rnn.pad_sequence([input for input, _ in data]).to(cfg.device),
            torch.nn.utils.rnn.pad_sequence([label for _, label in data]).to(cfg.device),
        )
    return torch.utils.data.DataLoader(
        ds,
        batch_size=cfg.batch_size,
        shuffle=shuffle,
        collate_fn=collate_fn
    )

In [7]:
for i in range(10):
	losses = run_epoch(model, train_ds, optimizer)
	print(f"Training loss: {np.mean(losses)}")
	losses = run_epoch(model, dev_ds, optimizer, training=False)
	print(f"Valid loss: {np.mean(losses)}")

Training loss: 0.8714710238600979
Valid loss: 0.3985938012599945
Training loss: 0.3149361566585653
Valid loss: 0.2877434434990088
Training loss: 0.2226134584355755
Valid loss: 0.24412120406826338
Training loss: 0.1758715825922349
Valid loss: 0.2355222647388776
Training loss: 0.14425291239964863
Valid loss: 0.20250578299164773
Training loss: 0.11360694370976016
Valid loss: 0.14760115842024485
