<a href="https://colab.research.google.com/github/guyez/NLP/blob/main/DistilBertQA_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
'''
Question Answering (QA) System using NLP with SQuAD
EE562 Group3 Project
Megha Chandra Nandyala
Amisha Himanshu Somaiya


APPROACH                   : DistilBERT Pretrained + Additional Head

ADDITIONAL HEAD            : 768->512->32->2
ACTIVATION FUNCTION        : GeLU and GeLU_new

Reason for Pretrained      : Pretraining on large corpus ensures strong understanding of contextual information even before fine-tuning
Reason for Additional Head : Additional head is required for QA task, since DistilBERT is pre-trained for general NLP tasks

HYPERPARAMETER TRIALS FOR BEST FINE-TUNING :  (3 hours per training on T4 GPU, Colab)
(Epochs, Learning Rate, Batch Size, Weight Decay):
3, 5.00E-05, 32, 0.01
5, 5.00E-05, 32, 0.01
7, 5.00E-05, 32, 0.01
3, 5.00E-05, 32, 0.01
3, 2.00E-05, 32, 0.01
3, 5.00E-04, 32, 0.01
3, 5.00E-05, 8,  0.01
3, 5.00E-05, 16, 0.01
3, 5.00E-05, 32, 0.01
3, 5.00E-05, 32, 0.001
3, 5.00E-05, 32, 0.01
3, 5.00E-05, 32, 0.1

REFERENCES :
https://arxiv.org/abs/1810.04805
https://arxiv.org/abs/1910.01108
https://rajpurkar.github.io/SQuAD-explorer/
https://huggingface.co/models
https://huggingface.co/nlpunibo
https://huggingface.co/docs/transformers/model_doc/auto
https://huggingface.co/docs/transformers/main_classes/data_collator
https://discuss.huggingface.co/t/squad-bert-why-max-length-384-by-default-and-not-512/11693



'''




# Install required packages
!pip install transformers
!pip install datasets
!pip install accelerate -U

# Import required libraries
import numpy as np
import pandas as pd
import torch
import json
import sys
import time
import datetime
import random
import collections
from pathlib import Path
import transformers
import datasets
from datasets import load_dataset
# Connect Drive
from google.colab import drive
drive.mount("/content/drive")

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6
Collecting accelerate
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     

In [2]:
#connect GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Tue Dec 12 06:37:29 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8    10W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [9]:
#Path to get SQuAD
FOLDER_NAME = "Question-Answering-SQUAD-main/data"
JSON_TEST_FILE = "training_set.json"
data_path = "drive/My Drive/EE562_Group3_Project/distilled_bert/" + FOLDER_NAME +"/"
file_path = data_path + JSON_TEST_FILE
checkpoint_path = data_path

In [11]:
'''
Load SQuAD version 1.1 since our system is closed-domain.
Future Work is to add system functionality to work with open questions in SQuAD version 2.
'''
class LoadData():
    def __init__(self,
                 path_to_json_file: str,
                 checkpoint_path: str,
                 train_file: str = 'train.json',
                 val_file: str = 'val.json') -> None:

        self.path_to_json_file = path_to_json_file  #specify paths to laoad
        self.checkpoint_path = checkpoint_path
        self.train_file = train_file
        self.val_file = val_file
        self.data = self.load_data()

    def load_data(self): #load data
        with open(self.path_to_json_file, 'r') as f:
            squad_data = json.load(f)
        version = squad_data.get("version", "")
        train_data, val_data, errors = self.load_squad_data(squad_data)
        with open(Path(self.checkpoint_path) / Path(self.train_file), 'w') as file:
            json.dump({"data": train_data}, file)
        with open(Path(self.checkpoint_path) / Path(self.val_file), 'w') as file:
            json.dump({"data": val_data}, file)
        return squad_data

    def load_squad_data(self, data, split=0.2):
        errors = 0
        flattened_data_train = [] #initialize empty lists for training and validation data
        flattened_data_val = []
        train_range = int(len(data['data']) * (1 - split)) #calculate split index
        for i, article in enumerate(data["data"]):
            title = article.get("title", "").strip() #eseparate title of the article
            for paragraph in article["paragraphs"]:
                context = paragraph["context"].strip()  #context of paragraph
                for qa in paragraph["qas"]:
                    question = qa["question"].strip() #separate the question
                    id_ = qa["id"]
                    answer_starts = [answer["answer_start"] for answer in qa["answers"]] #answer start position
                    answers = [answer["text"].strip() for answer in qa["answers"]] #answer text
                    flattened_data = {"title": title,  #store all of the above in a dictionary for current record
                                      "context": context,
                                      "question": question,
                                      "id": id_,
                                      "answers": {
                                          "answer_start": answer_starts,
                                          "text": answers}
                                      }
                    if i <= train_range:
                        flattened_data_train.append(flattened_data)  #repeat for all
                    else:
                        flattened_data_val.append(flattened_data) #repeat for val
        return flattened_data_train, flattened_data_val, errors
_ = LoadData(file_path, checkpoint_path)

In [12]:
#load data
from datasets import load_dataset
train_data = load_dataset('json', data_files=data_path+"train.json", field='data')
val_data = load_dataset('json', data_files=data_path+"val.json", field='data')

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [13]:
def get_text(answer: list) -> str:
    return answer[0] #Extract only the text from answers.text column

def get_json_data(json_path: str) -> dict:
     with open(json_path, 'r') as f:
        json_data = json.load(f)
     return json_data

train_dataframe = pd.json_normalize(get_json_data(data_path + "train.json"), record_path='data') #load data and normalize
train_dataframe["answers.text"] = train_dataframe["answers.text"].apply(get_text)
val_dataframe = pd.json_normalize(get_json_data(data_path + "val.json"), record_path='data')
val_dataframe["answers.text"] = val_dataframe["answers.text"].apply(get_text)
train_dataframe


Unnamed: 0,title,context,question,id,answers.answer_start,answers.text
0,University_of_Notre_Dame,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,5733be284776f41900661182,[515],Saint Bernadette Soubirous
1,University_of_Notre_Dame,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,5733be284776f4190066117f,[188],a copper statue of Christ
2,University_of_Notre_Dame,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,5733be284776f41900661180,[279],the Main Building
3,University_of_Notre_Dame,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,5733be284776f41900661181,[381],a Marian place of prayer and reflection
4,University_of_Notre_Dame,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,5733be284776f4190066117e,[92],a golden statue of the Virgin Mary
...,...,...,...,...,...,...
69387,Empiricism,John Dewey (1859–1952) modified James' pragmat...,Who came up with 'instrumentalism'?,572b459134ae481900dead71,[0],John Dewey
69388,Empiricism,John Dewey (1859–1952) modified James' pragmat...,What did Dewey think about reality?,572b459134ae481900dead72,[317],reality is determined by past experience
69389,Empiricism,John Dewey (1859–1952) modified James' pragmat...,When was Dewey born?,572b459134ae481900dead73,[12],1859
69390,Empiricism,John Dewey (1859–1952) modified James' pragmat...,When did Dewey die?,572b459134ae481900dead74,[17],1952


In [14]:
'''
Tokenization steps using NLTK punkt in ML model are performed by huggingface AutoTokenizer here
'''
from transformers import AutoTokenizer
import transformers
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [15]:
'''
https://discuss.huggingface.co/t/squad-bert-why-max-length-384-by-default-and-not-512/11693
'''
max_length = 384 #max length as per model capability
doc_stride = 128 #authorized overlap between two parts of the context when splitting it is needed.
pad_on_right = tokenizer.padding_side == "right"   #regular model with padding on right

In [16]:
def prepare_train_features(examples: collections.OrderedDict or dict) -> transformers.tokenization_utils_base.BatchEncoding:
    '''
    Tokenize examples with truncation and padding, keeping overflows using a stride.
    This results in multiple features when a context is long, each with overlapping context.
    '''
    question_key = "question" if pad_on_right else "context"
    context_key = "context" if pad_on_right else "question"
    tokenized_examples = tokenizer(
        examples[question_key],
        examples[context_key],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []
    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_examples.sequence_ids(i)
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        start_char = answers["answer_start"][0]
        end_char = start_char + len(answers["text"][0])
        token_start_index = next(i for i, sid in enumerate(sequence_ids) if sid == (1 if pad_on_right else 0))
        token_end_index = next(i for i, sid in enumerate(sequence_ids[::-1]) if sid == (1 if pad_on_right else 0))
        token_start_index = min(token_start_index, len(offsets) - 1)
        token_end_index = len(offsets) - 1 - token_end_index
        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)

            while offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)
    return tokenized_examples

In [17]:
# Features generation
features = prepare_train_features(train_data['train'][:5])
train_tokenized_datasets = train_data.map(prepare_train_features, batched=True, remove_columns=train_data['train'].column_names)

Map:   0%|          | 0/69392 [00:00<?, ? examples/s]

In [18]:
import math
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers import DistilBertPreTrainedModel, DistilBertModel
from torch import nn
import torch

def gelu(x):
    '''
    Original Implementation of the GELU activation function in Google BERT repo when initially created.
    '''
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def gelu_new(x):
    '''
    Implementation of the GELU activation function currently in Google BERT repo.
    '''
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)    #distilbert
        self.qa_outputs_0 = nn.Linear(config.dim, 512)  #additional head
        self.qa_outputs_1 = nn.Linear(512, 32)
        self.qa_outputs = nn.Linear(32, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout) #dropout
        self.LayerNorm = nn.LayerNorm(normalized_shape=[384, 2])  #output
        self.init_weights()
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        distilbert_output = self.distilbert(  #forward pass
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = distilbert_output.last_hidden_state  #last hidden state
        hidden_states = self.dropout(hidden_states)  #dropout
        logits = gelu_new(self.qa_outputs_0(hidden_states))  #apply activation function
        logits = gelu_new(self.qa_outputs_1(logits))
        logits = self.qa_outputs(logits)
        logits = self.LayerNorm(logits) #layer normalization
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
        total_loss = None #initialize
        if start_positions is not None and end_positions is not None: #calculate loss if start and end position are valid
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
        if not return_dict:
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output
        return QuestionAnsweringModelOutput( #return output
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions
        )

In [19]:
#instantiate the model
from transformers import TrainingArguments, Trainer
model = DistilBertForQuestionAnswering.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
  model.cuda()

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs_1.weight', 'LayerNorm.weight', 'qa_outputs_0.bias', 'qa_outputs.bias', 'LayerNorm.bias', 'qa_outputs_1.bias', 'qa_outputs_0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
# training hyperparameters
batch_size = 32
args = TrainingArguments(
    output_dir='./results',
    save_total_limit=5,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    label_names=["start_positions", "end_positions"]
)

In [21]:
#use data collator to handle different types of data and batch them properly during training
from transformers import default_data_collator
data_collator = default_data_collator

In [22]:
'''
If there is more than 1 prediction for a query then the query with higher similarity score is outputted as the prediction
'''
import collections
from tqdm import tqdm, tqdm_notebook
def postprocess_qa_predictions(examples: datasets.arrow_dataset.Dataset,
                               features: datasets.arrow_dataset.Dataset,
                               raw_predictions: tuple,
                               n_best_size: int = 20,
                               max_answer_length: int = 50) -> collections.OrderedDict:
    '''
    Function used to select the best answer from the raw predictions
    '''
    all_start_logits, all_end_logits = raw_predictions  #unpack

    example_id_to_index = {k: i for i, k in enumerate(examples["id"])} #map id to index
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features): #group features
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    predictions = collections.OrderedDict() #dictionary to store final predictions
    for example_index, example in enumerate(tqdm(examples)):
        feature_indices = features_per_example[example_index]
        valid_answers = []
        context = example["context"] #get context
        for feature_index in feature_indices:
            start_logits = all_start_logits[feature_index]  #get start and end logits for current feature
            end_logits = all_end_logits[feature_index]
            offset_mapping = features[feature_index]["offset_mapping"] #get offset mapping for current feature
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]  #calculate feature null score
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() #indices of top-n start and end logits
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if (        #invalid indices or null offset mappings check
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length: #invalid answer span or exceeding maximum answer length check
                        continue
                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(  #list of valid answers
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )

        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] #answer with highest score is best answer
        else:
            best_answer = {"text": "", "score": 0.0}  #if no valid answers then empty string
        predictions[example["id"]] = best_answer["text"]
    return predictions

In [23]:
def prepare_validation_features(examples: collections.OrderedDict or dict) -> transformers.tokenization_utils_base.BatchEncoding:
    '''
    To check a given span is inside the context (and not the question) and to get back the text inside.
    '''
    tokenized_examples = tokenizer(                          #tokenize
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples["offset_mapping"]
    tokenized_examples["start_positions"] = []  #lists to store start and end positions
    tokenized_examples["end_positions"] = []
    for i, offsets in enumerate(offset_mapping):

        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id) # CLS index
        sequence_ids = tokenized_examples.sequence_ids(i)   #sequence ids from tokenized examples
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        start_char = answers["answer_start"][0]    #start and end character positions from answers
        end_char = start_char + len(answers["text"][0])
        token_start_index = 0   #token start and end indices
        while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
            token_start_index += 1
        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
            token_end_index -= 1
        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            tokenized_examples["start_positions"].append(token_start_index - 1)
            while offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            tokenized_examples["end_positions"].append(token_end_index + 1)
    tokenized_examples["example_id"] = []
    for i in range(len(tokenized_examples["input_ids"])):
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])
        tokenized_examples["offset_mapping"][i] = [     #modify offset mappings based on context_index
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]
    return tokenized_examples

In [24]:
validation_features = val_data['train'].map(prepare_validation_features, batched=True, remove_columns=val_data['train'].column_names)

Map:   0%|          | 0/18207 [00:00<?, ? examples/s]

In [25]:
metric = datasets.load_metric("squad")

  metric = datasets.load_metric("squad")


Downloading builder script:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

In [26]:
def compute_metrics(pred: transformers.trainer_utils.EvalPrediction) -> dict:
    # The Trainer hides the columns that are not used by the model (here example_id and offset_mapping which we will need for our post-processing), so we set them back
    validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))
    # To get the final predictions we can apply our post-processing function to our raw predictions
    final_predictions = postprocess_qa_predictions(val_data['train'], validation_features, pred.predictions)
    # We just need to format predictions and labels a bit as metric expects a list of dictionaries and not one big dictionary
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
    references = [{"id": ex["id"], "answers": ex["answers"]} for ex in val_data["train"]]
    validation_features.set_format(type=validation_features.format["type"], columns=['attention_mask', 'end_positions', 'input_ids', 'start_positions'])
    metrics = metric.compute(predictions=formatted_predictions, references=references)
    return metrics

In [27]:
#instantiate Trainer and begin training
trainer = Trainer(
    model,
    args,
    compute_metrics=compute_metrics,
    train_dataset=train_tokenized_datasets["train"],
    eval_dataset=validation_features,
    data_collator=data_collator,
    tokenizer=tokenizer
)
trainer.train()

Epoch,Training Loss,Validation Loss,Exact Match,F1
1,1.2794,1.298775,58.807052,74.397433


100%|██████████| 18207/18207 [00:58<00:00, 310.77it/s]


Epoch,Training Loss,Validation Loss,Exact Match,F1
1,1.2794,1.298775,58.807052,74.397433
2,0.872,1.264256,59.916516,75.594801
3,0.5943,1.352394,60.279014,76.038806


100%|██████████| 18207/18207 [00:56<00:00, 322.22it/s]
100%|██████████| 18207/18207 [00:59<00:00, 306.64it/s]


TrainOutput(global_step=6576, training_loss=1.006303279649312, metrics={'train_runtime': 8052.4091, 'train_samples_per_second': 26.126, 'train_steps_per_second': 0.817, 'total_flos': 2.081373570268877e+16, 'train_loss': 1.006303279649312, 'epoch': 3.0})

In [28]:
#save trained model
trainer.save_model(data_path + "test-squad-trained")