In [23]:
import json
import math

import torch
import torch.utils.data as data
import numpy as np

from squad.squad_dataset import SquadDataset
from training.model_trainer import Trainer
from training.utils import *
from model.QANet import QANet

training_set = SquadDataset("./preprocessed_dataset/training_set_features.npz", False)
validation_set = SquadDataset("./preprocessed_dataset/validation_set_features.npz", False)
validation_set_eval = json.load(open("./preprocessed_dataset/validation_set_eval.json"))

In [15]:
training_set_loader = data.DataLoader(
    training_set, 
    batch_size=32,
    shuffle=False
)

validation_set_loader = data.DataLoader(
    validation_set, 
    batch_size=32,
    shuffle=False
)

In [4]:
word_emb = torch.from_numpy(np.load('./preprocessed_dataset/glove_embeddings.npz')['emb_mat']).double()
chr_emb = torch.from_numpy(np.load('./preprocessed_dataset/char_embeddings.npz')['emb_mat']).double()

In [5]:
config = json.load(open("./config.json"))
model = QANet(word_emb, chr_emb, config).double().to("cuda:0")

In [9]:
ema = EMA(model, 0.9999)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, betas=(0.8, 0.999), eps=1e-8, weight_decay=3e-7)
cr = 1.0 / math.log(1000)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * math.log(ee + 1) if ee < 1000 else 1)

QANet(
  (input_embedding_layer): InputEmbeddingLayer(
    (word_embed): Embedding(83941, 300)
    (char_embed): CharacterEmbedding(
      (char_embeddings): Embedding(1309, 200)
      (conv1): Conv2d(200, 200, kernel_size=(1, 5), stride=(1, 1))
    )
    (highway): Highway(
      (t_gates): ModuleList(
        (0): Linear(in_features=500, out_features=500, bias=True)
        (1): Linear(in_features=500, out_features=500, bias=True)
      )
      (h_gates): ModuleList(
        (0): Linear(in_features=500, out_features=500, bias=True)
        (1): Linear(in_features=500, out_features=500, bias=True)
      )
    )
  )
  (embedding_encoder_layer): EncoderEmbeddingLayer(
    (conv1d): Reshape1Dconv(
      (conv1d): Conv1d(500, 128, kernel_size=(1,), stride=(1,))
    )
    (c_pos_encoder): PositionalEncoder(
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (q_pos_encoder): PositionalEncoder(
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder_blocks): ModuleList(
     

In [None]:
model.train()
trainer = Trainer(model, optimizer, "cuda:0", scheduler, 5.0, ema, 1, "model_checkpoints")
trainer.train(10, training_set_loader, validation_set_loader, validation_set_eval)

100%|█| 82240/82240 [16:19<00:00, 83.98it/s, avg_loss=6.47, batch_ce=4.92, epoch=1, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 272.79it/s, batch_ce=5.13, time=Feb-06_17-03]


{'EM': 21.904205607476637, 'F1': 32.5156667452638}
Saving checkpoint: model_checkpoints/epoch01_f1_32.51567_em_21.90421.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.02it/s, avg_loss=5.08, batch_ce=3.87, epoch=2, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.31it/s, batch_ce=4.02, time=Feb-06_17-20]


{'EM': 31.230529595015575, 'F1': 43.096097242628474}
Saving checkpoint: model_checkpoints/epoch02_f1_43.09610_em_31.23053.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.01it/s, avg_loss=4.08, batch_ce=1.81, epoch=3, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.43it/s, batch_ce=2.96, time=Feb-06_17-36]


{'EM': 45.40498442367601, 'F1': 59.52451417759945}
Saving checkpoint: model_checkpoints/epoch03_f1_59.52451_em_45.40498.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.03it/s, avg_loss=3.29, batch_ce=1.34, epoch=4, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.06it/s, batch_ce=2.83, time=Feb-06_17-53]


{'EM': 50.95404984423676, 'F1': 65.65826486682144}
Saving checkpoint: model_checkpoints/epoch04_f1_65.65826_em_50.95405.pth.tar ...


100%|█| 82240/82240 [16:19<00:00, 83.98it/s, avg_loss=2.93, batch_ce=1.32, epoch=5, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.56it/s, batch_ce=2.65, time=Feb-06_18-10]


{'EM': 53.07632398753894, 'F1': 67.54032156767454}
Saving checkpoint: model_checkpoints/epoch05_f1_67.54032_em_53.07632.pth.tar ...


100%|█| 82240/82240 [16:19<00:00, 83.99it/s, avg_loss=2.7, batch_ce=0.972, epoch=6, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.43it/s, batch_ce=2.61, time=Feb-06_18-26]


{'EM': 53.81619937694704, 'F1': 68.65606202942605}
Saving checkpoint: model_checkpoints/epoch06_f1_68.65606_em_53.81620.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.01it/s, avg_loss=2.53, batch_ce=1.01, epoch=7, learning_rate=[0.001], time=F


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.45it/s, batch_ce=2.51, time=Feb-06_18-43]


{'EM': 54.61448598130841, 'F1': 69.31418849278492}
Saving checkpoint: model_checkpoints/epoch07_f1_69.31419_em_54.61449.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.02it/s, avg_loss=2.39, batch_ce=0.882, epoch=8, learning_rate=[0.001], time=


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.78it/s, batch_ce=2.68, time=Feb-06_19-00]


{'EM': 54.94548286604361, 'F1': 69.64133319474541}
Saving checkpoint: model_checkpoints/epoch08_f1_69.64133_em_54.94548.pth.tar ...


100%|█| 82240/82240 [16:18<00:00, 84.01it/s, avg_loss=2.28, batch_ce=0.954, epoch=9, learning_rate=[0.001], time=


Validating...


100%|█████████████████████████████████████| 5136/5136 [00:18<00:00, 273.44it/s, batch_ce=2.41, time=Feb-06_19-16]


{'EM': 54.94548286604361, 'F1': 69.47861040226148}
Saving checkpoint: model_checkpoints/epoch09_f1_69.47861_em_54.94548.pth.tar ...


 93%|▉| 76512/82240 [15:10<01:08, 83.90it/s, avg_loss=2.18, batch_ce=0.9, epoch=10, learning_rate=[0.001], time=F