In [1]:
import numpy as np
import pandas as pd
import torch.nn as nn

import keras
import os
import re
import torch 
import transformers
import gzip

from collections import Counter, defaultdict
from itertools import islice
from tqdm import trange, tqdm
from torch.optim import Adam
from transformers import BertTokenizer, BertModel, BertForQuestionAnswering
from transformers import AdamW
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import get_linear_schedule_with_warmup

from eval_squad import *
from models.QAModels import *
from models.utils import *
from utils import *

Using TensorFlow backend.


GPU not available, CPU used
GPU not available, CPU used


In [2]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

## Load QA data into memory

In [3]:
# load SubjQA_data into memory
subjqa_data_train = get_data(
                             source='/SubjQA/',
                             split='/train',
                             domain='all',
)

subjqa_data_dev = get_data(
                           source='/SubjQA/',
                           split='/dev',
                           domain='all',
)

subjqa_data_test = get_data(
                            source='/SubjQA/',
                            split='/test',
                            domain='all',
)

# convert pd.DataFrames into list of dictionaries (as many dicts as examples)
subjqa_data_train = convert_df_to_dict(
                                       subjqa_data_train,
                                       split='train',
)
subjqa_data_dev = convert_df_to_dict(
                                     subjqa_data_dev,
                                     split='dev',
)
subjqa_data_test = convert_df_to_dict(
                                      subjqa_data_test,
                                      split='test',
)

In [4]:
# load SQuAD_data into memory
squad_data_train = get_data(
                            source='/SQuAD/',
                            split='train',
)

#NOTE: we don't have correct answer spans (i.e., start and end positions) for SQuAD dev set (predictions need to be submitted)
squad_data_test = get_data(
                          source='/SQuAD/',
                          split='dev',
)

## Create train and dev QA examples

In [5]:
# TODO: figure out, whether we should use pretrained weights from 'bert-base-cased' or 'bert-base-uncased' model
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# choose pretrained weights
pretrained_weights = 'bert-large-cased-whole-word-masking-finetuned-squad'

# BERT cannot deal with sequences, where T > 512
max_seq_length = 512

# defin mini-batch size
batch_size = 32

# create domain_to_idx and dataset_to_idx mappings
domains = ['books', 'electronics', 'grocery', 'movies', 'restaurants', 'trustyou', 'tripadvisor', 'all', 'wikipedia']
datasets = ['SQuAD', 'SubjQA']

idx_to_domain = dict(enumerate(domains))
domain_to_idx = {domain: idx for idx, domain in enumerate(domains)}

idx_to_dataset = dict(enumerate(datasets))
dataset_to_idx = {dataset: idx for idx, dataset in enumerate(datasets)}

In [6]:
subjqa_examples_train = create_examples(
                                        subjqa_data_train,
                                        source='SubjQA',
                                        is_training=True,
)

subjqa_examples_dev = create_examples(
                                      subjqa_data_dev,
                                      source='SubjQA',
                                      is_training=True,
)

In [7]:
squad_examples_train = create_examples(
                                       squad_data_train,
                                       source='SQuAD',
                                       is_training=True,
)

# create train and dev examples from train set only
squad_examples_train, squad_examples_dev = split_into_train_and_dev(squad_examples_train)

In [8]:
subjqa_features_train = convert_examples_to_features(
                                                    subjqa_examples_train, 
                                                    bert_tokenizer,
                                                    max_seq_length=max_seq_length,
                                                    doc_stride=100,
                                                    max_query_length=50,
                                                    is_training=True,
                                                    domain_to_idx=domain_to_idx,
                                                    dataset_to_idx=dataset_to_idx,
)

subjqa_features_dev = convert_examples_to_features(
                                                    subjqa_examples_dev, 
                                                    bert_tokenizer,
                                                    max_seq_length=max_seq_length,
                                                    doc_stride=100,
                                                    max_query_length=50,
                                                    is_training=True,
                                                    domain_to_idx=domain_to_idx,
                                                    dataset_to_idx=dataset_to_idx,
)

HBox(children=(IntProgress(value=0, max=15246), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1671), HTML(value='')))




In [9]:
squad_features_train = convert_examples_to_features(
                                                    squad_examples_train, 
                                                    bert_tokenizer,
                                                    max_seq_length=max_seq_length,
                                                    doc_stride=100,
                                                    max_query_length=50,
                                                    is_training=True,
                                                    domain_to_idx=domain_to_idx,
                                                    dataset_to_idx=dataset_to_idx,
)

squad_features_dev = convert_examples_to_features(
                                                squad_examples_dev, 
                                                bert_tokenizer,
                                                max_seq_length=max_seq_length,
                                                doc_stride=100,
                                                max_query_length=50,
                                                is_training=True,
                                                domain_to_idx=domain_to_idx,
                                                dataset_to_idx=dataset_to_idx,
)

HBox(children=(IntProgress(value=0, max=15228), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3807), HTML(value='')))




In [10]:
subjqa_tensor_dataset_train = create_tensor_dataset(
                                                   subjqa_features_train,
                                                   evaluate=False,
)

subjqa_tensor_dataset_dev = create_tensor_dataset(
                                                  subjqa_features_dev,
                                                  evaluate=False,
)

In [11]:
squad_tensor_dataset_train = create_tensor_dataset(
                                                   squad_features_train,
                                                   evaluate=False,
)

squad_tensor_dataset_dev = create_tensor_dataset(
                                                 squad_features_dev,
                                                 evaluate=False,
)

## Create train and dev dataloaders

In [12]:
subjqa_train_dl = create_batches(
                                dataset=subjqa_tensor_dataset_train,
                                batch_size=batch_size,
                                split='train',
)

subjqa_dev_dl = create_batches(
                              dataset=subjqa_tensor_dataset_dev,
                              batch_size=batch_size,
                              split='eval',
)

In [13]:
squad_train_dl = create_batches(
                                dataset=squad_tensor_dataset_train,
                                batch_size=batch_size,
                                split='train',
)

squad_dev_dl = create_batches(
                              dataset=squad_tensor_dataset_dev,
                              batch_size=batch_size,
                              split='eval',
)

In [14]:
# initialise QA model
model = BertForQA.from_pretrained(
                                 pretrained_weights,
                                 qa_head_name='RecurrentQAHead',
                                 max_seq_length=max_seq_length,
                                 highway_connection=True,
                                 multitask=False,
)

# set model to device
model.to(device)

BertForQA(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_af

In [15]:
args = {
        "n_epochs": 3,
        "lr": 1e-3,
        "warmup_steps": 100,
        "max_grad_norm": 10,
        "squad": True,
}

In [None]:
train(
     model=model,
     tokenizer=bert_tokenizer,
     train_dl=subjqa_train_dl,
     val_dl=subjqa_dev_dl,
     batch_size=batch_size,
     args=args,
)

In [20]:
def train(
          model,
          tokenizer,
          train_dl,
          val_dl,
          batch_size,
          args,
):
    n_examples = len(train_dl) * batch_size
    t_total = len(train_dl) * args['n_epochs'] # total number of training steps (i.e., step = iteration)

    if args["squad"]:
        model = freeze_transformer_layers(model)
        print("------ Pre-trained BERT model is frozen -------")
        
    optimizer = AdamW(
                      model.parameters(), 
                      lr=args['lr'], 
                      correct_bias=False,
    )

    scheduler = get_linear_schedule_with_warmup(
                                                optimizer, 
                                                num_warmup_steps=args["warmup_steps"], 
                                                num_training_steps=t_total,
    )
    
    # store loss and accuracy for plotting
    batch_losses = []
    train_losses = []
    train_accs = []
    train_f1s = []
    val_losses = []
    val_accs = []
    val_f1s = []
        
    loss_func = nn.CrossEntropyLoss()

    for epoch in trange(args['n_epochs'],  desc="Epoch"):

        ### Training ###

        model.train()

        tr_loss, correct_answers, batch_f1 = 0, 0, 0
        nb_tr_examples, nb_tr_steps = 0, 0

        for step, batch in enumerate(tqdm(train_dl, desc="Iteration")):
            
            batch_loss = 0

            # add batch to GPU
            batch = tuple(t.to(device) for t in batch)

            # unpack inputs from dataloader            
            b_input_ids, b_attn_masks, b_token_type_ids, b_input_lengths, b_start_pos, b_end_pos, b_cls_indexes, _, _, _, _, _ = batch
            
            # sort sequences in batch in decreasing order w.r.t. to (original) sequence length
            b_input_ids, b_attn_masks, b_type_ids, b_input_lengths, b_start_pos, b_end_pos = sort_batch(
                                                                                                        b_input_ids,
                                                                                                        b_attn_masks,
                                                                                                        b_token_type_ids,
                                                                                                        b_input_lengths,
                                                                                                        b_start_pos,
                                                                                                        b_end_pos,
            )
            
            # zero-out gradients
            optimizer.zero_grad()
            
            # compute start and end logits respectively
            start_logits, end_logits = model(
                                             input_ids=b_input_ids,
                                             attention_masks=b_attn_masks,
                                             token_type_ids=b_token_type_ids,
                                             input_lengths=b_input_lengths,
            )
            
            # start and end loss must be computed separately
            start_loss = loss_func(start_logits, b_start_pos)
            end_loss = loss_func(end_logits, b_end_pos)
            
            batch_loss = (start_loss + end_loss) / 2
            
            print("Current batch loss: {}".format(batch_loss))
            print()

            batch_losses.append(batch_loss.item())
            
            start_log_probas = to_cpu(F.log_softmax(start_logits, dim=1), detach=False, to_numpy=False)
            end_log_probas = to_cpu(F.log_softmax(end_logits, dim=1), detach=False, to_numpy=False)
            
            pred_answers = get_answers(
                                       tokenizer=tokenizer,
                                       b_input_ids=b_input_ids,
                                       start_logs=start_log_probas,
                                       end_logs=end_log_probas,
                                       predictions=True,
            )
            
            true_answers = get_answers(
                                       tokenizer=tokenizer,
                                       b_input_ids=b_input_ids,
                                       start_logs=b_start_pos,
                                       end_logs=b_end_pos,
                                       predictions=False,
            )
            
            correct_answers += compute_exact_batch(true_answers, pred_answers)
            batch_f1 += compute_f1_batch(true_answers, pred_answers)
                        
            # backpropagate error
            #TODO: figure out whether you have to backpropagate the error separately for start and end loss
            #start_loss.backward()
            #end_loss.backward()
            batch_loss.backward()
            
            # clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), args["max_grad_norm"])

            # update model parameters and take a step using the computed gradient
            optimizer.step()
            scheduler.step()

            tr_loss += batch_loss.item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1
            
            current_batch_f1 = batch_f1 / nb_tr_examples
            current_batch_acc = 100 * (correct_answers / nb_tr_examples)
            
            print("Current batch exact-match: {} %".format(current_batch_acc))
            print("Current batch F1: {} %".format(current_batch_f1))
            print()
        
        train_f1 = batch_f1 / nb_tr_examples
        train_loss = tr_loss / nb_tr_steps
        train_exact_match = 100 * (correct_answers / n_tr_examples)
        
        print("---------- EPOCH {} ----------".format(epoch))
        print("----- Train loss: {} -----".format(tr_loss/nb_tr_steps))
        print("----- Train exact-match: {} % -----".format(train_exact_match))
        print("----- Train F1: {} % -----".format(train_f1))
        print()

        train_losses.append(train_loss)
        train_accs.append(train_exact_match)
        train_f1s.append(train_f1)
       
        ### Validation ###

        # set model to eval mode
        model.eval()
        
        correct_answers_val, batch_f1_val = 0, 0
        val_f1, val_loss = 0, 0
        nb_val_steps = 0

        for batch in val_dl:
            
            batch_loss_val = 0
            
            # add batch to current device
            batch = tuple(t.to(device) for t in batch)

            # unpack inputs from dataloader            
            b_input_ids, b_attn_masks, b_token_type_ids, b_input_lengths, b_start_pos, b_end_pos, b_cls_indexes, _, _, _, _, _ = batch
            
            
            # sort sequences in batch in decreasing order w.r.t. to (original) sequence length
            b_input_ids, b_attn_masks, b_type_ids, b_input_lengths, b_start_pos, b_end_pos = sort_batch(
                                                                                                        b_input_ids,
                                                                                                        b_attn_masks,
                                                                                                        b_token_type_ids,
                                                                                                        b_input_lengths,
                                                                                                        b_start_pos,
                                                                                                        b_end_pos,
            )
            
            with torch.no_grad():
                
                # compute start and end logits respectively
                start_logits_val, end_logits_val = model(
                                                         input_ids=b_input_ids,
                                                         attention_masks=b_attn_masks,
                                                         token_type_ids=b_token_type_ids,
                                                         input_lengths=b_input_lengths,
                )

                start_true_val = to_cpu(b_start_pos)
                end_true_val = to_cpu(b_end_pos)
                
                # start and end loss must be computed separately
                start_loss = loss_func(start_logits_val, b_start_pos)
                end_loss = loss_func(end_logits_val, b_end_pos)

                batch_loss_val = (start_loss + end_loss) / 2
                print("Current val loss: {}".format(total_loss))
                
                start_log_probs_val = to_cpu(F.log_softmax(start_logits_val, dim=1), detach=True, to_numpy=False)
                end_log_probs_val = to_cpu(F.log_softmax(end_logits_val, dim=1), detach=True, to_numpy=False)
            
                pred_answers = get_answers(
                                           tokenizer=tokenizer,
                                           b_input_ids=b_input_ids,
                                           start_logs=start_log_probs_val,
                                           end_logs=end_log_probs_val,
                                           predictions=True,
                )

                true_answers = get_answers(
                                           tokenizer=tokenizer,
                                           b_input_ids=b_input_ids,
                                           start_logs=b_start_pos,
                                           end_logs=b_end_pos,
                                           predictions=False,
                )
                
                correct_answers_val += compute_exact_batch(true_answers, pred_answers)
                batch_f1_val += compute_f1_batch(true_answers, pred_answers)
                
                val_loss += batch_loss_val.item()
                nb_val_examples += b_input_ids.size(0)
                nb_val_steps += 1
                
                current_batch_f1 = batch_f1_val / nb_val_examples
                current_batch_acc = 100 * (correct_answers_val / nb_val_examples)

        val_f1 = batch_f1 / nb_tr_examples
        val_loss = val_loss / nb_val_steps
        val_exact_match = 100 * (correct_answers / n_tr_examples)
        
        print("---------- EPOCH {} ----------".format(epoch))
        print("----- Val loss: {} -----".format(val_loss))
        print("----- Val exact-match: {} % -----".format(val_exact_match))
        print("----- Val F1: {} % -----".format(val_f1))
        print()

        val_losses.append(val_loss)
        val_accs.append(val_exact_match)
        val_f1s.append(val_f1)
       
    return batch_losses, train_losses, train_accs, train_f1s, val_losses, val_accs, val_f1s, model