In [None]:
#####################################################################
# Mandatory : Run this cell and restart the notebook kernel right after
#####################################################################
!pip install torch==1.11 transformers

In [None]:
%cd /content

In [None]:
%%bash
#####################################################################
# Only use on Googgle Colab, uncomment if necessary
#####################################################################
# Clone the repo content into
cd /content
rm -rf deepqa
git clone -b model https://github.com/PaulBeuran/deepqa.git

In [None]:
#####################################################################
# Only use on Googgle Colab, uncomment if necessary
#####################################################################
%cd deepqa/notebooks/

In [None]:
%%bash
# Dowload the SQuAD1.1 data
curl -O https://data.deepai.org/squad1.1.zip
unzip -o squad1.1.zip -d data

In [1]:
import sys
import os
import pathlib

import json

import pandas as pd
import numpy as np
import torch
import transformers

from tqdm import tqdm

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

deepqa_lib_path = str(pathlib.Path(os.getcwd()).parent.parent.absolute())
sys.path.insert(0, deepqa_lib_path)

from deepqa import preprocessing, tokenizer, module, model, wrapper, loss, metrics, utils

2022-08-02 01:13:48.034702: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-02 01:13:48.034772: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
# Read train and dev data
with open("data/train-v1.1.json", "rb") as j:
    train_data = json.load(j)["data"]
with open("data/dev-v1.1.json", "rb") as j:
    dev_data = json.load(j)["data"]

In [3]:
train_tcqa = preprocessing.tabularize_squad11_data(train_data)
dev_tcqa = preprocessing.tabularize_squad11_data(dev_data)

In [4]:
train_dataset = utils.QADataset(*train_tcqa[1:])
dev_dataset = utils.QADataset(*dev_tcqa[1:])

In [5]:
word_tokenizer_encoder_path = "bert-base-uncased"

hf_auto_tokenizer = tokenizer.HFAutoTokenizer(word_tokenizer_encoder_path, True)

In [6]:
contexts_tokens, queries_tokens, answers_tokens_range = (
    hf_auto_tokenizer.tokenize_qa_data(
        train_dataset.contexts, 
        train_dataset.queries, 
        train_dataset.answers,
        256, 64, 32
    )
)

train_dict_dataset = utils.QADictDataset(contexts_tokens, 
                                         queries_tokens, 
                                         answers_tokens_range)

train_loader = torch.utils.data.DataLoader(train_dict_dataset)

In [7]:
contexts_tokens, queries_tokens, answers_tokens_range = (
    hf_auto_tokenizer.tokenize_qa_data(
        dev_dataset.contexts, 
        dev_dataset.queries, 
        dev_dataset.answers,
        256, 64, 32
    )
)

dev_dict_dataset = utils.QADictDataset(contexts_tokens, 
                                       queries_tokens, 
                                       answers_tokens_range)

dev_loader = torch.utils.data.DataLoader(dev_dict_dataset)

In [19]:
char_vocab_len = 1372
char_embedding_size = 16
char_cnn = module.CharCNN(char_embedding_size, outs_channels=32 * np.arange(1, 7), 
                           kernel_sizes=np.arange(1, 7))

hf_word_encoder = module.HFAutoWordEncoder("bert-base-uncased", 
                                    char_vocab_len, char_embedding_size, char_cnn)
contextual_embedding_size = 512
bidaf = model.BiDAF(hf_word_encoder, contextual_embedding_size, 0.5)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
bidaf_trainer = wrapper.QATrainWrapper(bidaf)
bidaf_trainer.train(train_loader, 
                    epochs=16, 
                    loss=loss.bi_cross_entropy,
                    val_data_iter=dev_loader)

  0%|          | 0/87599 [00:00<?, ?it/s]

{'input_ids': tensor([[  101,  6549,  2135,  1010,  1996,  2082,  2038,  1037,  3234,  2839,
          1012, 10234,  1996,  2364,  2311,  1005,  1055,  2751,  8514,  2003,
          1037,  3585,  6231,  1997,  1996,  6261,  2984,  1012,  3202,  1999,
          2392,  1997,  1996,  2364,  2311,  1998,  5307,  2009,  1010,  2003,
          1037,  6967,  6231,  1997,  4828,  2007,  2608,  2039, 14995,  6924,
          2007,  1996,  5722,  1000,  2310,  3490,  2618,  4748,  2033, 18168,
          5267,  1000,  1012,  2279,  2000,  1996,  2364,  2311,  2003,  1996,
         13546,  1997,  1996,  6730,  2540,  1012,  3202,  2369,  1996, 13546,
          2003,  1996, 24665, 23052,  1010,  1037, 14042,  2173,  1997,  7083,
          1998,  9185,  1012,  2009,  2003,  1037, 15059,  1997,  1996, 24665,
         23052,  2012, 10223, 26371,  1010,  2605,  2073,  1996,  6261,  2984,
         22353,  2135,  2596,  2000,  3002, 16595,  9648,  4674,  2061, 12083,
          9711,  2271,  1999,  8517,  

[1/16 - Train] Loss: 11.041, Overlap F1: 0.0:   0%|          | 1/87599 [00:03<86:36:50,  3.56s/it]

{'input_ids': tensor([[  101,  6549,  2135,  1010,  1996,  2082,  2038,  1037,  3234,  2839,
          1012, 10234,  1996,  2364,  2311,  1005,  1055,  2751,  8514,  2003,
          1037,  3585,  6231,  1997,  1996,  6261,  2984,  1012,  3202,  1999,
          2392,  1997,  1996,  2364,  2311,  1998,  5307,  2009,  1010,  2003,
          1037,  6967,  6231,  1997,  4828,  2007,  2608,  2039, 14995,  6924,
          2007,  1996,  5722,  1000,  2310,  3490,  2618,  4748,  2033, 18168,
          5267,  1000,  1012,  2279,  2000,  1996,  2364,  2311,  2003,  1996,
         13546,  1997,  1996,  6730,  2540,  1012,  3202,  2369,  1996, 13546,
          2003,  1996, 24665, 23052,  1010,  1037, 14042,  2173,  1997,  7083,
          1998,  9185,  1012,  2009,  2003,  1037, 15059,  1997,  1996, 24665,
         23052,  2012, 10223, 26371,  1010,  2605,  2073,  1996,  6261,  2984,
         22353,  2135,  2596,  2000,  3002, 16595,  9648,  4674,  2061, 12083,
          9711,  2271,  1999,  8517,  

[1/16 - Train] Loss: 11.19, Overlap F1: 0.0:   0%|          | 2/87599 [00:06<80:19:54,  3.30s/it] 

{'input_ids': tensor([[  101,  6549,  2135,  1010,  1996,  2082,  2038,  1037,  3234,  2839,
          1012, 10234,  1996,  2364,  2311,  1005,  1055,  2751,  8514,  2003,
          1037,  3585,  6231,  1997,  1996,  6261,  2984,  1012,  3202,  1999,
          2392,  1997,  1996,  2364,  2311,  1998,  5307,  2009,  1010,  2003,
          1037,  6967,  6231,  1997,  4828,  2007,  2608,  2039, 14995,  6924,
          2007,  1996,  5722,  1000,  2310,  3490,  2618,  4748,  2033, 18168,
          5267,  1000,  1012,  2279,  2000,  1996,  2364,  2311,  2003,  1996,
         13546,  1997,  1996,  6730,  2540,  1012,  3202,  2369,  1996, 13546,
          2003,  1996, 24665, 23052,  1010,  1037, 14042,  2173,  1997,  7083,
          1998,  9185,  1012,  2009,  2003,  1037, 15059,  1997,  1996, 24665,
         23052,  2012, 10223, 26371,  1010,  2605,  2073,  1996,  6261,  2984,
         22353,  2135,  2596,  2000,  3002, 16595,  9648,  4674,  2061, 12083,
          9711,  2271,  1999,  8517,  

[1/16 - Train] Loss: 11.19, Overlap F1: 0.0:   0%|          | 2/87599 [00:09<116:12:30,  4.78s/it]


KeyboardInterrupt: 