In [None]:
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm, trange
from torch import nn

In [None]:
import torch

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    print('Current device:', torch.cuda.get_device_name(device))
else:
    print('Failed to find GPU. Will use CPU.')
    device = 'cpu'

In [None]:
from utils.Constants import tokenizer
from dataLoader.DataLoader import get_data_loader
from dataLoader.DataLoaderUtils import get_question_answers_for_where_value_def_length, get_question_answers_def_length

dev_data_loader = get_data_loader(data_type = 'dev', tokenizer = tokenizer, batch_size = 1)

In [None]:
from models.QABert import QABertTrainer
from models.SelectRanker import SelectRankerTrainer
from models.WhereRanker import WhereRankerTrainer
from models.WhereConditionClassifier import WhereConditionClassifierTrainer
from models.WhereNumberClassifier import WhereNumberClassifierTrainer
from models.AggregationClassifier import AggregationClassifierTrainer
from train.Trainer import train_epoch, save_model, load_model

models = dict(
    selection_trainer = SelectRankerTrainer(device, dev_data_loader),
    agg_class_trainer = AggregationClassifierTrainer(device, dev_data_loader),
    where_ranker_trainer = WhereRankerTrainer(device, dev_data_loader),
    where_cond_class_trainer = WhereConditionClassifierTrainer(device, dev_data_loader),
    where_numb_class_trainer = WhereNumberClassifierTrainer(device, dev_data_loader),
    qa_trainer = QABertTrainer(device, dev_data_loader),
)

In [None]:
#models = load_model("./checkpoint/16_Mar_2021_19_21", dev_data_loader, device)
train_epoch(models, dev_data_loader, device)
save_model(models,"./checkpoint")