### Try with question answering using BERT to perform extractive question answering

In [178]:
## based on HuggingFace documentations and tutorials

In [1]:
import pandas as pd
import numpy as np
import os
import ast
import evaluate
import collections
import evaluate
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm

### Load our datasets

In [2]:
abs_path = os.path.abspath('../../')
path_to_data = 'data/processed'

In [3]:
data_files = {
    "train": os.path.join(abs_path, path_to_data, 'train.csv'),
    "validation":  os.path.join(abs_path, path_to_data, 'valid.csv'),
    "test": os.path.join(abs_path, path_to_data, 'test.csv')
}

In [4]:
dataset = load_dataset("csv", data_files=data_files)

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'context', 'answer', 'voted_label', 'start_end_indexes'],
        num_rows: 11137
    })
    validation: Dataset({
        features: ['question', 'context', 'answer', 'voted_label', 'start_end_indexes'],
        num_rows: 1376
    })
    test: Dataset({
        features: ['question', 'context', 'answer', 'voted_label', 'start_end_indexes'],
        num_rows: 1314
    })
})

In [6]:
def convert_to_tuples(example):
    """
    when saved as dataframes, tuples were converted to string so we need to revert that
    """
    example['start_end_indexes'] = ast.literal_eval(example['start_end_indexes'])
    return example

In [179]:
def format_answer(example):
    """
    It was easier to format the dataset as expected by the model
    """
    example['answers'] = {'text': [example['answer']], 'answer_start': [example['start_end_indexes'][0]]}
    return example

In [8]:
dataset = dataset.map(convert_to_tuples)
dataset = dataset.map(format_answer)

In [9]:
dataset['train'] = dataset['train'].add_column('id', [str(i)*3 for i in range(len(dataset['train']))])
dataset['validation'] = dataset['validation'].add_column('id', [str(i)*3 for i in range(len(dataset['validation']))])
dataset['test'] = dataset['test'].add_column('id', [str(i)*3 for i in range(len(dataset['test']))])

## Let's preprocess the data with tokenizer

In [74]:
model_checkpoint = "distilbert/distilbert-base-uncased" # most common arquitecture for extractive question answering, pretrained with distillation technique
#"distilbert-base-cased-distilled-squad"

In [75]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [76]:
tokenizer.is_fast

True

In [77]:
max_length = 500
stride = 130

In [78]:
question = dataset["train"][0]["question"]
context = dataset["train"][0]["context"]

In [79]:
inputs = tokenizer(question, context)

In [80]:
tokenizer.decode(inputs["input_ids"])

"[CLS] what is ( are ) glaucoma? [SEP] glaucoma is a group of diseases that can damage the eye ' s optic nerve and result in vision loss and blindness. the most common form of the disease is open - angle glaucoma. with early treatment, you can often protect your eyes against serious vision loss. ( watch the video to learn more about glaucoma. to enlarge the video, click the brackets in the lower right - hand corner. to reduce the video, press the escape ( esc ) button on your keyboard. ) see this graphic for a quick overview of glaucoma, including how many people it affects, whos at risk, what to do if you have it, and how to learn more. see a glossary of glaucoma terms. [SEP]"

In [180]:
def preprocess_training_examples(examples):
    """
    Function to preprocess training examples,
    Tokenization, splitting, and start_end index identification
    """
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char or answer=='no-answer':
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

### Processing training dataset

In [82]:
train_dataset = dataset['train'].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=dataset["train"].column_names
)

Map:   0%|          | 0/11137 [00:00<?, ? examples/s]

In [83]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 11559
})

In [181]:
def preprocess_validation_examples(examples):
    """
    Processing for validation set, 
    A little bit different from  training
    """
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

### Processing validation dataset

In [85]:
validation_dataset = dataset['validation'].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=dataset["validation"].column_names
)

In [182]:
####
### Evaluation metric to run on the validation dataset
####

metric = evaluate.load("squad")
n_best = 50
max_answer_length = 512

def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": "no-answer"})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

### Perform Fine tuning

In [87]:
import torch
from torch.utils.data import DataLoader
from transformers import default_data_collator
from torch.optim import AdamW
from accelerate import Accelerator
from transformers import get_scheduler
from tqdm.auto import tqdm

In [88]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [89]:
train_dataset.set_format("torch")
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")

In [90]:
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    validation_set, collate_fn=default_data_collator, batch_size=8
)

In [91]:
optimizer = AdamW(model.parameters(), lr=2e-5)

In [92]:
accelerator = Accelerator(mixed_precision ='fp16', gradient_accumulation_steps=1)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [97]:
num_train_epochs = 2
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [99]:
output_dir = os.path.join(abs_path,'models', "bert-finetuned-medicalbot-accelerate_2")
output_dir

'/mnt/d/projects/medical_assistant_bot_assignment/models/bert-finetuned-medicalbot-accelerate_2'

### Run fine tuning for 2-3 epochs

In [102]:
progress_bar = tqdm(range(num_training_steps))

best_f1_score = 0
for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    start_logits = []
    end_logits = []
    accelerator.print("Evaluation!")
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(validation_dataset)]
    end_logits = end_logits[: len(validation_dataset)]

    metrics = compute_metrics(
        start_logits, end_logits, validation_dataset, dataset["validation"]
    )
    print(f"epoch {epoch}:", metrics)

    # Save best
    if metrics['f1'] > best_f1_score:
        print('new best model saved')
        best_f1_score = metrics['f1']
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    

  0%|          | 0/2890 [00:00<?, ?it/s]

Evaluation!


  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/1376 [00:00<?, ?it/s]

epoch 0: {'exact_match': 40.18895348837209, 'f1': 70.19901438065072}
new best model saved
Evaluation!


  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/1376 [00:00<?, ?it/s]

epoch 1: {'exact_match': 40.18895348837209, 'f1': 70.19901438065072}


### Let's try with some questions

In [104]:
from transformers import pipeline

In [140]:
test_dataset = pd.read_csv(os.path.join(abs_path, 'data/processed', 'test.csv'))

In [119]:
# Replace this with your own checkpoint
model_checkpoint = os.path.join(abs_path, "models/bert-finetuned-medicalbot-accelerate_2/")
question_answerer = pipeline("question-answering", model=model_checkpoint, tokenizer=tokenizer, device='cuda')

In [158]:
test_dataset.iloc[[961,985,41,80,112,146]]

Unnamed: 0,question,context,answer,voted_label,start_end_indexes
961,How many people are affected by Ochoa syndrome?,Ochoa syndrome is a rare disorder. About 150 c...,About 150 cases have been reported in the medi...,ochoa_related,"(35, 61)"
985,What are the genetic changes related to Danon ...,Danon disease is caused by mutations in the LA...,Danon disease is caused by mutations in the LA...,danon_related,"(0, 0)"
41,Who is at risk for Neuroblastoma??,The risk factors for neuroblastoma are not known.,no-answer,neuroblastoma_outlook,"(0, 0)"
80,What is (are) Pemphigus?,Pemphigus is an autoimmune disorder. If you ha...,Pemphigus is an autoimmune disorder. If you ha...,pemphigus_chronic,"(0, 315)"
112,Do you have information about MRSA,Summary : MRSA stands for methicillin-resistan...,Summary : MRSA stands for methicillin-resistan...,mrsa,"(0, 0)"
146,What is (are) Heart Disease in Women?,"In the United States, 1 in 4 women dies from h...","In the United States, 1 in 4 women dies from h...",women_heart,"(0, 0)"


In [160]:
question = "How many people are affected by Ochoa syndrome?	"
context = "Ochoa syndrome is a rare disorder. About 150 cases have been reported in the medical literature."
question_answerer(question=question, context=context)['answer']

'About 150 cases have been reported in the medical literature.'

In [162]:
question = "What is (are) Heart Disease in Women?"
context = "In the United States, 1 in 4 women dies from heart disease. The most common cause of heart disease in both men and women is narrowing or blockage of the coronary arteries, the blood vessels that supply blood to the heart itself. This is called coronary artery disease, and it happens slowly over time. It's the major reason people have heart attacks.    Heart diseases that affect women more than men include       - Coronary microvascular disease (MVD) - a problem that affects the heart's tiny arteries    - Broken heart syndrome - extreme emotional stress leading to severe but often short-term heart muscle failure       The older a woman gets, the more likely she is to get heart disease. But women of all ages should be concerned about heart disease. All women can take steps to prevent it by practicing healthy lifestyle habits.     NIH: National Heart, Lung, and Blood Institute"
question_answerer(question=question, context=context)['answer']

'In the United States, 1 in 4 women dies from heart disease.'

In [173]:
question = "Is allergic asthma inherited?"
context = "Allergic asthma can be passed through generations in families, but the inheritance pattern is unknown. People with mutations in one or more of the associated genes inherit an increased risk of allergic asthma, not the condition itself. Because allergic asthma is a complex condition influenced by genetic and environmental factors, not all people with a mutation in an asthma-associated gene will develop the disorder."
question_answerer(question=question, context=context)['answer']

'Allergic asthma can be passed through generations in families, but the inheritance'

In [177]:
question = "What is allergic asthma?"
context = "Because allergic asthma is a complex condition influenced by genetic and environmental factors, not all people with a mutation in an asthma-associated gene will develop the disorder."
question_answerer(question=question, context=context)['answer']

'Because allergic'

In [186]:
dataset['train']['question']

['What is (are) Glaucoma?',
 'Who is at risk for Glaucoma??',
 'How to prevent Glaucoma?',
 'What are the symptoms of Glaucoma?',
 'What are the treatments for Glaucoma?',
 'what research (or clinical trials) is being done for Glaucoma?',
 'What is (are) High Blood Pressure?',
 'What are the treatments for High Blood Pressure?',
 'How to prevent High Blood Pressure?',
 "What is (are) Paget's Disease of Bone?",
 "What are the symptoms of Paget's Disease of Bone?",
 "How to diagnose Paget's Disease of Bone?",
 "What are the complications of Paget's Disease of Bone?",
 "What are the treatments for Paget's Disease of Bone?",
 'What is (are) Urinary Tract Infections?',
 'What are the symptoms of Urinary Tract Infections?',
 'Who is at risk for Urinary Tract Infections??',
 'What are the treatments for Urinary Tract Infections?',
 'What is (are) Alcohol Use and Older Adults?',
 'What are the symptoms of Alcohol Use and Older Adults?',
 'What are the treatments for Alcohol Use and Older Adult