In [1]:
pip install transformers==3.0.2

Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.append("../")

import warnings
warnings.filterwarnings("ignore")

import os.path
import numpy as np
import pandas as pd
import re

from newsqa import NewsQaExample, NewsQaModel, create_dataset, getprediction
import utils

from transformers import BertTokenizer, BertForQuestionAnswering
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0x7fe26a574890>

In [3]:
# Loading the data
NEWS_STORIES = utils.open_pickle('../data/news_stories.pkl')
data = pd.read_csv('../data/newsqa-dataset-cleaned.csv')
total_examples = len(data)

In [7]:
NEWS_STORIES[row['42d01e187213e86f5fe617fe32e716ff7fa3afc4']]

NameError: name 'row' is not defined

In [5]:
data

Unnamed: 0,story_id,question,answer_char_ranges,validated_answers,start_idx,end_idx
0,42d01e187213e86f5fe617fe32e716ff7fa3afc4,What was the amount of children murdered?,294:297|None|None,"{""none"": 1, ""294:297"": 2}",293,300
1,c48228a52f26aca65c31fad273e66164f047f292,Where was one employee killed?,34:60|1610:1618|34:60,,34,59
2,c65ed85800e4535f4bbbfa2c34d7d9630358d303,who did say South Africa did not issue a visa ...,103:127|114:127|839:853,"{""839:853"": 1, ""103:127"": 2}",103,126
3,0cf66b646e9b32076513c050edf32a799200c3c2,How many years old was the businessman?,538:550|538:550,,530,549
4,13012604e3203c18df09289dfedd14cde67cf40b,What frightened the families?,690:742|688:791|630:646,"{""688:791"": 2, ""690:742"": 1}",682,782
...,...,...,...,...,...,...
87805,5e7c990b12d43b077d476413a16c05fad2398c35,what does Soufan's book argue against?,2682:2806|2700:2806|2709:2840,"{""2709:2840"": 2}",2698,2791
87806,4424c8580952975a3e367176a215c78711246bdd,is toyota under fire issues on sticking gas pe...,None|None,,-1,-1
87807,7b2b414d8cbc968f4df05bcefb2f9f0fd3052083,what are the men being detained for,2386:2435|1146:1167|None,"{""2386:2435"": 2}",2408,2434
87808,4566e90ca5e65f0323c41319030ca4349357cd67,In what year didIvory Coast exit in group stag...,None|1260:1265|1260:1265,,1257,1265


In [8]:
def get_examples():
    '''
    Return a list of NewsQaExample objects
    '''
    # If a pickle file exists for examples, read the file
    # if os.path.isfile('../data/examples.pkl'):
        # return utils.open_pickle('../data/examples.pkl')
    
    examples = []

    for idx, row in data.iterrows():
        ex = NewsQaExample(NEWS_STORIES[row['story_id']], row['question'], row['start_idx'], row['end_idx'])
        examples.append(ex)
        print(ex)
        utils.drawProgressBar(idx + 1, total_examples)
        break
    print('\n')
    # Saving examples to a pickle file
    # utils.save_pickle('../data/examples.pkl', examples)
    
    return examples

# def get_examples():
#     '''
#     Return a list of NewsQaExample objects
#     '''
#     # If a pickle file exists for examples, read the file
#     if os.path.isfile('../data/examples.pkl'):
#         return utils.open_pickle('../data/examples.pkl')
    
#     examples = []

#     for idx, row in data.iterrows():
#         ex = NewsQaExample(NEWS_STORIES[row['story_id']], row['question'], row['start_idx'], row['end_idx'])
#         examples.append(ex)
#         # print(ex)
#         utils.drawProgressBar(idx + 1, total_examples)
#         # break
#     print('\n')
#     # Saving examples to a pickle file
#     utils.save_pickle('../data/examples.pkl', examples)
    
#     return examples

examples = get_examples()

text: NEW DELHI, India (CNN) -- A high court in northern India on Friday acquitted a wealthy businessman facing the death sentence for the killing of a teen in a case dubbed "the house of horrors."

Moninder Singh Pandher was sentenced to death by a lower court in February.

The teen was one of 19 victims -- children and young women -- in one of the most gruesome serial killings in India in recent years.

The Allahabad high court has acquitted Moninder Singh Pandher, his lawyer Sikandar B. Kochar told CNN.

Pandher and his domestic employee Surinder Koli were sentenced to death in February by a lower court for the rape and murder of the 14-year-old.

The high court upheld Koli's death sentence, Kochar said.

The two were arrested two years ago after body parts packed in plastic bags were found near their home in Noida, a New Delhi suburb. Their home was later dubbed a "house of horrors" by the Indian media.

Pandher was not named a main suspect by investigators initially, but was summo

In [15]:
examples

[text: NEW DELHI, India (CNN) -- A high court in northern India on Friday acquitted a wealthy businessman facing the death sentence for the killing of a teen in a case dubbed "the house of horrors."
 
 Moninder Singh Pandher was sentenced to death by a lower court in February.
 
 The teen was one of 19 victims -- children and young women -- in one of the most gruesome serial killings in India in recent years.
 
 The Allahabad high court has acquitted Moninder Singh Pandher, his lawyer Sikandar B. Kochar told CNN.
 
 Pandher and his domestic employee Surinder Koli were sentenced to death in February by a lower court for the rape and murder of the 14-year-old.
 
 The high court upheld Koli's death sentence, Kochar said.
 
 The two were arrested two years ago after body parts packed in plastic bags were found near their home in Noida, a New Delhi suburb. Their home was later dubbed a "house of horrors" by the Indian media.
 
 Pandher was not named a main suspect by investigators initially

In [1]:
def get_datasets(examples, tokenizer_name):
    
    model_name = tokenizer_name.split('-')[0]
    
    if os.path.isfile('../data/dataset_' + model_name + '.pkl'):
        return utils.open_pickle('../data/dataset_' + model_name + '.pkl')
    
    features = []
    labels = []
    
    if tokenizer_name == 'bert-large-uncased-whole-word-masking-finetuned-squad':
        tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
    
    if tokenizer_name == 'distilbert-base-uncased-distilled-squad':
        tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_name)
    
    print("Getting input features:")
    for idx, ex in enumerate(examples):
        input_features = ex.encode_plus(tokenizer, pad = True)
        features.append(input_features)
        labels.append(ex.get_label())
        utils.drawProgressBar(idx + 1, total_examples)
    
    # Getting TensorDataset
    train_set, val_set, test_set, feature_idx_map = create_dataset(features, labels, model = model_name)
    # Saving the dataset in a file
    utils.save_pickle('../data/dataset_' + model_name + '.pkl', (train_set, val_set, test_set, feature_idx_map))
    
    return (train_set, val_set, test_set, feature_idx_map)

In [2]:
def get_dataloaders(train_set, val_set, test_set, batch_size):
   
    train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, 
                          sampler = RandomSampler(train_set))

    val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, 
                            sampler = SequentialSampler(val_set))

    test_loader = DataLoader(test_set, batch_size = BATCH_SIZE, 
                             sampler = SequentialSampler(test_set))
    
    return train_loader, val_loader, test_loader

In [3]:
def finetune_model(model_name, train_loader, val_loader, feature_idx_map, device, 
                   epochs = 1, learning_rate = 1e-5):
    
    if model_name == 'bert-large-uncased-whole-word-masking-finetuned-squad':
        model = BertForQuestionAnswering.from_pretrained(model_name)
        # Freezing bert parameters
        for param in model.bert.parameters():
            param.requires_grad = False
    
    if model_name == 'distilbert-base-uncased-distilled-squad':
        model = DistilBertForQuestionAnswering.from_pretrained(model_name)
        # Freezing distilbert parameters
        for param in model.distilbert.parameters():
            param.requires_grad = False
        
    short_name = model_name.split('-')[0]
    
    newsqa_model = NewsQaModel(model)
    newsqa_model.train(train_loader, val_loader, feature_idx_map, device, 
                       num_epochs = epochs, lr = learning_rate, 
                       filename = '../data/' + short_name + '_model.pt')
    
    return newsqa_model

In [7]:
# Get a list of NewsQaExample objects
examples = get_examples()

In [8]:
# Defining model name
bert_model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'

In [9]:
# Getting the training, validation and test sets
bert_datasets = get_datasets(examples, bert_model_name)
bert_train_set, bert_val_set, bert_test_set, bert_feature_idx_map = bert_datasets

In [10]:
# Getting data loaders
BATCH_SIZE = 32

bert_loaders = get_dataloaders(bert_train_set, bert_val_set, bert_test_set, batch_size = BATCH_SIZE)
bert_train_loader, bert_val_loader, bert_test_loader = bert_loaders

In [12]:
EPOCHS = 5
LEARNING_RATE = 0.001

bert_model = finetune_model(bert_model_name, bert_train_loader, bert_val_loader, bert_feature_idx_map,device, epochs = EPOCHS, learning_rate = LEARNING_RATE)

Epoch 1/5:
Validation accuracy increased from 0.0000 to 0.6174, saving to models/bert.pt



Epoch 2/5:
Validation accuracy increased from 0.6174 to 0.6542, saving to models/bert.pt



Epoch 3/5:
Validation accuracy increased from 0.6542 to 0.6641, saving to models/bert.pt



Epoch 4/5:
Validation accuracy increased from 0.6641 to 0.6643, saving to models/bert.pt



Epoch 5/5:
Validation accuracy increased from 0.6643 to 0.6673, saving to models/bert.pt


In [14]:
# Evaluation the performance on test set
bert_model.load('../data/bert_model.pt')
bert_eval_metrics = bert_model.evaluate(bert_test_loader, bert_feature_idx_map, device)

loss: 1.3887	f1:0.5313	acc:0.6750
