In [1]:
import random
import numpy as np
import csv

from tqdm import tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import torchtext
from torchtext.vocab import build_vocab_from_iterator

from TorchTransformer import *
from evaluation import *

In [2]:
# GLOBAL VARIABLES
PAD_IDX = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
MAX_LENGTH = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


Set which dataset you want to run the experiment on with the DATSET variable, the name should be the same as the filename found in the atomic_datasets folder without the fileformat.

In [None]:
DATASET = "all_dataset"

In [3]:
def yield_tokens(lines):
    for line in lines:
        yield line.split()

def create_vocab(lang):
    vocab = build_vocab_from_iterator(yield_tokens(lang), specials=["<PAD>", "<SOS>", "<EOS>"], special_first=True)
    vocab.set_default_index(-1)
    return vocab

def read_data(dataset="all"):
    print("Reading lines...")

    # Read the file and split into lines
    lines = []
    with open('atomic_datasets/' + dataset + '.csv') as csvfile:
        reader = csv.reader(csvfile, delimiter='\t')
        for row in reader:
            lines.append(row)
    
    lang1 = list(map(lambda x: x[0], lines))
    lang2 = list(map(lambda x: x[1], lines))

    return lang1, lang2, lines

def add_sequence_tokens(sentence):
    new_sentence = sentence.split()
    n_tokens = len(new_sentence)
    new_sentence.insert(0, "<SOS>")
    new_sentence.append("<EOS>")
    assert n_tokens + 2 == len(new_sentence)
    return " ".join(new_sentence)

def add_sequence_tokens_dataset(data):
    new_data = []
    for d in data:
        new_data.append(add_sequence_tokens(d))
    return new_data

def tensor_from_sentence(sentence, vocab):
    encoded = [vocab[word] for word in sentence.split()]
    return torch.tensor(encoded, dtype=torch.long, device=device)

def filter_pairs(pairs, MAX_LENGTH):
    filtered_pairs = []
    for s1, s2 in pairs:
        if len(s1.split()) < MAX_LENGTH and len(s2.split()) < MAX_LENGTH:
            filtered_pairs.append([s1, s2])
            
    lang1 = list(map(lambda x: x[0], filtered_pairs))
    lang2 = list(map(lambda x: x[1], filtered_pairs))
    return lang1, lang2, filtered_pairs

def pad_dataset(data, target_length):
    new_data = [] 
    for d in data:
        x = d.split()
        x = x + ["<PAD>"] * (target_length - len(x))
        new_data.append(" ".join(x))
    return new_data

def create_dataset(src_lang, src_vocab, trg_lang, trg_vocab, MAX_LENGTH):
    src_list = []
    trg_list = []
    
    src_lang = pad_dataset(src_lang, MAX_LENGTH)
    trg_lang = pad_dataset(trg_lang, MAX_LENGTH)
    
    for src, trg in zip(src_lang, trg_lang):
        src_list.append(tensor_from_sentence(src, src_vocab).to(device))
        trg_list.append(tensor_from_sentence(trg, trg_vocab).to(device))
    src_tensors = torch.stack(src_list)
    trg_tensors = torch.stack(trg_list)
    return TensorDataset(src_tensors, trg_tensors)

def create_dataset_new(src_lang, src_vocab, trg_lang, trg_vocab, MAX_LENGTH):
    src_list = []
    trg_list = []
    
    src_lang = add_sequence_tokens_dataset(src_lang)
    trg_lang = add_sequence_tokens_dataset(trg_lang)
    
    src_lang = pad_dataset(src_lang, MAX_LENGTH + 2)
    trg_lang = pad_dataset(trg_lang, MAX_LENGTH + 2)
    
    for src, trg in zip(src_lang, trg_lang):
        src_list.append(tensor_from_sentence(src, src_vocab).to(device))
        trg_list.append(tensor_from_sentence(trg, trg_vocab).to(device))
    src_tensors = torch.stack(src_list)
    trg_tensors = torch.stack(trg_list)
    return TensorDataset(src_tensors, trg_tensors)

In [4]:
_, _, atomic_pairs = read_data(DATASET)
text, logic, _ = filter_pairs(atomic_pairs, MAX_LENGTH)
text_vocab = create_vocab(text)
logic_vocab = create_vocab(logic)

print(len(text), len(logic))
print(len(text_vocab), print(len(logic_vocab)))

Reading lines...
["personx returns to personx's work xintent to keep their job", 'person (x) & returns to (x,z) & work (z) -> to keep their job (x)']
["personx returns to personx's work xintent to keep their job", 'person (x) & returns to (x,z) & work (z) -> to keep their job (x)']
571793 571793
25943
25950 None


In [5]:
total_dataset = create_dataset_new(text, text_vocab, logic, logic_vocab, MAX_LENGTH)
train_set, val_30k = torch.utils.data.random_split(total_dataset, [len(total_dataset) - 30000, 30000])
train_20k, _ = torch.utils.data.random_split(train_set, [20000, len(train_set) - 20000])
train_10k, _ = torch.utils.data.random_split(train_set, [10000, len(train_set) - 10000])
train_5k, _ = torch.utils.data.random_split(train_set, [5000, len(train_set) - 5000])
train_2k, _ = torch.utils.data.random_split(train_set, [2000, len(train_set) - 2000])


In [6]:
loader_val_30k = DataLoader(val_30k, batch_size=128)
loader_20k = DataLoader(train_20k, batch_size=128)
loader_10k = DataLoader(train_10k, batch_size=128)
loader_5k = DataLoader(train_5k, batch_size=128)
loader_2k = DataLoader(train_2k, batch_size=128)

training_loaders = [loader_2k,
                    loader_5k,
                   loader_10k,
                   loader_20k]

In [7]:
src_vocab_size = len(text_vocab)
trg_vocab_size = len(logic_vocab)
embed_dim=512
transformer2k = Transformer(src_vocab_size,
                          trg_vocab_size,
                          embed_size=embed_dim,
                          max_length=MAX_LENGTH+2,
                          dropout=0.1,
                          pad_idx=PAD_IDX).to(device)
transformer5k = Transformer(src_vocab_size,
                          trg_vocab_size,
                          embed_size=embed_dim,
                          max_length=MAX_LENGTH+2,
                          dropout=0.1,
                          pad_idx=PAD_IDX).to(device)
transformer10k = Transformer(src_vocab_size,
                          trg_vocab_size,
                          embed_size=embed_dim,
                          max_length=MAX_LENGTH+2,
                          dropout=0.1,
                          pad_idx=PAD_IDX).to(device)
transformer20k = Transformer(src_vocab_size,
                          trg_vocab_size,
                          embed_size=embed_dim,
                          max_length=MAX_LENGTH+2,
                          dropout=0.1,
                          pad_idx=PAD_IDX).to(device)

transformers = [transformer2k,
                transformer5k,
                transformer10k,
                transformer20k]

transformer_names = [DATASET + "_2k",
                     DATASET + "_5k",
                     DATASET + "_10k",
                     DATASET + "_20k"]

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(transformer2k):,} trainable parameters')

The model has 84,071,767 trainable parameters


In [8]:
class Scheduler(_LRScheduler):
    def __init__(self, 
                 optimizer: Optimizer,
                 dim_embed: int,
                 warmup_steps: int,
                 last_epoch: int=-1,
                 verbose: bool=False) -> None:

        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)
        
    def get_lr(self) -> float:
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups


def calc_lr(step, dim_embed, warmup_steps):
    return dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))

class TranslationLoss(nn.Module):
    def __init__(self, label_smoothing: float=0.0) -> None:
        super().__init__()
        self.loss_func = nn.CrossEntropyLoss(ignore_index    = PAD_IDX,
                                             label_smoothing = label_smoothing)

    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        vocab_size = logits.shape[-1]
        logits = logits.reshape(-1, vocab_size)
        labels = labels.reshape(-1).long()
        return self.loss_func(logits, labels)

In [9]:
def train(model: nn.Module,
          loader: DataLoader,
          loss_func: torch.nn.Module,
          optimizer: torch.optim.Optimizer,
          scheduler: torch.optim.lr_scheduler._LRScheduler) -> float:

    model.train() # train mode
    
    total_loss = 0
    num_batches = len(loader)

    for source, target in tqdm(loader):
        # feed forward
        logits = model(source[:, 1:], target[:, :-1]) #input lacking EOS

        # loss calculation
        loss = loss_func(logits, target[:, 1:]) #labels lacking SOS
        total_loss += loss.item()

        # back-prop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # learning rate scheduler
        if scheduler is not None:
            scheduler.step()

    # average training loss
    avg_loss = total_loss / num_batches
    return avg_loss

In [10]:
for transformer, data_loader, name in zip(transformers, training_loaders, transformer_names):
    optimizer = optim.Adam(transformer.parameters(),
                       betas=(0.9, 0.98),
                       eps=1.0e-9)
    criterion = TranslationLoss(label_smoothing=0.1)
    scheduler = Scheduler(optimizer=optimizer, dim_embed=embed_dim, warmup_steps=4000)
    print("Training", name, ":")
    best_res = float("inf")
    for e in range(30):
        res = train(transformer, data_loader, criterion, optimizer, scheduler)
        print(res)
        if res < best_res:
            best_res = res
        else:
            break
    print("------Training complete-----")
    
    torch.save(transformer.state_dict(), "./models/" + name + ".pt")


Training all_2k :


100%|███████████████████████████████████████████| 16/16 [00:14<00:00,  1.07it/s]


10.148587822914124


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.17it/s]


9.099200248718262


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.17it/s]


8.21500739455223


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.17it/s]


7.734296560287476


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.17it/s]


7.424902617931366


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.16it/s]


7.002890646457672


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.16it/s]


6.296623736619949


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.15it/s]


5.614317715167999


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.15it/s]


5.107887238264084


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.14it/s]


4.7389858067035675


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.14it/s]


4.4936904311180115


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.14it/s]


4.313005596399307


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.13it/s]


4.161639928817749


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.14it/s]


4.030516102910042


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.13it/s]


3.9141005128622055


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.13it/s]


3.806545630097389


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.12it/s]


3.7112599164247513


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.6260686814785004


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.12it/s]


3.5469915121793747


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.477140963077545


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.4136456698179245


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.355578362941742


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.303591251373291


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.247339442372322


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.10it/s]


3.1934486031532288


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.11it/s]


3.139032542705536


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.10it/s]


3.088616192340851


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.10it/s]


3.0366455763578415


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.10it/s]


2.9813808649778366


100%|███████████████████████████████████████████| 16/16 [00:03<00:00,  5.10it/s]


2.9220018833875656
------Training complete-----
Training all_5k :


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.09it/s]


9.36887047290802


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.10it/s]


7.7140900015830995


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.09it/s]


6.516689586639404


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.09it/s]


5.08208976984024


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.08it/s]


4.404511547088623


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.08it/s]


4.054335874319077


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.08it/s]


3.8003475069999695


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.08it/s]


3.6051990032196044


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.07it/s]


3.461033356189728


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.07it/s]


3.350733518600464


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.07it/s]


3.253147131204605


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.07it/s]


3.145345610380173


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


3.0396419405937194


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.9402716934680937


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.06it/s]


2.843238526582718


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.04it/s]


2.753033608198166


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.04it/s]


2.6711892008781435


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.5930818021297455


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.06it/s]


2.5120399713516237


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.428777125477791


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.35033301115036


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.2747700572013856


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.2024526327848433


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.130171298980713


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


2.0553923189640044


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.05it/s]


1.9885954856872559


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.00it/s]


1.9220196843147277


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.10it/s]


1.8575576305389405


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.11it/s]


1.7970307648181916


100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.11it/s]


1.7371515393257142
------Training complete-----
Training all_10k :


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


8.519396631023552


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


5.746402812909476


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


4.236686631094051


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


3.7181610004811345


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


3.4179586579528034


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


3.1908141782012165


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.989384313172932


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.808180389525015


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.641065244433246


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.4760788452776175


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.3212369786033147


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.177481705629373


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.05it/s]


2.0453022416633897


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.9294219273555129


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.8268735514411443


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.7425004092952874


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.6673842698712893


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.6047959312607971


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.551860762547843


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.510283796093132


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.4740986069546471


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.03it/s]


1.4455520337141012


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.4229220906390418


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.4041222107561329


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.3908566689189477


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.3827855722813667


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.3754343609266644


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]


1.373536877994296


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.07it/s]


1.3715233229383637


100%|███████████████████████████████████████████| 79/79 [00:15<00:00,  5.07it/s]


1.370408046094677
------Training complete-----
Training all_20k :


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


7.12200946564887


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.02it/s]


4.003927965832364


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  4.99it/s]


3.3561640985452446


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  4.98it/s]


2.9987285850913663


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  4.98it/s]


2.7166697098191377


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.01it/s]


2.47628172795484


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


2.2339039775216656


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.985795756054532


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.7964606361024698


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.6629075335848862


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.5732078590210836


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.5118560092464375


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.03it/s]


1.4687758585449997


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.04it/s]


1.4386598342543195


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.01it/s]


1.417622746935316


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.4029819380705524


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.3917973702120934


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.3843259310266773


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.3791786811913653


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.3771297498873085


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.00it/s]


1.3763046188718955


100%|█████████████████████████████████████████| 157/157 [00:31<00:00,  5.01it/s]


1.3768620673258594
------Training complete-----


In [14]:
def evaluation(model, val_dataset, src_vocab, trg_vocab):
    src_itos = src_vocab.get_itos()
    trg_itos = trg_vocab.get_itos()
    
    model.eval()
    
    reference_tokens = []
    predicted_tokens = []
    for source, target in tqdm(val_dataset):
        logits = model(source[1:].unsqueeze(0), target[:-1].unsqueeze(0)) #input lacking EOS
        golden = [trg_itos[t] for t in target if t != PAD_IDX][1:]
        reference_tokens.append(golden)
        target_tokens = []
        for word in logits.tolist()[-1]:
            guess = trg_itos[np.argmax(word)]
            target_tokens.append(guess)
            if guess == "<EOS>":
                break
        predicted_tokens.append(target_tokens)
    
    formula = average_formula_accuracy(predicted_tokens, reference_tokens, write_results=True)
    token = average_token_accuracy(predicted_tokens, reference_tokens)
    edit_distance = average_ld(predicted_tokens, reference_tokens)

    return formula, token, edit_distance


In [15]:
for transformer in transformers:
    print("Evaluating model:")
    formula, token, edit_distance = evaluation(transformer, val_30k, text_vocab, logic_vocab)
    print("average formula accuracy: ", formula)
    print("average token accuracy: ", token)
    print("average edit distance: ", edit_distance)
    print("-----------------------")


Evaluating model:


100%|█████████████████████████████████████| 30000/30000 [26:02<00:00, 19.20it/s]
30000it [00:05, 5472.27it/s]


average formula accuracy:  3.3333333333333335e-05
average token accuracy:  0.7383611677865762
average edit distance:  4.211566666666666
-----------------------
Evaluating model:


100%|█████████████████████████████████████| 30000/30000 [26:13<00:00, 19.06it/s]
30000it [00:05, 5481.69it/s]


average formula accuracy:  0.09436666666666667
average token accuracy:  0.8785327710245566
average edit distance:  1.9728666666666668
-----------------------
Evaluating model:


100%|█████████████████████████████████████| 30000/30000 [26:45<00:00, 18.68it/s]
30000it [00:05, 5351.08it/s]


average formula accuracy:  0.6861333333333334
average token accuracy:  0.9749083070138329
average edit distance:  0.3900666666666667
-----------------------
Evaluating model:


100%|█████████████████████████████████████| 30000/30000 [27:46<00:00, 18.01it/s]
30000it [00:05, 5333.35it/s]

average formula accuracy:  0.8076666666666666
average token accuracy:  0.9859253554849771
average edit distance:  0.22463333333333332
-----------------------



