<a href="https://colab.research.google.com/github/JasonLUrquhart/Applied-Data-Science-Capstone/blob/master/AISC_Workshop1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# This notebook is part of the workshop "modern natural language processing" run 
# by Aggregate Intellect Inc. (https://ai.science), and is released under 
# 'Creative Commons Attribution-NonCommercial-ShareAlike CC BY-NC-SA" license. 
# This material can be altered and distributed for non-commercial use with 
# reference to Aggregate Intellect Inc. as the original owner, and any material 
# generated from it must be released under similar terms. 
# (https://creativecommons.org/licenses/by-nc-sa/4.0/)

# Huggingface's BERT

Huggingface, a company focused on social chatbots, provides Pytorch implementations of several cutting edge models!

We'll be using[ their implementation of BERT](https://github.com/huggingface/pytorch-pretrained-BERT) for this workshop.




In [0]:
! git clone https://github.com/huggingface/pytorch-pretrained-BERT.git
% cd pytorch-pretrained-BERT
! pip install .
%cd ..

In [0]:
import json
import logging
import math
import collections
from io import open
import pickle
import pprint

import numpy as np
import pandas as pd
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)

from tqdm import tqdm, trange, tqdm_notebook
from time import sleep

from pytorch_pretrained_bert import BertForQuestionAnswering
from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize, BertTokenizer
from pytorch_pretrained_bert.file_utils import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

### Import some functions from Huggingface's github

The source we are loading as a module provides classes to represent training examples and model features, as well as data preprocessing functions.


In [0]:
import urllib.request
import imp

# URL to the raw .py file used to read and convert SQuAD 
# dataset to feature vectors
source = 'https://raw.githubusercontent.com/huggingface/pytorch-pretrained-BERT/' + \
         'master/examples/run_squad_dataset_utils.py'
modulesource = urllib.request.urlopen(source).read()

def makemodule(modulesource,sourcestr,modname=None):
    if not modname: modname = 'newmodulename'
    codeobj = compile(modulesource, sourcestr, 'exec')
    newmodule = imp.new_module(modname)
    exec(codeobj,newmodule.__dict__)
    return newmodule
  
data_functs = makemodule(modulesource, source)

### Create a tokenizer to process raw text


In [0]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

# The SQuAD Dataset

Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage.

You can find more information about the dataset, and approaches to reading comprehension [ here!](https://rajpurkar.github.io/SQuAD-explorer/)

### Download train/test examples

In [0]:
!wget 'https://drive.google.com/uc?id=1OBjbuy9lIxcKrnqAq6n-m8S7T_GBw2h0' -O 'train-v1.1.json'
!wget 'https://drive.google.com/uc?id=1TtsI3_Jm2OQuJ_nOi2_KoBP9RdcRAs9r' -O 'eval-v1.1.json'

In [0]:
filename = 'train-v1.1.json'
train_examples = data_functs.read_squad_examples(
                   filename, 
                   True, 
                   False)

train_examples = data_functs.read_squad_examples(
                   filename, 
                   True, 
                   False)

eval_filename = 'eval-v1.1.json'
eval_examples = data_functs.read_squad_examples(
                  eval_filename,
                  False,
                  False)

### Some notes on tokenizing

BERT Uses WordPiece Tokenization, which helps handle a broad range of words outside of the vocabulary. 

In [0]:
print(train_examples[8635].question_text)
print(tokenizer.tokenize(train_examples[8635].question_text))

## Exploring the SQuAD dataset

We've defined a type to represent our training examples.

It contains:

1.   the question and answer text
2.   the tokenized document we need to comprehend
3.   positions of the answer in the document.



In [0]:
# Anatomy of a SQuAD example
ex = train_examples[0]
print(f"Document: {ex.doc_tokens}")
print(f"Question: {ex.question_text}")
print(f"Answer: {ex.orig_answer_text}")
print(f"Answer token position range: {(ex.start_position, ex.end_position)}")

In [0]:
import matplotlib.pyplot as plt
plt.hist([len(ex.doc_tokens) for ex in train_examples])
plt.title("Length SQuAD of Training Documents (Tokenized)")
plt.show()


In [0]:
plt.hist([len(tokenizer.tokenize(ex.question_text)) for ex in train_examples])
plt.title("Length SQuAD of Training Questions (Tokenized)")
plt.show()


In [0]:
max_seq_length=324
doc_stride=128
max_query_length=64
predict_batch_size=128
local_rank = -1
n_best_size=20
max_answer_length=30
is_training=False

# Making Predictions

## Importing our Fine-Tuned BERT for SQuAD

In [0]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

## Setup to execute the model on our GPU and Look at the Architecture





In [0]:
n_gpu = torch.cuda.device_count()
device = torch.device("cuda")
model.to(device)

## Converting text to input features

We need to transform our tokens into numeric features for our network to consume!

Tokens are mapped to integers, which are mapped to embeddings in the first layer of our network.

In [0]:
help(data_functs.convert_examples_to_features)

In [0]:
eval_features = data_functs.convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=is_training)

In [0]:
eval_features = eval_features[:1000]

In [0]:
print("***** Running predictions *****")
print(f"Num orig examples = {len(eval_examples)}")
print(f"Num split examples = {len(eval_features)}")
print(f"Batch size = {predict_batch_size}")

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)


eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)

# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=predict_batch_size)

model.eval()
all_results = []
for input_ids, input_mask, segment_ids, example_indices in \
      tqdm_notebook(eval_dataloader, desc="Evaluating", disable=local_rank not in [-1, 0]):
    sleep(0.0001) # sleep to help tqdm progress bar behave

    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    with torch.no_grad():
        batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
        for i, example_index in enumerate(example_indices):
            start_logits = batch_start_logits[i].detach().tolist()
            end_logits = batch_end_logits[i].detach().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(data_functs.RawResult(
                                   unique_id=unique_id,
                                   start_logits=start_logits,
                                   end_logits=end_logits)
                              )

## Raw Results

Our model produces two sets of logits representing the predicted Start and End positions.

In [0]:
pprint.pprint(all_results[0].__repr__())

In [0]:
pprint.pprint(np.argmax(all_results[0].start_logits))
pprint.pprint(np.argmax(all_results[0].end_logits))

## Transform our output logits into responses



In [0]:
help(data_functs.write_predictions)

In [0]:
data_functs.write_predictions(
    eval_examples, 
    eval_features, 
    all_results,
    n_best_size, 
    max_answer_length,
    False, 
    "predictions.json",
    "n_best.json", 
    "null_log_odds", 
    False,
    False, 
    0.0
)

In [0]:
print(eval_examples[0].__repr__())

## Create a map of question IDs to associate results 

In [0]:
questions = {}
for i in eval_examples:
  if i.qas_id not in questions:
    questions[i.qas_id] = i

## Explore our results

In [0]:
data = json.load(open('predictions.json','r'))

for i in data:
  print(f"Question: {questions[i].question_text}\nDocument:{questions[i].doc_tokens}\nPredicted Answer: {data[i]}\n")

## Understanding how we extract features from text

In [0]:
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Load SQuAD dataset. """

from __future__ import absolute_import, division, print_function

import json
import logging
import math
import collections
from io import open

from pytorch_pretrained_bert.tokenization import BasicTokenizer, whitespace_tokenize

logger = logging.getLogger(__name__)


class SquadExample(object):
    """
    A single training/test example for the Squad dataset.
    For examples without an answer, the start and end position are -1.
    """

    def __init__(self,
                 qas_id,
                 question_text,
                 doc_tokens,
                 orig_answer_text=None,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "qas_id: %s" % (self.qas_id)
        s += ", question_text: %s" % (
            self.question_text)
        s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
        if self.start_position:
            s += ", start_position: %d" % (self.start_position)
        if self.end_position:
            s += ", end_position: %d" % (self.end_position)
        if self.is_impossible:
            s += ", is_impossible: %r" % (self.is_impossible)
        return s


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 unique_id,
                 example_index,
                 doc_span_index,
                 tokens,
                 token_to_orig_map,
                 token_is_max_context,
                 input_ids,
                 input_mask,
                 segment_ids,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible


def read_squad_examples(input_file, is_training, version_2_with_negative):
    """Read a SQuAD json file into a list of SquadExample."""
    with open(input_file, "r", encoding='utf-8') as reader:
        input_data = json.load(reader)["data"]

    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False

    examples = []
    for entry in input_data:
        for paragraph in entry["paragraphs"]:
            paragraph_text = paragraph["context"]
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position = None
                end_position = None
                orig_answer_text = None
                is_impossible = False
                if is_training:
                    if version_2_with_negative:
                        is_impossible = qa["is_impossible"]
                    if (len(qa["answers"]) != 1) and (not is_impossible):
                        raise ValueError(
                            "For training, each question should have exactly 1 answer.")
                    if not is_impossible:
                        answer = qa["answers"][0]
                        orig_answer_text = answer["text"]
                        answer_offset = answer["answer_start"]
                        answer_length = len(orig_answer_text)
                        start_position = char_to_word_offset[answer_offset]
                        end_position = char_to_word_offset[answer_offset + answer_length - 1]
                        
                        # Only add answers where the text can be exactly recovered from the
                        # document. If this CAN'T happen it's likely due to weird Unicode
                        # stuff so we will just skip the example.
                        #
                        # Note that this means for training mode, every example is NOT
                        # guaranteed to be preserved.
                        actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
                        cleaned_answer_text = " ".join(
                            whitespace_tokenize(orig_answer_text))
                        if actual_text.find(cleaned_answer_text) == -1:
                            logger.warning("Could not find answer: '%s' vs. '%s'",
                                           actual_text, cleaned_answer_text)
                            continue
                    else:
                        start_position = -1
                        end_position = -1
                        orig_answer_text = ""

                example = SquadExample(
                    qas_id=qas_id,
                    question_text=question_text,
                    doc_tokens=doc_tokens,
                    orig_answer_text=orig_answer_text,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=is_impossible)
                examples.append(example)
    return examples


def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length, is_training):
    """Loads a data file into a list of `InputBatch`s."""

    unique_id = 1000000000

    features = []
    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset
            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0
            if example_index < 20:
                logger.info("*** Example ***")
                logger.info("unique_id: %s" % (unique_id))
                logger.info("example_index: %s" % (example_index))
                logger.info("doc_span_index: %s" % (doc_span_index))
                logger.info("tokens: %s" % " ".join(tokens))
                logger.info("token_to_orig_map: %s" % " ".join([
                    "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
                logger.info("token_is_max_context: %s" % " ".join([
                    "%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
                ]))
                logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                logger.info(
                    "input_mask: %s" % " ".join([str(x) for x in input_mask]))
                logger.info(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
                if is_training and example.is_impossible:
                    logger.info("impossible example")
                if is_training and not example.is_impossible:
                    answer_text = " ".join(tokens[start_position:(end_position + 1)])
                    logger.info("start_position: %d" % (start_position))
                    logger.info("end_position: %d" % (end_position))
                    logger.info(
                        "answer: %s" % (answer_text))

            features.append(
                InputFeatures(
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=example.is_impossible))
            unique_id += 1

    return features


def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
                         orig_answer_text):
    """Returns tokenized answer spans that better match the annotated answer."""

    # The SQuAD annotations are character based. We first project them to
    # whitespace-tokenized words. But then after WordPiece tokenization, we can
    # often find a "better match". For example:
    #
    #   Question: What year was John Smith born?
    #   Context: The leader was John Smith (1895-1943).
    #   Answer: 1895
    #
    # The original whitespace-tokenized answer will be "(1895-1943).". However
    # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
    # the exact answer, 1895.
    #
    # However, this is not always possible. Consider the following:
    #
    #   Question: What country is the top exporter of electornics?
    #   Context: The Japanese electronics industry is the lagest in the world.
    #   Answer: Japan
    #
    # In this case, the annotator chose "Japan" as a character sub-span of
    # the word "Japanese". Since our WordPiece tokenizer does not split
    # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
    # in SQuAD, but does happen.
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)


def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index


RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])

In [0]:
do_train = True # Are we training?
local_rank = -1  
version_2_with_negative=False # Version 2 of SQuAD has "impossible questions" to consider
do_lower_case = True # Do we need to force data to lowercase?
max_seq_length=324 # How long a sequence do we consider? If sequence is bigger, truncate
doc_stride=128 # How much of each sequence to consider at once
max_query_length=64
train_batch_size=32
gradient_accumulation_steps=1
num_train_epochs=1.0
learning_rate=5e-5
loss_scale=0
warmup_proportion=0.1

In [0]:
!wget 'https://drive.google.com/uc?id=1w0yYOeytZOFv_SuydI5xS_pjlFLd--bT' train-v1.1_aisc_extract

In [0]:
# Transform examples into input features
print("converting features")
cached_train_features_file = 'train-v1.1'+'_aisc_extract'
try:
    with open("/content/gdrive/My Drive/" + cached_train_features_file, "rb") as reader:
        train_features = pickle.load(reader)
except:
  print("failed")
  train_features = convert_examples_to_features(
        examples=train_examples,
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
        max_query_length=max_query_length,
        is_training=True)

  logger.info("  Saving train features into cached file %s", cached_train_features_file)
  with open(cached_train_features_file, "wb") as writer:
    pickle.dump(train_features, writer)

In [0]:
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                           all_start_positions, all_end_positions)

print("setting up data loader")
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
num_train_optimization_steps = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs

# Prepare optimizer
param_optimizer = list(model.named_parameters())

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = BertAdam(optimizer_grouped_parameters,
                     lr=learning_rate,
                     warmup=warmup_proportion,
                     t_total=num_train_optimization_steps)

global_step = 0
logger.setLevel(0)
print("***** Running training *****")
print(f"  Num orig examples = {len(train_examples)}", len(train_examples))
print("  Num split examples = %d", len(train_features))
print("  Batch size = %d", train_batch_size)
print("  Num steps = %d", num_train_optimization_steps)

In [0]:
print("training")
model.train()
for epoch in trange(int(num_train_epochs), desc="Epoch"):
    for step, batch in enumerate(tqdm_notebook(train_dataloader, desc="Iteration", disable=local_rank not in [-1, 0])):
        if n_gpu == 1:
            batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
        input_ids, input_mask, segment_ids, start_positions, end_positions = batch
        loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1