## Dataset Description

The dataset used in this project is the *Medical Question Answering Dataset* ([MedQuAD](https://github.com/abachaa/MedQuAD/tree/master)). It includes medical question-answer pairs along with additional information, such as the question type, the question *focus*, its UMLS(Unified Medical Language System) details like - Concept Unique Identifier(*CUI*) and Semantic *Type* and *Group*.

To know more about this data's collection, and construction method, refer to this [paper](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3119-4).

The data is extracted and is in CSV format with below features:

- **Focus**: the question focus
- **CUI**: concept unique identifier
- **SemanticType**
- **SemanticGroup**
- **Question**
- **Answer**

## Information

Healthcare professionals often have to refer to medical literature and documents while seeking answers to medical queries. Medical databases or search engines are powerful resources of upto date medical knowledge. However, the existing documentation is large and makes it difficult for professionals to retrieve answers quickly in a clinical setting. The problem with search engines and informative retrieval engines is that these systems return a list of documents rather than answers. Instead, healthcare professionals can use question-answering systems to retrieve short sentences or paragraphs in response to medical queries. Such systems have the biggest advantage of generating answers and providing hints in a few seconds.

### Problem Statement

### Import required packages

In [2]:
!pip -q install -U accelerate
!pip -q install -U transformers
!pip -q install torch

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/280.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m143.4/280.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import warnings
warnings.filterwarnings('ignore')

In [4]:

df = pd.read_csv("MedQuAD.csv")
df.shape



(16412, 6)

### Pre-processing and EDA

In [None]:

print(df.isna().sum())

Focus             14
CUI              565
SemanticType     597
SemanticGroup    565
Question           0
Answer             5
dtype: int64


In [None]:

df_cleaned = df.drop_duplicates(subset=['Question', 'Answer'], keep='last')
df_cleaned.reset_index(drop=True, inplace=True)
print(df.shape)
print(df_cleaned.shape)

(16412, 6)
(16364, 6)


In [None]:
# Total categories in Focus column

df_cleaned['Focus'].value_counts()

Breast Cancer                                        53
Prostate Cancer                                      43
Stroke                                               35
Skin Cancer                                          34
Alzheimer's Disease                                  30
                                                     ..
Pediatric ulcerative colitis                          1
Duodenal ulcer due to antral G-cell hyperfunction     1
Duodenal atresia                                      1
Pediatric Crohn's disease                             1
Medium-chain 3-ketoacyl-coa thiolase deficiency       1
Name: Focus, Length: 5126, dtype: int64

In [None]:
# Displaying the distinct categories of Focus column and the number of records belonging to each category
# (Top 100 only)


df_cleaned['Focus'].value_counts().nlargest(100)

Breast Cancer                                     53
Prostate Cancer                                   43
Stroke                                            35
Skin Cancer                                       34
Alzheimer's Disease                               30
                                                  ..
Poland syndrome                                   11
Opitz G/BBB syndrome                              11
Polycythemia Vera                                 11
Diabetic Kidney Disease                           10
What I need to know about Gestational Diabetes    10
Name: Focus, Length: 100, dtype: int64

In [None]:
# Top 100 Focus categories names


top_focus_cat = list(df_cleaned['Focus'].value_counts().nlargest(100).index)
top_focus_cat

['Breast Cancer',
 'Prostate Cancer',
 'Stroke',
 'Skin Cancer',
 "Alzheimer's Disease",
 'Lung Cancer',
 'Colorectal Cancer',
 'High Blood Cholesterol',
 'Heart Failure',
 'Heart Attack',
 'High Blood Pressure',
 "Parkinson's Disease",
 'Leukemia',
 'Osteoporosis',
 'Shingles',
 'Hemochromatosis',
 'Age-related Macular Degeneration',
 'Diabetes',
 'Gum (Periodontal) Disease',
 'Diabetic Retinopathy',
 'Psoriasis',
 'Kidney Disease',
 'Balance Problems',
 'COPD',
 'Cataract',
 'Dry Mouth',
 'Medicare and Continuing Care',
 'Prescription and Illicit Drug Abuse',
 'Gout',
 'Wilson Disease',
 'Glaucoma',
 'Osteoarthritis',
 'Short Bowel Syndrome',
 'Endometrial Cancer',
 'Narcolepsy',
 'Problems with Taste',
 'Rheumatoid Arthritis',
 'Neuroblastoma',
 'Urinary Tract Infections in Children',
 'Surviving Cancer',
 'Peripheral Arterial Disease (P.A.D.)',
 'Problems with Smell',
 'Anxiety Disorders',
 'Kidney Dysplasia',
 'Dry Eye',
 'Pituitary Tumors',
 'Diabetic Neuropathies: The Nerve Dama

### Create Training and Validation set

In [None]:

num_select_focus_cat = 100
num_train_samples_per_focus = 4
num_val_samples_per_focus = 1

selected_focus_cat = list(df_cleaned['Focus'].value_counts().nlargest(num_select_focus_cat).index)
df_selected_focus_cat = df_cleaned[df_cleaned['Focus'].isin(selected_focus_cat)]
df_selected_focus_cat.reset_index(drop=True, inplace=True)

df_train = df_selected_focus_cat.groupby('Focus').apply(lambda x: x.sample(n=num_train_samples_per_focus))
print(df_selected_focus_cat.shape, df_train.shape)

(1532, 6) (400, 6)


In [None]:
df_train.reset_index(drop=True, inplace=True)
df_train.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
0,21-hydroxylase deficiency,C1291314,T019,Disorders,What are the treatments for 21-hydroxylase def...,These resources address the diagnosis or manag...
1,21-hydroxylase deficiency,C1291314,T019,Disorders,What causes 21-hydroxylase deficiency ?,"What causes salt-wasting, simple virilizing, a..."
2,21-hydroxylase deficiency,C1291314,T019,Disorders,How many people are affected by 21-hydroxylase...,The classic forms of 21-hydroxylase deficiency...
3,21-hydroxylase deficiency,C1291314,T019,Disorders,What is (are) 21-hydroxylase deficiency ?,21-hydroxylase deficiency is an inherited diso...
4,Abdominal Adhesions,C0549357,T020,Disorders,What causes Abdominal Adhesions ?,Abdominal surgery is the most frequent cause o...


In [None]:
df_filter_train = pd.merge(df_selected_focus_cat,df_train, indicator=True, how='outer').query('_merge=="left_only"').drop('_merge', axis=1)
df_val = df_filter_train.groupby('Focus').apply(lambda x: x.sample(n=num_val_samples_per_focus))
df_val.reset_index(drop=True, inplace=True)

### Pre-process `Question` and `Answer` text



In [None]:
# Combine Questions and Answers for train and val data
## sequence = '<question>' + question + '<answer>' + answer


df_train['Sequence'] = '<question>' + df_train['Question'] + '<answer>' + df_train['Answer']
df_val['Sequence'] = '<question>' + df_val['Question'] + '<answer>' + df_val['Answer']

In [None]:
df_val.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer,Sequence
0,21-hydroxylase deficiency,C1291314,T019,Disorders,What causes 21-hydroxylase deficiency ?,"What causes salt-wasting, simple virilizing, a...",<question>What causes 21-hydroxylase deficienc...
1,Abdominal Adhesions,C0549357,T020,Disorders,What are the symptoms of Abdominal Adhesions ?,A complete intestinal obstruction is life thre...,<question>What are the symptoms of Abdominal A...
2,Adrenal Insufficiency and Addison's Disease,C0405580,T019,Disorders,What to do for Adrenal Insufficiency and Addis...,Some people with Addisons disease who are aldo...,<question>What to do for Adrenal Insufficiency...
3,Age-related Macular Degeneration,C0242383,T047,Disorders,What are the symptoms of Age-related Macular D...,An early symptom of wet AMD is that straight l...,<question>What are the symptoms of Age-related...
4,Alagille Syndrome,C0085280,T019,Disorders,How to prevent Alagille Syndrome ?,Scientists have not yet found a way to prevent...,<question>How to prevent Alagille Syndrome ?<a...


In [None]:
# Save the training and validation data as separate text files

train_file = "/content/train.txt"
val_file = "/content/val.txt"

with open(train_file, "w") as file1:
  for idx in range(0, len(df_train)):
    file1.write(df_train['Sequence'].iloc[idx] + "\n")

with open(val_file, "w") as file2:
  for idx in range(0, len(df_val)):
    file2.write(df_val['Sequence'].iloc[idx] + "\n")

**Load pre-trained GPT2Tokenizer**



In [None]:
# Set up the tokenizer

checkpoint = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

**Exercise 7: Tokenize train and validation data and form TextDataset objects **



In [None]:
# Tokenize train text
train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_file, block_size=512)

# Tokenize validation text
val_dataset = TextDataset(tokenizer=tokenizer, file_path=val_file, block_size=512)

** Create a DataCollator object **

In [None]:
# Create a Data collator object

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")

In [None]:
# Set up the model

model = GPT2LMHeadModel.from_pretrained(checkpoint)

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

**Fine-tune GPT2 Model**



In [None]:
# Set up the training arguments


model_output_path = "./output"

training_args = TrainingArguments(
    output_dir = model_output_path,
    overwrite_output_dir = True,
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    num_train_epochs = 50,
    save_steps = 1_000,
    # save_total_limit = 2,
    logging_dir = './logs',
    )

In [None]:
# Train the model

trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
)
trainer.train()
# Save the model

trainer.save_model(model_output_path)

# Save the tokenizer

tokenizer.save_pretrained(model_output_path)

('./output/tokenizer_config.json',
 './output/special_tokens_map.json',
 './output/vocab.json',
 './output/merges.txt',
 './output/added_tokens.json')

**Test Model with user input prompts **



In [None]:
def generate_response(model, tokenizer, prompt, max_length=200):


    input_ids = tokenizer.encode(prompt, return_tensors="pt")                   # 'pt' for returning pytorch tensor

    # Create the attention mask and pad token id
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id

    output = model.generate(
        input_ids,
        max_length=max_length,
        num_return_sequences=1,
        attention_mask=attention_mask,
        pad_token_id=pad_token_id
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)


In [None]:
# Load the fine-tuned model and tokenizer

my_model = GPT2LMHeadModel.from_pretrained(model_output_path)
my_tokenizer = GPT2Tokenizer.from_pretrained(model_output_path)

In [None]:
# Testing with a sample prompt 1


prompt = "What are the top 5 symptoms of Breast Cancer?"
response = generate_response(my_model, my_tokenizer, prompt)
response

'What are the top 5 symptoms of Breast Cancer??<answer>What are the signs and symptoms of breast cancer? Many breast cancer patients have no signs or symptoms. However, signs and symptoms of early breast cancer may include the following: - Fever. The usual symptoms of a cold or the chill of the night can be a sign of early breast cancer. - Trouble chewing, swallowing, or speaking. - Feeling tired. A common feeling during the day is that something is hard, like a fallen stone. - Feeling nervous. A feeling of nervousness or nervousness is usually a sign of early breast cancer. Getting breast cancer is very difficult because of the many genes that cause cancer. If a woman has a mutated gene, the chances of her developing breast cancer are very low. If a woman has a mutated gene, the chances of her developing breast cancer are very high. Getting breast cancer is very difficult because of the many genes that cause cancer. If a woman has a mutated gene, the'

In [None]:
import gradio as gr

def predict(name):
  response = generate_response(my_model, my_tokenizer, name)
  return response

iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface.launch()


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

