In [1]:
import sys
sys.path.append('..')

%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

import transformers
import tokenizers

from new_semantic_parsing import EncoderDecoderWPointerModel
from new_semantic_parsing import TopSchemaTokenizer
from new_semantic_parsing.utils import DataCollator, InputDataClass

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained('distilbert-base-uncased')

vocab = {'[', ']', 'IN:', 'SL:', 'GET_DIRECTIONS', 'DESTINATION',
         'DATE_TIME_DEPARTURE', 'GET_ESTIMATED_ARRIVAL'}
schema_tokenizer = TopSchemaTokenizer(vocab, tokenizer)

In [4]:
model = EncoderDecoderWPointerModel.from_parameters(
    layers=3, hidden=128, heads=4,
    src_vocab_size=tokenizer.vocab_size, tgt_vocab_size=schema_tokenizer.vocab_size
)

In [5]:
source_texts = [
    'Directions to Lowell',
    'Get directions to Mountain View',
]
schema_texts = [
    '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]',
    '[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View]]'
]

source_ids = tokenizer.batch_encode_plus(source_texts, pad_to_max_length=True)['input_ids']
schema_batch = schema_tokenizer.batch_encode_plus(
    schema_texts, source_ids, pad_to_max_length=True, return_tensors='pt'
)

print(source_ids)
print(schema_batch)

[[101, 7826, 2000, 15521, 102, 0, 0], [101, 2131, 7826, 2000, 3137, 3193, 102]]
{'input_ids': tensor([[ 7,  3,  6, 10, 11,  7,  1,  5, 12,  4,  4,  0,  0],
        [ 7,  3,  6, 10, 11, 12,  7,  1,  5, 13, 14,  4,  4]]), 'attention_mask': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])}


In [6]:
source_ids = torch.LongTensor(source_ids)
source_ids_mask = ((source_ids != tokenizer.pad_token_id) &
                   (source_ids != tokenizer.cls_token_id) &
                   (source_ids != tokenizer.sep_token_id)).type(torch.FloatTensor)
source_ids_mask

tensor([[0., 1., 1., 1., 0., 0., 0.],
        [0., 1., 1., 1., 1., 1., 0.]])

In [7]:
class MockDataset(torch.utils.data.Dataset):
    def __len__(self): return 2

    def __getitem__(self, i):
        return InputDataClass(**{
            'input_ids': source_ids[i],
            'attention_mask': source_ids_mask[i],
            'decoder_input_ids': schema_batch['input_ids'][i],
            'decoder_attention_mask': schema_batch['attention_mask'][i],
            'labels': schema_batch['input_ids'][i],
        })

def compute_metrics(eval_prediction: transformers.EvalPrediction):
    predictions = np.argmax(eval_prediction.predictions, axis=-1)
    accuracy = np.mean(predictions.reshape(-1) == eval_prediction.label_ids.reshape(-1))

    return {
        'accuracy': accuracy,
    }

In [12]:
train_args = transformers.TrainingArguments(
    output_dir='output_debug',
    do_train=True,
    num_train_epochs=100,
    seed=42,
)

trainer = transformers.Trainer(
    model,
    train_args,
    train_dataset=MockDataset(),
    data_collator=DataCollator(),
    eval_dataset=MockDataset(),
    compute_metrics=compute_metrics,
)

# a trick to reduce the amount of logging
trainer.is_local_master = lambda: False

In [13]:
trainer.train()

TrainOutput(global_step=100, training_loss=0.18632038921117783)

In [14]:
trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=1.0, style=ProgressStyle(description_wid…


{"eval_loss": 0.07112160325050354, "eval_accuracy": 1.0, "epoch": 100.0, "step": 100}


{'eval_loss': 0.07112160325050354, 'eval_accuracy': 1.0, 'epoch': 100.0}

In [17]:
!rm -r output_debug
!rm -r runs