<a href="https://colab.research.google.com/github/IVN-RIN/bio-med-BIT/blob/main/notebooks/BioBIT_Question_Answering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **BioBIT Fine-Tuning Experiment For <u>Question Answering</u>**

*Tommaso M Buonocore, University of Pavia, 2022*

*Last edited: 28/11/2022*

*Related paper: [Localising In-Domain Adaptation of Transformer-Based Biomedical Language Models](https://www.medrxiv.org/content/XXXXXX)*

#Initialization

Short string describing the current run

In [None]:
experiment_name = "bioasq4a_base_QA"

##Imports

In [None]:
%%capture
# If running on colab, install first
!pip3 install datasets transformers evaluate seqeval

# Google Colab only
from IPython.display import display, HTML
from google.colab import files

# General
import random
import pandas as pd
import numpy as np
from torch import cuda
import os
from io import StringIO
import json
from uuid import uuid4
import time

# HuggingFace Transformers
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback, set_seed
from datasets import load_dataset, load_metric, ClassLabel, Sequence, DatasetDict, Features, Value, Sequence, ClassLabel, Dataset
from evaluate import load

# Set device to GPU Cuda if available 
device = 'cuda' if cuda.is_available() else 'cpu'

input_file_path = "BioASQ_4b_splitted.json"

##Session info

In [None]:
session_info = json.loads(os.popen("curl curl ipinfo.io").read())
if device=='cuda':
  gpu_info = pd.read_csv(StringIO(os.popen("nvidia-smi --query-gpu=gpu_name,memory.total --format=csv").read()),names=["name","memory"],header=0)
  session_info[f'gpus'] = [{'name': row["name"], 'memory': row["memory"]} for index, row in gpu_info.iterrows()] 
else: 
  session_info[f'gpus'] = []
session_info['time_start'] = time.strftime("%H:%M:%S", time.localtime())
session_info['experiment_name'] = experiment_name
session_info

# Data Preprocessing

Input: json file with train, test, dev splits

Output Expected:

      {
        {'answers:{
            'answer_start': [(int)] -> the char position where the answer starts in the context, provided as an array
            'text' [(str)] -> the answer as a string, provided as an array
        }
        'context': (str) --> the context as a string
        'id': (str) --> unique id
        'question': (str) --> the question as a string
        'title': (str) --> the title of the source document as a string
      }


Questions have been translated --> we've lost the char index --> get it back by searching the answer in the context --> get the first match if multiple.

FUTURE IMPLEMENTATION: if original_char_index is available --> get the match that is closer to original_char_index

In [None]:
def get_answer_start(answer,context):
  idx = None
  try:
    idx = context.index(answer)
  except:
    #automatic translation can generate uppercase first letter
    try: 
      idx = context.index(answer[0].lower() + answer[1:])
    except:
      #attempt: lowercase everything
      try: 
        idx = context.lower().index(answer.lower())
      except:
        #give up
        idx = None
        print("Warning: char index not found in provided context for '"+answer+"'")
        print(context)
  return idx

Parse the raw json to be compliant with the schema above. No info about documents provided --> replace document name with random ids

In [None]:
import json
from uuid import uuid4
import os

ENCODING = 'utf-8'

def parse(file_path):
    f = open(file_path, encoding=ENCODING)
    json_data = json.load(f)
    for split in ["train","test","dev"]:
      output = {'data': []}
      input = json_data[split]  # input to iterate through
      id=0
      for example in input:
        output['data'].append({
            'answers': {
                'answer_start': [get_answer_start(example['answer'],example['context'])],
                'text': [example['answer']]
            },
            'context': example['context'],
            'id': str(uuid4()),
            'question': example['question'],
            'title': str(uuid4())
        })
      # Save file as a new json
      with open(os.path.splitext(file_path)[0]+'-'+split+os.path.splitext(file_path)[1], 'w') as outfile:
        json.dump(output, outfile)
    # Closing file
    f.close()

Load parsed json as a Huggingface dataset

In [None]:
parse(input_file_path)
dataset = load_dataset("json", data_files={'train': f"{os.path.splitext(input_file_path)[0]}-train.json",
                                           'test': f"{os.path.splitext(input_file_path)[0]}-test.json",
                                           'dev': f"{os.path.splitext(input_file_path)[0]}-dev.json"}, field="data")

In [None]:
dataset

# Training

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
model_checkpoints = [
    "dbmdz/bert-base-italian-xxl-cased", # Baseline
    "/content/gdrive/MyDrive/Colab Environments/biobert_models/bio-full", # BioBERT
    #"/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v3", # Best model w/o corpus augmentation
    #"/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v12", # Best ER model w/ corpus augmentation
    "/content/gdrive/MyDrive/Colab Environments/biobert_models/med-reg-v3-enriched" # Best MIXOUT model w/ corpus augmentation
    ]
seeds = [
    #3407, 
    #6, 
    #11, 
    61, 
    #39
    ]

#This can be changed according to the downstream dataset. The only important thing is that they remain consistent for *ALL* the models  
batch_size = 10
learning_rate = 3e-5
epochs=20
weight_decay=0.01

## Define metrics

In [None]:
#metric = load_metric("squad")
#tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

def compute_metrics(p):
  logits, labels = p
  predictions = np.argmax(logits, axis=2)
  legit_predictions = (predictions[1]-predictions[0]) >0
  labels_cls = labels[0]!=0
  tokens_diff = np.sum(predictions-labels, axis=0)
  tot = len(tokens_diff)
  tolerance = 10

  exact_match = tokens_diff==0 #a.k.a strict accuracy
  approx_match = abs(tokens_diff)<=tolerance

  #accuracy = exact match of end tokens and start tokens
  accuracy = np.sum(exact_match)/tot
  #accuracy no cls = exact match without considering [CLS] tokens. Useful when we know that there will be >512 tokens long contexts 
  #doc stride will break context in subcontext and assign 0 as the expected prediction for subcontexts that do not include the answer
  accuracy_not_cls = np.sum(np.logical_and(exact_match, labels_cls))/np.sum(labels_cls)
  #accuracy_tol = predicted end tokens and start tokens are in close proximity with the actual ones
  #we have to filter out the examples that pass the proximity condition but have pred end tokens < pred start tokens, because it is impossible
  #this is automatically addressed in accuracy because it is an exact match
  accuracy_approx = np.sum(np.logical_and(approx_match,legit_predictions))/tot

  #pred_string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[predictions[0]:predictions[1]]))
  #true_string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
  #return metric.compute(predictions=p.predictions, references=p.label_ids)
  return {'acc': accuracy, 'acc_no_cls': accuracy_not_cls, 'acc_approx': accuracy_approx}

## QA Trainer Class Definition

In [None]:
from torch import argmax, softmax
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, default_data_collator, Trainer
import re
from datasets import load_metric
import numpy as np

# ==== PREDEFINED SETTINGS ====
SEED = 666
CHECKPOINT = 'mrm8488/bert-italian-finedtuned-squadv1-it-alfa'
BATCH_SIZE = 10
MAX_LENGTH = 384  # The maximum length of a feature (question and context)
DOC_STRIDE = 128  # The authorized overlap between two part of the context when splitting it is needed.
LEARNING_RATE = 1e-5
EPOCHS = 20
PATIENCE = 4
WEIGHT_DECAY = 0.01
OUTPUT_DIR = "model"
METRIC_BEST = "acc"

In [None]:
# ==== TRAINER QA ====
class TrainerQA(object):
    """ This class provides utilities to train QA Transformers """

    def __init__(self,
                 checkpoint=CHECKPOINT,
                 batch_size=BATCH_SIZE,
                 epochs=EPOCHS,
                 seed=SEED,
                 learning_rate=LEARNING_RATE,
                 weight_decay=WEIGHT_DECAY,
                 output_dir=OUTPUT_DIR):
        self._tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        self._model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)
        self._args = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="epoch",
            logging_strategy="epoch",
            save_strategy="epoch",
            save_total_limit=5,
            load_best_model_at_end = True,
            metric_for_best_model = METRIC_BEST,
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=epochs,
            weight_decay=weight_decay,
            seed=seed
        )
        self._data_collator = default_data_collator
        self._tokenized_datasets = None
        self.__debug_mode = False

    def __prepare_train_features(self, examples, max_length=MAX_LENGTH, doc_stride=DOC_STRIDE):
        pad_on_right = self._tokenizer.padding_side == "right"
        # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = self._tokenizer(
            examples["question" if pad_on_right else "context"],
            examples["context" if pad_on_right else "question"],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(self._tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples["answers"][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)

        self._tokenizer            

        return tokenized_examples

    def load_dataset(self, dataset):
        self._tokenized_datasets = dataset.map(self.__prepare_train_features, batched=True, remove_columns=dataset["train"].column_names)

    def train(self, save_path):
        trainer = Trainer(
            self._model,
            self._args,
            train_dataset=self._tokenized_datasets["train"],
            eval_dataset=self._tokenized_datasets["dev"],
            data_collator=default_data_collator,
            tokenizer=self._tokenizer,
            compute_metrics=compute_metrics,
            callbacks = [EarlyStoppingCallback(early_stopping_patience = PATIENCE)]
        )
        trainer.train()
        trainer.save_model(save_path)

In [None]:
# ==== MODEL QA ====
class ModelQA(object):
    """ This class provides utilities to use QA Transformers """

    def __init__(self, model_path):
        self._tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
        self._model = AutoModelForQuestionAnswering.from_pretrained(model_path, local_files_only=True)
        self._shared_context = None
        self.__debug_mode = False

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        return self._model

    @property
    def shared_context(self):
        return self._shared_context

    @shared_context.setter
    def shared_context(self, value):
        self._shared_context = value

    # Methods
    def toggle_debug(self):
        self.__debug_mode = not self.__debug_mode

    def ask_question_stride(self, question="", context=None, print_results=True, max_length=MAX_LENGTH, doc_stride=DOC_STRIDE):
        pad_on_right = self._tokenizer.padding_side == "right"
        # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = self._tokenizer(
            question if pad_on_right else context,
            context if pad_on_right else question,
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_length,
            stride=doc_stride,
            add_special_tokens=True,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            return_tensors="pt",
            padding="max_length",
        )

        answers = {'answer': [], 'probs': [], 'indices': []}
        for i in range(len(tokenized_examples['input_ids'])):
          inputs = {}
          for key in ['input_ids', 'token_type_ids', 'attention_mask']:
            inputs[key] = stack([tokenized_examples[key][i]])

          input_ids = inputs["input_ids"].tolist()[0]

          #print(inputs["input_ids"].tolist())
          text_tokens = self._tokenizer.convert_ids_to_tokens(input_ids)  # Very useful for debug!!!
          outputs = self._model(**inputs)
          answer_start_scores = outputs.start_logits
          # print(answer_start_scores)
          # support_start = answer_start_scores.tolist()
          answer_end_scores = outputs.end_logits
          answer_start = argmax(answer_start_scores)  # Get the most likely beginning of answer with the argmax of the score
          answer_end = argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score
          # WARNING: convert_tokens_to_string automatically adds " " after special tokens! e.g. 1.2 mV --> 1. 2 mV
          answer = self._tokenizer.convert_tokens_to_string(self._tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

          # Get % of start and end tokens
          answer_start_list = softmax(answer_start_scores, dim=1).tolist()[0]
          start_index, start_prob = max(enumerate(answer_start_list), key=lambda x: x[1])
          answer_end_list = softmax(answer_end_scores, dim=1).tolist()[0]
          end_index, end_prob = max(enumerate(answer_end_list), key=lambda x: x[1])

          if print_results:
            print("-" * 30)
            print(f"Question: {question}")
            print(f"Answer: {answer}")
            print(f'Start token probability: {start_prob:.2f}')
            print(f'End token probability: {end_prob:.2f}')
          else:
            answers["answer"].append(answer)
            answers["probs"].append([start_prob, end_prob])
            answers["indices"].append([int(answer_start), int(answer_end)])
        return answers

    def ask_question(self, question="", context=None, print_results=True, return_topK = 0):
        if context is None:
            if self._shared_context is None:
                raise Exception("Shared context not provided")
            else:
                context = self._shared_context

        inputs = self._tokenizer(question, context, add_special_tokens=True, return_tensors="pt", truncation=True)
        input_ids = inputs["input_ids"].tolist()[0]

        # print(inputs["input_ids"].tolist())
        text_tokens = self._tokenizer.convert_ids_to_tokens(input_ids)  # Very useful for debug!!!
        outputs = self._model(**inputs)
        answer_start_scores = outputs.start_logits
        # print(answer_start_scores)
        # support_start = answer_start_scores.tolist()
        answer_end_scores = outputs.end_logits
        answer_start = argmax(answer_start_scores)  # Get the most likely beginning of answer with the argmax of the score
        answer_end = argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

        ### DEBUG ##############

        # ### <TEST> return the top k answers (i.e. the top k couples that maximizes the sum between beginning and end token)
        topk_indices = None

        # if return_topK>0:
        #   #get the most likely beginnings and ends (top k)
        #   ind_starts = np.argpartition(answer_start_scores,-return_topK)[-return_topK:] #top K most likely beginnings
        #   ind_ends = np.argpartition(answer_end_scores,-return_topK)[-return_topK:] #top K most likely ends

        #   #get the correspondent confidence score
        #   pred_starts = answer_start_scores[ind_starts]
        #   pred_ends = answer_end_scores[ind_ends]

        #   #link the two
        #   dict_starts = dict(zip(ind_starts,pred_starts))
        #   dict_ends = dict(zip(ind_ends,pred_ends))

        #   #get all the possible combinations of the beginnings end ends
        #   indlist = list(product(ind_starts,ind_ends))

        #   #from this set of combinations, keep only the ones where start<=end
        #   indlist_ok = [x for x in indlist if x[0]<=x[1]]

        #   #get the sum of the probabilities for each couple of beginning and ends
        #   sums = [dict_starts[x[0]]+dict_ends[x[1]] for x in indlist_ok]

        #   #sort by max sum and keep top k
        #   indmax_k = np.argsort(sums)[::-1][:return_topK]

        #   #get the top k couples of begginings and ends that maximizes the sum
        #   topk_indices = [indlist_ok[i] for i in indmax_k]

        ### DEBUG ##############


        # WARNING: convert_tokens_to_string automatically adds " " after special tokens! e.g. 1.2 mV --> 1. 2 mV
        answer = self._tokenizer.convert_tokens_to_string(self._tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

        # Get % of start and end tokens
        answer_start_list = softmax(answer_start_scores, dim=1).tolist()[0]
        start_index, start_prob = max(enumerate(answer_start_list), key=lambda x: x[1])
        answer_end_list = softmax(answer_end_scores, dim=1).tolist()[0]
        end_index, end_prob = max(enumerate(answer_end_list), key=lambda x: x[1])

        if print_results:
            print("-" * 30)
            print(f"Question: {question}")
            print(f"Answer: {answer}")
            print(f'Start token probability: {start_prob:.2f}')
            print(f'End token probability: {end_prob:.2f}')
        else:
            return {'answer': answer, 'probs': [start_prob, end_prob], 'indices': [int(answer_start), int(answer_end)], 'top_k':topk_indices}

    #   tokens --> lower text and remove punctuation (ENGLISH!), articles and extra whitespace
    #   precision = 1.0 * num_same / len(prediction_tokens)
    #   recall = 1.0 * num_same / len(ground_truth_tokens)
    #   f1 = (2 * precision * recall) / (precision + recall)
    #   exact_match = prediction == ground_truth
    #   warning: iterate over single prediction-reference pairs, do not pass the whole test set to compute! Flawed somewhere
    def validate(self, dataset):
        squad_metric = load("squad")
        formatted_predictions = [
            {"id": str(ex['id']),
             "prediction_text": re.sub("(?<=\\d(\\.|\\,))\\s+(?=\\d)", "", self.ask_question(ex['question'], ex['context'], print_results=False)['answer'])} for ex in dataset]
        references = [{"id": str(ex['id']), "answers": ex['answers']} for ex in dataset]
        return {
            'predictions': formatted_predictions,
            'references': references,
            'metric': squad_metric.compute(predictions=formatted_predictions, references=references)
        }


    ########DEBUG

    # #strict accuracy, lenient accuracy, mean reciprocal rank
    # def validate2(self, dataset):
    #     metric = load_metric("squad")

    #     #1) strict accuracy (if tol = 0)
    #     def idx_match(reference, prediction,tolerance=0):
    #       ref_start = reference["answer_start"][0]
    #       ref_end = ref_start+len(reference["text"][0])
    #       pred_start = prediction[0]
    #       pred_end = prediction[1]
    #       tokens_diff = (ref_start-pred_start)+(ref_end-pred_end)
    #       match = abs(tokens_diff)<=tolerance
    #     matches = [idx_match(self.ask_question(ex['question'], ex['context'], print_results=False)['indices'],
    #                          ex['answers'],
    #                          tolerance=0) for ex in dataset]
    #     strict_acc = np.sum(matches)/len(matches)

    #     #2) lenient accuracy
    #     #if there's a match in at least one of the top k cases
    #     def idx_match_k(reference, top_ks, tolerance=0):
    #       ref_start = reference["answer_start"][0]
    #       ref_end = ref_start+len(reference["text"][0])
    #       match = False
    #       for top_k in top_ks:
    #         pred_start = top_k[0]
    #         pred_end = top_k[1]
    #         tokens_diff = (ref_start-pred_start)+(ref_end-pred_end)
    #         if abs(tokens_diff)<=tolerance:
    #           match = True
    #       return match
    #     matches = [idx_match_k(self.ask_question(ex['question'], ex['context'], print_results=False, return_topK=5)['top_k'],
    #                          ex['answers'],
    #                          tolerance=0) for ex in dataset]
    #     lenient_acc = np.sum(matches)/len(matches)

    #     #3) MRR
        
    #     def idx_match_mrr(reference, top_ks, tolerance=0):
    #       ref_start = reference["answer_start"][0]
    #       ref_end = ref_start+len(reference["text"][0])
    #       match_rank = 0
    #       for i in range(len(top_ks)):
    #         top_k = top_ks[i]
    #         pred_start = top_k[0]
    #         pred_end = top_k[1]
    #         tokens_diff = (ref_start-pred_start)+(ref_end-pred_end)
    #         if abs(tokens_diff)<=tolerance:
    #           match_rank = i
    #           break
    #       if match_rank==0:
    #         return 0
    #       else:
    #         return 1/match_rank
    #     matches = [idx_match_mrr(self.ask_question(ex['question'], ex['context'], print_results=False, return_topK=5)['top_k'],
    #                          ex['answers'],
    #                          tolerance=0) for ex in dataset]
    #     mrr = np.sum(matches)/len(matches)       

    #     return {
    #         'strict_acc': strict_acc,
    #         'lenient_acc': lenient_acc,
    #         'mrr': metric.compute(predictions=predictions, references=references)
    #     }
    ########DEBUG

## Training Loop

In [None]:
import shutil  

for model_checkpoint in model_checkpoints:
  df_results = pd.DataFrame(columns= ['f1', 'exact_match', 'seed'])
  for seed in seeds:
    # Seed must be set before creating the model, otherwise the random head will be initialized in a different way every time and the results will not be replicable
    # From now on, the seed is set for *all* the random processes, including numpy, sklearn, etc...not only for transformers!
    set_seed(seed)

    output_dir = f"/content/{os.path.basename(model_checkpoint)}_ft_QA/{seed}"
    trainer = TrainerQA(checkpoint=model_checkpoint, seed=seed, output_dir=output_dir)
    trainer.load_dataset(dataset)

    # Train the model
    trainer.train(output_dir)

    # Collect results on test set
    model = ModelQA(model_path=output_dir)
    results = model.validate(dataset['test'])
    row = {'f1': results['metric']['f1'],'exact_match': results['metric']['exact_match'], 'seed':seed}
    df_results = df_results.append(row, ignore_index=True)
    display(df_results)

  display(df_results)
  df_results.to_csv(f'/content/results_{os.path.basename(model_checkpoint)}.csv')
  files.download(f'/content/results_{os.path.basename(model_checkpoint)}.csv')

  #Free up memory for next checkpoint iteration
  shutil.rmtree(f"/content/{os.path.basename(model_checkpoint)}_ft_QA",ignore_errors=True)

Finalize session info and download

In [None]:
session_info['checkpoints'] = [os.path.basename(c) for c in model_checkpoints]
session_info['seeds'] = seeds
session_info['training_arguments'] = []
session_info['time_end'] = time.strftime("%H:%M:%S", time.localtime())

with open(f'/content/session_info.json', "w") as outfile:
    outfile.write(json.dumps(session_info, indent=4))
files.download(f'/content/session_info.json')

# Interactive (optional)

Remove `%%script echo skipping` to run

In [None]:
%%script echo skipping

import ipywidgets as widgets
def ask(context, question):
  context_size = len(model._tokenizer(question, context)["input_ids"])
  print(context_size)
  if context_size>512:
    answer = model.ask_question_stride(question, context, print_results=False)
  else:
    answer = model.ask_question(context,question, print_results=False)
  return (answer["answer"], answer["probs"])
int_widget = widgets.interact_manual(ask, 
                                     context=widgets.Textarea('Context', layout=widgets.Layout(width='70%', height='100px')),
                                     question=widgets.Text('Question', layout=widgets.Layout(width='70%')))

# Unassign runtime to avoid wasting compute units

In [None]:
from google.colab import runtime
runtime.unassign()