In [None]:
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from torch import nn

In [None]:
import torch
from transformers import BertModel

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    print('Current device:', torch.cuda.get_device_name(device))
    torch.cuda.empty_cache()
else:
    print('Failed to find GPU. Will use CPU.')
    device = 'cpu'

In [None]:
from utils.Constants import tokenizer, PRE_TRAINED_MODEL_NAME
from dataLoader.DataLoader import get_data_loader

dev_data_loader = get_data_loader(data_type = 'train', tokenizer = tokenizer, batch_size = 1)
test_data_loader = get_data_loader(data_type = 'test', tokenizer = tokenizer, batch_size = 1)

In [None]:
from models.QABert import QABertTrainer
from models.SelectRanker import SelectRankerTrainer
from models.WhereRanker import WhereRankerTrainer
from models.WhereConditionClassifier import WhereConditionClassifierTrainer
from models.WhereNumberClassifier import WhereNumberClassifierTrainer
from models.AggregationClassifier import AggregationClassifierTrainer

bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
for param in bert.parameters():
    param.requires_grad = False

models = dict(
    selection_trainer = SelectRankerTrainer(device, dev_data_loader, bert),
    agg_class_trainer = AggregationClassifierTrainer(device, dev_data_loader, bert, use_pretrained=False),
    where_ranker_trainer = WhereRankerTrainer(device, dev_data_loader, bert),
    where_cond_class_trainer = WhereConditionClassifierTrainer(device, dev_data_loader, bert, use_pretrained=False),
    where_numb_class_trainer = WhereNumberClassifierTrainer(device, dev_data_loader, bert),
    qa_trainer = QABertTrainer(device, dev_data_loader, bert, use_pretrained=False),
)

In [None]:
from train.Trainer import train_epoch, save_model, load_model

#models = load_model("./checkpoint/16_Mar_2021_19_21", dev_data_loader, device)
writer = SummaryWriter(log_dir = "runs/")

train_epoch(
    models = models,
    train_data_loader = dev_data_loader,
    eval_data_loader = test_data_loader,
    device = device,
    batch_size = 16, report_size = 8, eval_size = 64,
    writer = writer
)
save_model(models,"./checkpoint")

In [None]:
import json

losses = {
    f'{key}_loss': [loss if type(loss) == int or type(loss) == float else loss.item() for loss in value.losses] for key, value in models.items()
}
# losses['qa_trainer_loss']
with open('data/losses.json', 'w') as outfile:
    json.dump(losses, outfile)

In [None]:
models['qa_trainer'].losses[0].item()

In [None]:
from matplotlib.pyplot import plot

for key, loss in losses.items():
    plot(loss, label=key)
plt.legend()
plt.show()