----------
# **DATASET PREPARATION**
----------

## **- Load**

In [1]:
import pandas as pd

df = pd.read_csv('MedQUAD.csv')
df['Context'] = df['Answer']
df.drop('qtype', axis=1, inplace=True)
df.head()

Unnamed: 0,Question,Answer,Context
0,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,LCMV infections can occur after exposure to fr...
1,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,LCMV is most commonly recognized as causing ne...
2,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,Individuals of all ages who come into contact ...
3,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos...","During the first phase of the disease, the mos..."
4,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen...","Aseptic meningitis, encephalitis, or meningoen..."


## **- Split**

In [2]:
from sklearn.model_selection import train_test_split

train_df, test_valid_df = train_test_split(df, test_size=0.3, random_state=42)
valid_df, test_df = train_test_split(test_valid_df, test_size=0.33, random_state=42)

print(f"Train set: {len(train_df)} samples")
print(f"Test set: {len(valid_df)} samples")
print(f"Validation set: {len(test_df)} samples")

Train set: 11484 samples
Test set: 3298 samples
Validation set: 1625 samples


In [3]:
train_df.to_csv('train_df.csv', index=False)

In [4]:
train_df

Unnamed: 0,Question,Answer,Context
3588,How to prevent Urinary Tract Infections ?,Changing some of these daily habits may help p...,Changing some of these daily habits may help p...
8658,What is (are) neuroferritinopathy ?,Neuroferritinopathy is a disorder in which iro...,Neuroferritinopathy is a disorder in which iro...
796,What is the outlook for Cerebro-Oculo-Facio-Sk...,COFS is a fatal disease. Most children do not ...,COFS is a fatal disease. Most children do not ...
12750,What are the symptoms of Keratolytic winter er...,What are the signs and symptoms of Keratolytic...,What are the signs and symptoms of Keratolytic...
5861,Is steatocystoma multiplex inherited ?,When steatocystoma multiplex is caused by muta...,When steatocystoma multiplex is caused by muta...
...,...,...,...
11284,What are the symptoms of Lipoic acid synthetas...,What are the signs and symptoms of Lipoic acid...,What are the signs and symptoms of Lipoic acid...
11964,What is (are) Cerebellar degeneration ?,Cerebellar degeneration refers to the deterior...,Cerebellar degeneration refers to the deterior...
5390,What is (are) Rashes ?,A rash is an area of irritated or swollen skin...,A rash is an area of irritated or swollen skin...
860,What is the outlook for Syringomyelia ?,"Symptoms usually begin in young adulthood, wit...","Symptoms usually begin in young adulthood, wit..."


## **- Pre-processing**

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")

In [6]:
def preprocess(row):
    questions = [q.strip() for q in row["Question"]]
    contexts = row["Context"]
    answers = row["Answer"]
    
    inputs = tokenizer(
        questions,
        contexts,
        max_length=512,
        truncation="only_second",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,    
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer_text = answers[sample_idx]
        start_char = contexts[sample_idx].find(answer_text)
        end_char = start_char + len(answer_text)
        sequence_ids = inputs.sequence_ids(i)
        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)
            
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


In [7]:
from datasets import Dataset

# Preprocess train dataset
train_dataset = Dataset.from_pandas(train_df)
train_dataset = train_dataset.map(
    preprocess,
    batched=True,
    remove_columns=train_dataset.column_names,
)

# Preprocess valid dataset
valid_dataset = Dataset.from_pandas(valid_df)
valid_dataset = valid_dataset.map(
    preprocess,
    batched=True,
    remove_columns=valid_dataset.column_names,
)

# Preprocess test dataset
test_dataset = Dataset.from_pandas(test_df)
test_dataset = test_dataset.map(
    preprocess,
    batched=True,
    remove_columns=test_dataset.column_names,
)

# Print total number of samples
print(f"Train set: {len(train_df)} samples, Preprocessed: {len(train_dataset)} samples")
print(f"Valid set: {len(valid_df)} samples, Preprocessed: {len(valid_dataset)} samples")
print(f"Test set: {len(test_df)} samples, Preprocessed: {len(test_dataset)} samples")


Map:   0%|          | 0/11484 [00:00<?, ? examples/s]

Map:   0%|          | 0/3298 [00:00<?, ? examples/s]

Map:   0%|          | 0/1625 [00:00<?, ? examples/s]

Train set: 11484 samples, Preprocessed: 13497 samples
Valid set: 3298 samples, Preprocessed: 3870 samples
Test set: 1625 samples, Preprocessed: 1850 samples


In [10]:
print(train_dataset)
print(valid_dataset)
print(test_dataset)

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 13497
})
Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 3870
})
Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 1850
})


In [8]:
datasets = [("Train Data", train_dataset), ("Validation Data", valid_dataset), ("Test Data", test_dataset)]

for dataset_name, dataset in datasets:
    print(dataset_name + ":")
    print("Sample 1:")
    print("Inputs:", dataset[0])
    print()

Train Data:
Sample 1:
Inputs: {'input_ids': [2, 184, 20, 2501, 287, 6478, 622, 10020, 18460, 13, 60, 3, 4226, 109, 16, 158, 1954, 16760, 123, 448, 2501, 287, 6478, 622, 10020, 18460, 13, 5, 11267, 18, 6, 9, 13, 8, 11500, 37, 431, 20, 97, 75, 568, 14, 9594, 9, 364, 378, 11500, 37, 431, 20, 97, 20, 643, 10955, 37, 1017, 77, 14, 13, 4221, 11631, 9, 48, 1424, 25, 127, 681, 75, 21, 26050, 1018, 9, 11500, 37, 431, 20, 97, 75, 568, 14, 9594, 9, 364, 378, 11500, 37, 431, 20, 97, 20, 643, 10955, 37, 1017, 77, 14, 13, 4221, 11631, 9, 48, 1424, 25, 127, 681, 75, 21, 26050, 1018, 9, 13, 8, 2610, 7503, 16, 6250, 18, 15, 1118, 308, 9, 6250, 18, 92, 448, 15017, 10955, 37, 14, 287, 6478, 622, 329, 9, 308, 25, 246, 9, 127, 7714, 148, 378, 1131, 20, 2610, 490, 20, 970, 15, 469, 8, 11792, 7574, 16, 6250, 206, 208, 9, 13, 5, 3220, 148, 376, 20, 2610, 787, 308, 185, 16, 1200, 2039, 9, 26, 823, 15, 100, 42, 57, 13400, 2990, 54, 582, 2515, 15, 42, 378, 52, 2610, 48, 212, 6250, 9, 1349, 154, 853, 781, 11747, 

In [9]:
datasets = [("Train Data", train_dataset), ("Validation Data", valid_dataset), ("Test Data", test_dataset)]

for dataset_name, dataset in datasets:
    print(dataset_name + ":")
    print("Sample 1:")
    decoded_input = tokenizer.decode(dataset[0]["input_ids"], skip_special_tokens=False)
    print("Inputs:", decoded_input)
    print()


Train Data:
Sample 1:
Inputs: [CLS] how to prevent urinary tract infections?[SEP] changing some of these daily habits may help prevent urinary tract infections (utis). - wipe from front to back after using the toilet. women should wipe from front to back to keep bacteria from getting into the urethra. this step is most important after a bowel movement. wipe from front to back after using the toilet. women should wipe from front to back to keep bacteria from getting into the urethra. this step is most important after a bowel movement. - drink lots of fluids, especially water. fluids can help flush bacteria from the urinary system. water is best. most healthy people should try to drink six to eight, 8-ounce glasses of fluid each day. (some people need to drink less water because of certain conditions. for example, if you have kidney failure or heart disease, you should not drink this much fluid. ask your health care provider how much fluid is healthy for you.) drink lots of fluids, espec

---
# **FINE TUNING & TRAIN**
---

In [8]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
from transformers import AlbertForQuestionAnswering

model = AlbertForQuestionAnswering.from_pretrained("twmkn9/albert-base-v2-squad2")
model.to(device)

Some weights of the model checkpoint at twmkn9/albert-base-v2-squad2 were not used when initializing AlbertForQuestionAnswering: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


AlbertForQuestionAnswering(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias

In [10]:
from transformers import Trainer, TrainingArguments, get_linear_schedule_with_warmup
import torch.optim as optim

def fine_tune_and_save(model, args, train_dataset, valid_dataset, test_dataset, tokenizer, save_path):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=len(train_dataset) // args.per_device_train_batch_size * args.num_train_epochs,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=tokenizer,
        optimizers=(optimizer, scheduler),
    )

    trainer.train()
    test_results = trainer.evaluate(test_dataset)
    trainer.save_model(save_path)

    return trainer, test_results

In [None]:
args_1 = TrainingArguments(
    "./models/model_1/results",
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    num_train_epochs=3,
    weight_decay=0.001,
    warmup_steps=500,
    gradient_accumulation_steps=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    group_by_length=True,
    fp16=True,
)

save_path_1 = "./models/model_1"
trainer_1, test_results_1 = fine_tune_and_save(model, args_1, train_dataset, valid_dataset, test_dataset, tokenizer, save_path_1)

In [None]:
args_2 = TrainingArguments(
    "./models/model_2/results",
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    num_train_epochs=3,
    weight_decay=0.001,
    warmup_steps=500,
    gradient_accumulation_steps=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    group_by_length=True,
    fp16=True,
)

save_path_2 = "./models/model_2"
trainer_2, test_results_2 = fine_tune_and_save(model, args_2, train_dataset, valid_dataset, test_dataset, tokenizer, save_path_2)

---
# **EVALUATION**
---

In [None]:
import matplotlib.pyplot as plt

def plot_losses(trainer, test_results, save_path):
    train_losses = []
    valid_losses = []
    epochs_train = []
    epochs_valid = []

    for log in trainer.state.log_history:
        if 'loss' in log:
            train_losses.append(log['loss'])
            epochs_train.append(log['epoch'])
        elif 'eval_loss' in log:
            valid_losses.append(log['eval_loss'])
            epochs_valid.append(log['epoch'])

    test_loss = test_results["eval_loss"]

    plt.figure(figsize=(12, 6))
    plt.plot(epochs_train, train_losses, label="Train Loss")
    plt.plot(epochs_valid, valid_losses, label="Valid Loss")
    plt.axhline(y=test_loss, color='r', linestyle='-', label="Test Loss") 
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(save_path)
    plt.show()


In [None]:
# Plot losses for model 1
plot_losses(trainer_1, test_results_1,save_path_1)

In [None]:
# Plot losses for model 2
plot_losses(trainer_2, test_results_2, save_path_2)

---
# **PREDICTION**
---

In [30]:
from transformers import pipeline, AutoModelForQuestionAnswering
import torch

save_path_1 = "./models/model_1"
save_path_2 = "./models/model_2"

model = AutoModelForQuestionAnswering.from_pretrained(save_path_1)
tokenizer=AutoTokenizer.from_pretrained(save_path_1)

In [33]:
qa_pipeline = pipeline(
    "question-answering",
    model=model,
    tokenizer=tokenizer
)
sample_test_df = test_df.sample(n=3, random_state=32)

for index, row in sample_test_df.iterrows():
    context = row['Context']
    question = row['Question']
    print(f"Context: {context}")
    
    result = qa_pipeline(question=question, context=context)
    
    print(f"Question: {question}")
    print(f"Answer: {result['answer']}")
    print(f"Score: {result['score']}")
    print()


Context: What are the signs and symptoms of Ichthyosis hystrix gravior? The Human Phenotype Ontology provides the following list of signs and symptoms for Ichthyosis hystrix gravior. If the information is available, the table below includes how often the symptom is seen in people with this condition. You can use the MedlinePlus Medical Dictionary to look up the definitions for these medical terms. Signs and Symptoms Approximate number of patients (when available) Autosomal dominant inheritance - Ichthyosis - The Human Phenotype Ontology (HPO) has collected information on how often a sign or symptom occurs in a condition. Much of this information comes from Orphanet, a European rare disease database. The frequency of a sign or symptom is usually listed as a rough estimate of the percentage of patients who have that feature. The frequency may also be listed as a fraction. The first number of the fraction is how many people had the symptom, and the second number is the total number of peo

In [36]:
def get_answer(question, context, model, tokenizer, max_length):
    if len(question) + len(context) + 4 > max_length:
        return None, None

    inputs = tokenizer(question, context, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()

    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    predicted_answer = tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)

    confidence_score_start = torch.nn.functional.softmax(outputs.start_logits, dim=1)[0, answer_start_index].item()
    confidence_score_end = torch.nn.functional.softmax(outputs.end_logits, dim=1)[0, answer_end_index].item()

    combined_confidence_score = (confidence_score_start + confidence_score_end) / 2

    return predicted_answer, combined_confidence_score

sample_test_df = test_df.sample(n=3, random_state=32)

max_length = model.config.max_position_embeddings 

for index, row in sample_test_df.iterrows():
    question = row['Question']
    context = row['Context']
    
    predicted_answer, confidence_score = None, None
    while not predicted_answer:
        predicted_answer, confidence_score = get_answer(question, context, model, tokenizer, max_length)
        if not predicted_answer:
            row = test_df.sample(n=1).iloc[0]
            question = row['Question']
            context = row['Context']

    print(f"Question: {question}")
    print(f"Context: {context}")
    print(f"Predicted Answer: {predicted_answer}")
    print(f"Combined Confidence Score: {confidence_score}")
    print()



Question: How many people are affected by hypokalemic periodic paralysis ?
Context: Although its exact prevalence is unknown, hypokalemic periodic paralysis is estimated to affect 1 in 100,000 people. Men tend to experience symptoms of this condition more often than women.
Predicted Answer: although its exact prevalence is unknown, hypokalemic periodic paralysis is estimated to affect 1 in 100,000 people. men tend to experience symptoms of this condition more often than women.
Combined Confidence Score: 0.9999983310699463

Question: How many people are affected by fumarase deficiency ?
Context: Fumarase deficiency is a very rare disorder. Approximately 100 affected individuals have been reported worldwide. Several were born in an isolated religious community in the southwestern United States.
Predicted Answer: fumarase deficiency is a very rare disorder. approximately 100 affected individuals have been reported worldwide. several were born in an isolated religious community in the sout