In [1]:
import random
import numpy as np
import csv

from tqdm import tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import torchtext
from torchtext.vocab import build_vocab_from_iterator

from TorchTransformer import *
from evaluation import *

In [2]:
# GLOBAL VARIABLES
PAD_IDX = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
MAX_LENGTH = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


Select which dataset to run experiments on by specifying the length.

In [None]:
DATASET = "20k"

In [3]:
def yield_tokens(lines):
    for line in lines:
        yield line.split()

def create_vocab(lang):
    vocab = build_vocab_from_iterator(yield_tokens(lang), specials=["<PAD>", "<SOS>", "<EOS>"], special_first=True)
    vocab.set_default_index(-1)
    return vocab

def read_data(dataset="all"):
    print("Reading lines...")

    # Read the file and split into lines
    lines = []
    with open('data_datasets/' + dataset + '.csv') as csvfile:
        reader = csv.reader(csvfile, delimiter='\t')
        for row in reader:
            lines.append(row)
            
    print(lines[0])
    
    lang1 = list(map(lambda x: x[0], lines))
    lang2 = list(map(lambda x: x[1], lines))

    return lang1, lang2, lines

def add_sequence_tokens(sentence):
    new_sentence = sentence.split()
    n_tokens = len(new_sentence)
    new_sentence.insert(0, "<SOS>")
    new_sentence.append("<EOS>")
    assert n_tokens + 2 == len(new_sentence)
    return " ".join(new_sentence)

def add_sequence_tokens_dataset(data):
    new_data = []
    for d in data:
        new_data.append(add_sequence_tokens(d))
    return new_data

def tensor_from_sentence(sentence, vocab):
    encoded = [vocab[word] for word in sentence.split()]
    return torch.tensor(encoded, dtype=torch.long, device=device)

def filter_pairs(pairs, MAX_LENGTH):
    filtered_pairs = []
    for s1, s2 in pairs:
        if len(s1.split()) < MAX_LENGTH and len(s2.split()) < MAX_LENGTH:
            filtered_pairs.append([s1, s2])
            
    lang1 = list(map(lambda x: x[0], filtered_pairs))
    lang2 = list(map(lambda x: x[1], filtered_pairs))
    print(filtered_pairs[0])
    return lang1, lang2, filtered_pairs

def pad_dataset(data, target_length):
    new_data = [] 
    for d in data:
        x = d.split()
        x = x + ["<PAD>"] * (target_length - len(x))
        new_data.append(" ".join(x))
    return new_data

def create_dataset(src_lang, src_vocab, trg_lang, trg_vocab, MAX_LENGTH):
    src_list = []
    trg_list = []
    
    src_lang = pad_dataset(src_lang, MAX_LENGTH)
    trg_lang = pad_dataset(trg_lang, MAX_LENGTH)
    
    for src, trg in zip(src_lang, trg_lang):
        src_list.append(tensor_from_sentence(src, src_vocab).to(device))
        trg_list.append(tensor_from_sentence(trg, trg_vocab).to(device))
    src_tensors = torch.stack(src_list)
    trg_tensors = torch.stack(trg_list)
    return TensorDataset(src_tensors, trg_tensors)

def create_dataset_new(src_lang, src_vocab, trg_lang, trg_vocab, MAX_LENGTH):
    src_list = []
    trg_list = []
    
    src_lang = add_sequence_tokens_dataset(src_lang)
    trg_lang = add_sequence_tokens_dataset(trg_lang)
    
    src_lang = pad_dataset(src_lang, MAX_LENGTH + 2)
    trg_lang = pad_dataset(trg_lang, MAX_LENGTH + 2)
    
    for src, trg in zip(src_lang, trg_lang):
        src_list.append(tensor_from_sentence(src, src_vocab).to(device))
        trg_list.append(tensor_from_sentence(trg, trg_vocab).to(device))
    src_tensors = torch.stack(src_list)
    trg_tensors = torch.stack(trg_list)
    return TensorDataset(src_tensors, trg_tensors)

In [4]:
_, _, train_pairs = read_data("dket_train_" + DATASET)
_, _, val_pairs = read_data("dket_validation_" + DATASET)

train_text, train_logic, _ = filter_pairs(train_pairs, MAX_LENGTH)
val_text, val_logic, _ = filter_pairs(val_pairs, MAX_LENGTH)
text_vocab = create_vocab(train_text + val_text)
logic_vocab = create_vocab(train_logic + val_logic)

print(len(train_text), len(train_logic))
print(len(val_text), len(val_logic))
print(len(text_vocab), print(len(logic_vocab)))


Reading lines...
['kernel summary of trunk forgive also principal of string or of fever .', 'kernel summary of trunk := E forgive . ( principal of string U principal of fever )']
Reading lines...
['every rural guilty kernel forgive or plug no summary of stateful trunk or of principal .', 'rural guilty kernel := ! E ( forgive ^ plug ) . ( summary of stateful trunk U summary of principal )']
Reading lines...
["personx returns to personx's work xintent to keep their job", 'person (x) & returns to (x,a) & work (a) -> to keep their job (x)']
["personx returns to personx's work xintent to keep their job", 'person (x) & returns to (x,a) & work (a) -> to keep their job (x)']
['kernel summary of trunk forgive also principal of string or of fever .', 'kernel summary of trunk := E forgive . ( principal of string U principal of fever )']
['every rural guilty kernel forgive or plug no summary of stateful trunk or of principal .', 'rural guilty kernel := ! E ( forgive ^ plug ) . ( summary of statefu

In [5]:
train_dataset = create_dataset_new(train_text, text_vocab, train_logic, logic_vocab, MAX_LENGTH)
val_dataset = create_dataset_new(val_text, text_vocab, val_logic, logic_vocab, MAX_LENGTH)
train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)


In [6]:
src_vocab_size = len(text_vocab)
trg_vocab_size = len(logic_vocab)
embed_dim=512
transformer = Transformer(src_vocab_size,
                          trg_vocab_size,
                          embed_size=embed_dim,
                          max_length=MAX_LENGTH+2,
                          dropout=0.1,
                          pad_idx=PAD_IDX).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(transformer):,} trainable parameters')

The model has 82,126,440 trainable parameters


In [7]:
class Scheduler(_LRScheduler):
    def __init__(self, 
                 optimizer: Optimizer,
                 dim_embed: int,
                 warmup_steps: int,
                 last_epoch: int=-1,
                 verbose: bool=False) -> None:

        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)
        
    def get_lr(self) -> float:
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups


def calc_lr(step, dim_embed, warmup_steps):
    return dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))

class TranslationLoss(nn.Module):
    def __init__(self, label_smoothing: float=0.0) -> None:
        super().__init__()
        self.loss_func = nn.CrossEntropyLoss(ignore_index    = PAD_IDX,
                                             label_smoothing = label_smoothing)

    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        vocab_size = logits.shape[-1]
        logits = logits.reshape(-1, vocab_size)
        labels = labels.reshape(-1).long()
        return self.loss_func(logits, labels)

In [8]:
def train(model: nn.Module,
          loader: DataLoader,
          loss_func: torch.nn.Module,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler) -> float:

    model.train() # train mode
    
    total_loss = 0
    num_batches = len(loader)

    for source, target in tqdm(loader):
        # feed forward
        logits = model(source[:, 1:], target[:, :-1]) #input lacking EOS

        # loss calculation
        loss = loss_func(logits, target[:, 1:]) #labels lacking SOS
        total_loss += loss.item()

        # back-prop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # learning rate scheduler
        if scheduler is not None:
            scheduler.step()

    # average training loss
    avg_loss = total_loss / num_batches
    return avg_loss

In [9]:
optimizer = optim.Adam(transformer.parameters(),
                   betas=(0.9, 0.98),
                   eps=1.0e-9)
criterion = TranslationLoss(label_smoothing=0.1)
scheduler = Scheduler(optimizer=optimizer, dim_embed=embed_dim, warmup_steps=4000)
print("Training", DATASET, ":")
best_res = float("inf")
for e in range(30):
    res = train(transformer, train_loader, criterion, optimizer, scheduler)
    print(res)
    if res < best_res:
        best_res = res
    else:
        break
print("------Training complete-----")

torch.save(transformer.state_dict(), "./models/" + "dket_" + DATASET + ".pt")


Training 20k :


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


6.815534914644381


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.95it/s]


4.636190796050782


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


3.991013415705282


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.93it/s]


3.5188538670158995


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


2.8477525474926155


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.93it/s]


2.130113551030144


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


1.6220525190853083


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


1.4759323494122052


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


1.436278932391645


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.95it/s]


1.4175962857164133


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.95it/s]


1.4064271370062051


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.95it/s]


1.3995328242786396


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


1.39464743335407


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.94it/s]


1.3895857197027237


100%|█████████████████████████████████████████| 313/313 [01:03<00:00,  4.95it/s]


1.3839232750213184


100%|█████████████████████████████████████████| 313/313 [01:02<00:00,  5.03it/s]


1.3797697057358373


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.07it/s]


1.376134115667008


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.05it/s]


1.3734247585455068


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.07it/s]


1.3710491600128027


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.09it/s]


1.3688881701935594


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.08it/s]


1.3676086313808307


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.07it/s]


1.3662923730600376


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.07it/s]


1.3647111174397575


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.07it/s]


1.3636334132843506


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.09it/s]


1.3629326131016302


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.09it/s]


1.3619346409179152


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.06it/s]


1.3610689354399903


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.09it/s]


1.3601902612862877


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.09it/s]


1.3598189258727784


100%|█████████████████████████████████████████| 313/313 [01:01<00:00,  5.08it/s]


1.3590542218936519
------Training complete-----


In [13]:
def evaluation(model, val_dataset, src_vocab, trg_vocab):
    src_itos = src_vocab.get_itos()
    trg_itos = trg_vocab.get_itos()
    
    model.eval()
    
    reference_tokens = []
    predicted_tokens = []
    for source, target in tqdm(val_dataset):
        logits = model(source[1:].unsqueeze(0), target[:-1].unsqueeze(0)) #input lacking EOS
        golden = [trg_itos[t] for t in target if t != PAD_IDX][1:]
        reference_tokens.append(golden)
        target_tokens = []
        for word in logits.tolist()[-1]:
            guess = trg_itos[np.argmax(word)]
            target_tokens.append(guess)
            if guess == "<EOS>":
                break
        predicted_tokens.append(target_tokens)
        """
        if target_tokens == golden:
            print([src_itos[s] for s in source if s != PAD_IDX][1:])
            print(target_tokens, golden)
        """
    
    formula = average_formula_accuracy(predicted_tokens, reference_tokens, write_results=False)
    token = average_token_accuracy(predicted_tokens, reference_tokens)
    edit_distance = average_ld(predicted_tokens, reference_tokens)

    return formula, token, edit_distance


In [14]:
print("Evaluating model:")
formula, token, edit_distance = evaluation(transformer, val_dataset, text_vocab, logic_vocab)
print("average formula accuracy: ", formula)
print("average token accuracy: ", token)
print("average edit distance: ", edit_distance)
print("-----------------------")


Evaluating model:


100%|█████████████████████████████████████| 60000/60000 [49:27<00:00, 20.22it/s]
60000it [00:15, 3813.93it/s]

average formula accuracy:  0.9488333333333333
average token accuracy:  0.9963426528984816
average edit distance:  0.054966666666666664
-----------------------



