In [None]:
from google.colab import drive
drive.mount('/content/drive') #1397

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Imports

In [None]:
! pip install datasets transformers



In [None]:
VERSION = "1.8.1"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
import pandas as pd
import numpy as np
import transformers
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer, default_data_collator, AutoTokenizer

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

print(transformers.__version__)

4.18.0


# Data Preprocessing

In [None]:
df = pd.read_csv('covid19_dataqa.csv')

In [None]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(df, test_size=0.3)

In [None]:
train

Unnamed: 0,Context,Question,Answer
15,The state government will continue with its se...,Who will continue its search to find out wheth...,"{ ""text"": [ ""State government"" ], ""answer_star..."
7,The new delhi traders association (ndta) has s...,What is the new delhi traders association?,"{ ""text"": [ ""Ndta"" ], ""answer_start"": [ 6 ] }"
4,Traders' associations have resumed regular san...,Who has resumed regular sanitization and masking?,"Traders{ ""text"": [ ""Covid cases"" ], ""answer_st..."
10,The new delhi traders association (ndta) has s...,What was atul bhargava concerned about?,"{ ""text"": [ ""Safety"" ], ""answer_start"": [ 9 ] }"
13,Traders at chandni chowk have brought back cov...,Who brought back covid protocols?,"{ ""text"": [ ""Traders"" ], ""answer_start"": [ 12 ] }"
1,Who is working tirelessly with partners to dev...,Research is still ongoing into how much vaccin...,"{ ""text"": [ ""Infection"" ], ""answer_start"": [ 1..."
16,The state government will continue with its se...,Who is monitoring the rate of occurrence of co...,"{ ""text"": [ ""Doctors"" ], ""answer_start"": [ 15 ] }"
3,Traders' associations have resumed regular san...,What is rising in delhi?,"{ ""text"": [ ""Covid cases"" ], ""answer_start"": [..."
2,Who is working tirelessly with partners to dev...,Research is still ongoing into how much vaccin...,"{ ""text"": [ ""Transmission"" ], ""answer_start"": ..."
5,Traders' associations have resumed regular san...,Who has resumed regular sanitization and masking?,"{ ""text"": [ ""Associations"" ], ""answer_start"": ..."


In [None]:
test

Unnamed: 0,Context,Question,Answer
12,Traders at chandni chowk have brought back cov...,What have traders at chandni chowk brought back?,"{ ""text"": [ ""Covid protocols"" ], ""answer_start..."
18,The state government will continue with its se...,Sentinel surveillance is monitoring the rate o...,"{ ""text"": [ ""Laboratories"" ], ""answer_start"": ..."
6,Traders' associations have resumed regular san...,In what city are covid cases rising?,"{ ""text"": [ ""Delhi"" ], ""answer_start"": [ 5 ] }"
17,The state government will continue with its se...,What is monitoring the rate of occurrence of c...,"{ ""text"": [ ""Network"" ], ""answer_start"": [ 16 ] }"
8,The new delhi traders association (ndta) has s...,Over 400 of the ndta have been asked to remain...,"{ ""text"": [ ""Members"" ], ""answer_start"": [ 7 ] }"
0,Who is working tirelessly with partners to dev...,"Who is working tirelessly to develop, manufact...","{ ""text"": [ ""Vaccines"" ], ""answer_start"": [ 0 ] }"


In [None]:
train.to_csv('train.csv', index=False)
test.to_csv('test.csv', index=False)

In [None]:
from datasets import load_dataset
# data = load_dataset('csv', data_files='covid19_dataqa.csv')
data = load_dataset('csv', data_files={'train': 'train.csv', 'validation': 'test.csv'})



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-7a470900ad512580/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-7a470900ad512580/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
for i in data:
  data[i] = data[i].select(range(10))

data

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10
    })
})

In [None]:
model_name = "tau/splinter-base"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
data

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10
    })
})

Tokenization

In [None]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

pad_on_right = tokenizer.padding_side == "right"

def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [None]:
processed_data = data.map(prepare_train_features, batched=True, remove_columns=data["train"].column_names)

processed_data



  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10
    })
})

# Fine Tuning

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

model.train()
WRAPPED_MODEL = xmp.MpModelWrapper(model)

Some weights of SplinterForQuestionAnswering were not initialized from the model checkpoint at tau/splinter-base and are newly initialized: ['splinter_qass.start_transform.LayerNorm.bias', 'splinter_qass.query_start_transform.dense.weight', 'splinter_qass.query_start_transform.LayerNorm.weight', 'splinter_qass.query_end_transform.LayerNorm.bias', 'splinter_qass.query_end_transform.dense.bias', 'splinter_qass.end_transform.dense.weight', 'splinter_qass.end_transform.dense.bias', 'splinter_qass.end_transform.LayerNorm.weight', 'splinter_qass.end_classifier.weight', 'splinter_qass.query_start_transform.dense.bias', 'splinter_qass.start_transform.dense.weight', 'splinter_qass.start_transform.LayerNorm.weight', 'splinter_qass.query_end_transform.dense.weight', 'splinter_qass.end_transform.LayerNorm.bias', 'splinter_qass.start_classifier.weight', 'splinter_qass.query_start_transform.LayerNorm.bias', 'splinter_qass.query_end_transform.LayerNorm.weight', 'splinter_qass.start_transform.dense.bi

# TPU

In [None]:
# This code contains everything that must be done to train our models

def train_loop(model, batch_size=2):
    print("Training... ", end="")

    training_args = TrainingArguments(
        "Q&A",
        evaluation_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=15,
        tpu_num_cores=8,
        logging_steps=1
    )

    trainer = Trainer(
        model,
        training_args,
        train_dataset=processed_data["train"],
        eval_dataset=processed_data["validation"]
    )

    trainer.place_model_on_device = False
    trainer.train()

    trainer.save_model('/content/drive/MyDrive/Q&A/Splinter')

In [None]:
def _mp_fn(index):

    device = xm.xla_device()

    model = WRAPPED_MODEL.to(device)

    train_loop(model)

xmp.spawn(_mp_fn, start_method="fork")

***** Running training *****
  Num examples = 10
  Num Epochs = 15
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 15


Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678




Training... 

Epoch,Training Loss,Validation Loss
1,25.8582,20.52029
2,12.6615,18.528233
3,11.4174,17.624628
4,6.0676,17.258915
5,5.0375,16.786728
6,3.9574,16.678822
7,3.1564,16.836023
8,0.9511,16.94495
9,0.5385,17.016155
10,0.7759,17.079678


***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 10
  Batch size = 2
***** Running Evaluation *****
  Num examples = 