In [4]:
from trainer import sequence_accuracy
from data import Data
import pandas as pd
import numpy as np
import os
import torch
from config import SkanformerTestConfig
from tokenizer import Tokenizer
from fn_utils import get_model

In [5]:
# Special tokens & coressponding ids
BOS_IDX, PAD_IDX, EOS_IDX, UNK_IDX, SEP_IDX = 0, 1, 2, 3, 4
special_symbols = ['<S>', '<PAD>', '</S>', '<UNK>', '<SEP>']

In [13]:
config = SkanformerTestConfig(
    model_name="skanformer_64x2",
    root_dir="/pscratch/sd/r/ritesh11/SYMBA_test/Skanformer",
    data_dir="/pscratch/sd/r/ritesh11/SYMBA_test/SYMBA_test",
    device='cuda',
    embedding_size=512,
    nhead=8,
    num_layers=3,
    ff_dims=[4096],
    dropout=0.1,
    d_ff=4096,
    src_max_len=280,
    tgt_max_len=323,
    src_voc_size=None,
    tgt_voc_size=None,
    seed=42,
    to_replace=False,
    index_pool_size=200,
    momentum_pool_size=200,
    is_prefix=False
)

In [14]:
df_train = pd.read_csv(config.data_dir+"train.csv")
df_test = pd.read_csv(config.data_dir+"test.csv")
df_valid = pd.read_csv(config.data_dir+"valid.csv")

df = pd.concat([df_train,df_valid,df_test]).reset_index(drop=True)

In [15]:
tokenizer = Tokenizer(df, config.index_pool_size,config.momentum_pool_size, special_symbols, UNK_IDX,config.to_replace)
src_vocab = tokenizer.build_src_vocab(config.seed)
src_itos = {value: key for key, value in src_vocab.get_stoi().items()}
tgt_vocab = tokenizer.build_tgt_vocab()
tgt_itos = {value: key for key, value in tgt_vocab.get_stoi().items()}

Processing source vocab: 100%|██████████| 15552/15552 [00:00<00:00, 19815.61it/s]
Processing target vocab: 100%|██████████| 15552/15552 [00:01<00:00, 14384.89it/s]


In [16]:
config.src_voc_size = len(src_vocab)
config.tgt_voc_size = len(tgt_vocab)

In [17]:
model = get_model(config)

In [18]:
datasets = Data.get_data(
            df_train, df_test, df_valid, config, tokenizer,src_vocab, tgt_vocab)

test_ds = datasets['test']

In [19]:
preds = sequence_accuracy(config,test_ds,tgt_itos,test_size=len(test_ds))

Using epoch 19 model for predictions.


Seq_Acc_Cal: 100%|██████████| 1555/1555 [14:29<00:00,  1.79it/s, seq_accuracy=0.894]
