Here, we will fine-tune a BERT deep learning model for extractive question and answering. Make sure to run the notebook on Google Colab (T4 GPU instance) or equivalent RunPod or AWS GPU instance

We will also visualize the training. Code was modfied from here:  https://wandb.ai/mostafaibrahim17/ml-articles/reports/Extractive-Question-Answering-With-HuggingFace-Using-PyTorch-and-W-B--Vmlldzo0MzMwOTY5



In [1]:
!pip install -q transformers datasets accelerate wandb

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/471.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m471.0/471.6 kB[0m [31m23.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [13]:
import json
import pandas as pd
import torch
from transformers import (
	BertTokenizerFast,
        BertForQuestionAnswering,
        TrainingArguments,
        Trainer,
 )
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

# Specify model for finetuning
model_name = "dmis-lab/biobert-base-cased-v1.1"

## Load dataset

The COVID-19 Q & A dataset is a json file. Convert the questions, context and answers into a dataframe. Dataset was taken from Kaggle here: https://www.kaggle.com/datasets/kaysarulanas/covidqa-dataset?resource=download

In [14]:
# Read the json file line by line and save questions, context and answers into a dataframe
with open(r"data/COVID-QA.json", "r") as f:
  data = json.load(f)

questions = []
answers = []
contexts = []

for entry in data['data']:
 for paragraph in entry['paragraphs']:
  context = paragraph['context']
  for qa in paragraph['qas']:
    questions.append(qa['question'])
    answers.append(qa['answers'][0]['text'])
    contexts.append(context)

# Create dataframe and display contents
df = pd.DataFrame({
  'question': questions,
  'answer': answers,
  'context': contexts
})

display(df.head())

Unnamed: 0,question,answer,context
0,What is the main cause of HIV-1 infection in c...,Mother-to-child transmission (MTCT) is the mai...,Functional Genetic Variants in DC-SIGNR Are As...
1,What plays the crucial role in the Mother to C...,DC-SIGNR plays a crucial role in MTCT of HIV-1...,Functional Genetic Variants in DC-SIGNR Are As...
2,How many children were infected by HIV-1 in 20...,"more than 400,000 children were infected world...",Functional Genetic Variants in DC-SIGNR Are As...
3,What is the role of C-C Motif Chemokine Ligand...,"High copy numbers of CCL3L1, a potent HIV-1 su...",Functional Genetic Variants in DC-SIGNR Are As...
4,What is DC-GENR and where is it expressed?,Dendritic cell-specific ICAM-grabbing non-inte...,Functional Genetic Variants in DC-SIGNR Are As...


## Tokenize dataset

In [15]:
dataset = Dataset.from_pandas(df)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

def tokenize(batch):
  '''Tokenizes text using the appropriate model tokenizer and finds the starting and ending positions of the answers in the text'''
  tokenized_batch = tokenizer(batch["question"], batch["context"],
                              max_length=512,
                              padding="max_length",
                              truncation=True,
                              return_offsets_mapping=True,
                              return_token_type_ids=True)


  answer_starts = []
  answer_ends = []


  for i, context in enumerate(batch["context"]):
      answer_start = context.find(batch["answer"][i])
      answer_end = answer_start + len(batch["answer"][i])
      answer_starts.append(answer_start)
      answer_ends.append(answer_end)


  tokenized_batch["answer_start"] = answer_starts
  tokenized_batch["answer_end"] = answer_ends


  return tokenized_batch


tokenized_dataset = dataset.map(tokenize, batched=True)




Map:   0%|          | 0/2019 [00:00<?, ? examples/s]

In [16]:
# Check keys of the dataset
display(tokenized_dataset)

# Check example entry of the dataset
display(tokenized_dataset[0])

Dataset({
    features: ['question', 'answer', 'context', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'answer_start', 'answer_end'],
    num_rows: 2019
})

{'question': 'What is the main cause of HIV-1 infection in children?',
 'answer': 'Mother-to-child transmission (MTCT) is the main cause of HIV-1 infection in children worldwide. ',
 'context': "Functional Genetic Variants in DC-SIGNR Are Associated with Mother-to-Child Transmission of HIV-1\n\nhttps://www.ncbi.nlm.nih.gov/pmc/articles/PMC2752805/\n\nBoily-Larouche, Geneviève; Iscache, Anne-Laure; Zijenah, Lynn S.; Humphrey, Jean H.; Mouland, Andrew J.; Ward, Brian J.; Roger, Michel\n2009-10-07\nDOI:10.1371/journal.pone.0007211\nLicense:cc-by\n\nAbstract: BACKGROUND: Mother-to-child transmission (MTCT) is the main cause of HIV-1 infection in children worldwide. Given that the C-type lectin receptor, dendritic cell-specific ICAM-grabbing non-integrin-related (DC-SIGNR, also known as CD209L or liver/lymph node–specific ICAM-grabbing non-integrin (L-SIGN)), can interact with pathogens including HIV-1 and is expressed at the maternal-fetal interface, we hypothesized that it could influence

## Filter for valid entries in dataset

Remove all those data where context does not contain answers

In [17]:
def prepare_train_features(example):
    start_position = example["input_ids"].index(tokenizer.cls_token_id)
    end_position = example["input_ids"].index(tokenizer.sep_token_id)


    found_start = False
    found_end = False
    for i, (offset_start, offset_end) in enumerate(example["offset_mapping"]):
        if not found_start and offset_start == example["answer_start"]:
            start_position = i
            found_start = True
        if not found_end and offset_end == example["answer_end"]:
            end_position = i
            found_end = True
        if found_start and found_end:
            break


    if not found_start or not found_end:
        start_position = -1
        end_position = -1


    example["start_positions"] = start_position
    example["end_positions"] = end_position
    return example


prepared_dataset = tokenized_dataset.map(prepare_train_features, batched=False)


def filter_invalid_examples(example):
    return example["start_positions"] != -1 and example["end_positions"] != -1


filtered_dataset = prepared_dataset.filter(filter_invalid_examples, batched=False)


Map:   0%|          | 0/2019 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2019 [00:00<?, ? examples/s]

## Split data into train and test set

In [18]:
train_indices, eval_indices = train_test_split(list(range(len(filtered_dataset))), test_size=0.1, random_state=42)
train_dataset = filtered_dataset.select(train_indices)
eval_dataset = filtered_dataset.select(eval_indices)


def convert_to_tensors(example):
  '''Takes input IDs and attention masks from train and eval tokenized dataset and converts them to tensors'''
  example["input_ids"] = torch.tensor(example["input_ids"], dtype=torch.long)
  example["attention_mask"] = torch.tensor(example["attention_mask"], dtype=torch.long)
  return example


train_dataset = train_dataset.map(convert_to_tensors)
eval_dataset = eval_dataset.map(convert_to_tensors)


dataset_dict = DatasetDict({"train": train_dataset, "eval": eval_dataset})


Map:   0%|          | 0/220 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

## Initilize model, setup training configuration and run training

In [24]:
# Initialize model and run training loop
model = BertForQuestionAnswering.from_pretrained(model_name)
# Optional: make tensor contigous to prevent issues occuring after a certain number of training epochs
for param in model.parameters():
  param.data = param.data.contiguous()

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    fp16=False,
    load_best_model_at_end=True,
    report_to="wandb",
    run_name="bert-qa-covid"
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["eval"],
)


trainer.train()

# Note that the Trainer object also supports hyperparameter tuning. See: https://github.com/huggingface/notebooks/blob/main/examples/text_classification.ipynb

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss
1,6.2801,6.24275
2,6.2563,6.181658
3,6.151,6.094459
4,6.0602,5.95331
5,5.8038,5.720126
6,5.5794,5.374663
7,5.3216,4.963151
8,4.7546,4.5224
9,4.46,4.2212
10,3.6559,3.864543


TrainOutput(global_step=1400, training_loss=0.7294905604518551, metrics={'train_runtime': 3469.9995, 'train_samples_per_second': 6.34, 'train_steps_per_second': 0.403, 'total_flos': 5748528648192000.0, 'train_loss': 0.7294905604518551, 'epoch': 100.0})

## Save and load the model for later

In [37]:
model.save_pretrained("trained_model")
tokenizer.save_pretrained("trained_model")

# Create a tar.gz file for your model (preferred format for S3 inferencing by HuggingFace)
!tar -czvf model.tar.gz -C './trained_model' .

# If using Google Colab, uncomment the code below to download the folder with all your model data to your Google Drive folder so you can transfer to S3
from google.colab import drive
drive.mount('/content/drive')

!cp -r './model.tar.gz' '/content/drive/MyDrive/'

# If it all fails, download the folder manually
# files.download('trained_model')

model = BertForQuestionAnswering.from_pretrained("trained_model")

./
./tokenizer.json
./tokenizer_config.json
./config.json
./vocab.txt
./special_tokens_map.json
./model.safetensors
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Get best answer using model


In [26]:
def get_answer(question, context):
  '''Returns the best answer from trained model based on provided question and context'''
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
  start_logits, end_logits = model(**inputs).values()


  start_index_and_logits = torch.argmax(start_logits, dim=1).item(), start_logits[0].max().item()
  end_index_and_logits = torch.argmax(end_logits, dim=1).item(), end_logits[0].max().item()


  if end_index_and_logits[0] >= start_index_and_logits[0]:
      start_index, end_index = start_index_and_logits[0], end_index_and_logits[0]
  else:
      if start_index_and_logits[1] > end_index_and_logits[1]:
          start_index, end_index = start_index_and_logits[0], start_index_and_logits[0]
      else:
          start_index, end_index = end_index_and_logits[0], end_index_and_logits[0]


  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index+1]))
  return answer


## Test out Q and A capabilities

In [27]:
question = "Has multidrug resistant infections increased because of COVID-19"
context = '''
Results: Among the 7106 specimens, there was a significant increase in the multidrug-resistant bacterial from 27.38% to 35.87% during COVID-19 (p<0.001), particularly in blood culture, cerebrospinal fluid, catheter, and pus. However, there was a non-significant change in puncture fluid, expectoration, protected distal sampling, joint fluid, stool culture, and genital sampling. A decrease in Multidrug-resistant bacteria (MDRB) was observed only in cytobacteriological urine tests (p<0.05). According to species, there was an increase in extended-spectrum beta-lactamase-producing Enterobacteriaceae, carbapenem-resistant Enterobacteriaceae, and methicillin-resistant Staphylococcus aureus.

Conclusion: In our study, it is particularly noticeable that the MDRB has increased. These results highlight the importance that the pandemic has not been able to slow the progression.
'''
answer1 = get_answer(question, context)
print("Answer:", answer1)


Answer: multi
