#Clinical Question Answering LLM

In [1]:
# Installing the Hugging Face Transformers library, access to NLP model BERT
!pip install transformers

# Install PyTorch - a deep learning framework that enables building and training neural networks.
!pip install torch



In [2]:
from transformers import BertForQuestionAnswering, Trainer, TrainingArguments
import torch
from transformers import BertTokenizer
import numpy as np
import pandas as pd

**BERT large model (uncased) whole word masking finetuned on SQuAD**

**Pretrained model** on English language using a masked language modeling (**MLM**) objective.

This model is **uncased**: it does not make a difference between english and English.

This model was trained with: Whole Word Masking.
In this case, all of the tokens corresponding to a word are masked at once.
The overall masking rate remains the same.

It was pretrained with two objectives:

Masked language modeling (MLM): taking a sentence, the model randomly masks 15% of the words in the input then run the entire masked sentence through the model and has to predict the masked words. This is different from traditional recurrent neural networks (RNNs) that usually see the words one after the other, or from autoregressive models like GPT which internally mask the future tokens. It allows the model to learn a bidirectional representation of the sentence.

Next sentence prediction (NSP): the models concatenates two masked sentences as inputs during pretraining. Sometimes they correspond to sentences that were next to each other in the original text, sometimes not. The model then has to predict if the two sentences were following each other or not.

This model has the following configuration:

24-layer
1024 hidden dimension
16 attention heads
336M parameters.


Training procedure
Preprocessing
The texts are lowercased and tokenized using WordPiece and a vocabulary size of 30,000. The inputs of the model are then of the form:

[CLS] Sentence A [SEP] Sentence B [SEP]

With probability 0.5, sentence A and sentence B correspond to two consecutive sentences in the original corpus and in the other cases, it's another random sentence in the corpus. Note that what is considered a sentence here is a consecutive span of text usually longer than a single sentence. The only constrain is that the result with the two "sentences" has a combined length of less than 512 tokens.

The details of the masking procedure for each sentence are the following:

15% of the tokens are masked.
In 80% of the cases, the masked tokens are replaced by [MASK].
In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
In the 10% remaining cases, the masked tokens are left as is.
Pretraining
The model was trained on 4 cloud TPUs in Pod configuration (16 TPU chips total) for one million steps with a batch size of 256. The sequence length was limited to 128 tokens for 90% of the steps and 512 for the remaining 10%. The optimizer used is Adam with a learning rate of 1e-4,
𝛽
1
=
0.9
β
1
​
 =0.9 and
𝛽
2
=
0.999
β
2
​
 =0.999, a weight decay of 0.01, learning rate warmup for 10,000 steps and linear decay of the learning rate after.

Fine-tuning
After pre-training, this model was fine-tuned on the SQuAD dataset with one of our fine-tuning scripts. In order to reproduce the training, you may use the following command:

python -m torch.distributed.launch --nproc_per_node=8 ./examples/question-answering/run_qa.py \
    --model_name_or_path bert-large-uncased-whole-word-masking \
    --dataset_name squad \
    --do_train \
    --do_eval \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ./examples/models/wwm_uncased_finetuned_squad/ \
    --per_device_eval_batch_size=3   \
    --per_device_train_batch_size=3   \

Evaluation results
The results obtained are the following:

f1 = 93.15
exact_match = 86.91



In [3]:
# Initializing the model and tokenizer
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

tokenizer_for_bert = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



In [None]:
data = pd.read_csv("/content/QAMedical.csv")  # Use pd.read_csv to read the CSV file

In [None]:
data.head(5)

Unnamed: 0,qtype,Question,Answer
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos..."
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen..."


In [None]:
for i in range(3):
 question=data.iloc[:, 1][i]
 print(question)

Who is at risk for Lymphocytic Choriomeningitis (LCM)? ?
What are the symptoms of Lymphocytic Choriomeningitis (LCM) ?
Who is at risk for Lymphocytic Choriomeningitis (LCM)? ?


In [None]:
for i in range(3):
  passage=data.iloc[:, 2][i]
  print(passage +"\n")

LCMV infections can occur after exposure to fresh urine, droppings, saliva, or nesting materials from infected rodents.  Transmission may also occur when these materials are directly introduced into broken skin, the nose, the eyes, or the mouth, or presumably, via the bite of an infected rodent. Person-to-person transmission has not been reported, with the exception of vertical transmission from infected mother to fetus, and rarely, through organ transplantation.

LCMV is most commonly recognized as causing neurological disease, as its name implies, though infection without symptoms or mild febrile illnesses are more common clinical manifestations. 
                
For infected persons who do become ill, onset of symptoms usually occurs 8-13 days after exposure to the virus as part of a biphasic febrile illness. This initial phase, which may last as long as a week, typically begins with any or all of the following symptoms: fever, malaise, lack of appetite, muscle aches, headache, nau

In [None]:
def bert_question_answer(question, passage, max_len=500):

    #Tokenize input question and passage
    #Include unique tokens- [CLS] and [SEP]
    input_ids = tokenizer_for_bert.encode (question, passage,  max_length= max_len, truncation=True)
    print("input_ids", input_ids)

    #Getting number of tokens in 1st sentence (question) and 2nd sentence (passage that contains answer)
    sep_index = input_ids.index(102)
    print("sep_index", sep_index)
    len_question = sep_index + 1
    print("len_question",len_question)
    len_passage = len(input_ids)- len_question
    print("len_passage",len_passage)

    #Need to separate question and passage
    #Segment ids will be 0 for question and 1 for passage
    segment_ids =  [0]*len_question + [1]*(len_passage)
    print("segment_ids", segment_ids)


    #Converting token ids to tokens
    tokens = tokenizer_for_bert.convert_ids_to_tokens(input_ids)

    print("tokens: ", tokens)

    #Getting start and end scores for answer
    #Converting input arrays to torch tensors before passing to the model
    start_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[0]
    print("start_token_scores", start_token_scores)
    end_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[1]
    print("end_token_scores", end_token_scores)

    #Converting scores tensors to numpy arrays
    start_token_scores = start_token_scores.detach().numpy().flatten()
    print("start_token_scores", start_token_scores)
    end_token_scores = end_token_scores.detach().numpy().flatten()
    print("end_token_scores", end_token_scores)

    #Getting start and end index of answer based on highest scores
    answer_start_index = np.argmax(start_token_scores)
    print("answer_start_index", answer_start_index)
    answer_end_index = np.argmax(end_token_scores)
    print("answer_end_index", answer_end_index)

    #Getting scores for start and end token of the answer
    start_token_score = np.round(start_token_scores[answer_start_index], 2)
    print("start_token_score", start_token_score)
    end_token_score = np.round(end_token_scores[answer_end_index], 2)
    print("end_token_score", end_token_score)

    #Combining subwords starting with ## and get full words in output.
    #It is because tokenizer breaks words which are not in its vocab.
    answer = tokens[answer_start_index]
    for i in range(answer_start_index + 1, answer_end_index + 1):
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
        else:
            answer += ' ' + tokens[i]

    # If the answer didn't find in the passage
    if (start_token_score < 0 ) or ( answer_start_index == 0) or ( answer_end_index <  answer_start_index) or (answer == '[SEP]'):
      answer = "Sorry!, I was unable to discover an answer in the passage."

    return ("answer_start_index: ",answer_start_index, "answer_end_index: ",answer_end_index,"start_token_score: ", start_token_score,"end_token_score: ", end_token_score, "answer:", answer)

#Testing function
bert_question_answer("What is the domain of the NLP Project", "The case study of Natural Language Processing. On the medical domain, Bert model question answering along with fine tuning")


input_ids [101, 2054, 2003, 1996, 5884, 1997, 1996, 17953, 2361, 2622, 102, 1996, 2553, 2817, 1997, 3019, 2653, 6364, 1012, 2006, 1996, 2966, 5884, 1010, 14324, 2944, 3160, 10739, 2247, 2007, 2986, 17372, 102]
sep_index 10
len_question 11
len_passage 22
segment_ids [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
tokens:  ['[CLS]', 'what', 'is', 'the', 'domain', 'of', 'the', 'nl', '##p', 'project', '[SEP]', 'the', 'case', 'study', 'of', 'natural', 'language', 'processing', '.', 'on', 'the', 'medical', 'domain', ',', 'bert', 'model', 'question', 'answering', 'along', 'with', 'fine', 'tuning', '[SEP]']
start_token_scores tensor([[-5.6137, -5.2203, -8.4065, -6.9295, -8.3026, -8.4705, -7.7444, -6.5232,
         -9.4839, -9.9740, -5.6137,  1.5292,  1.2963, -0.8007, -3.7327,  2.5807,
         -0.9019, -1.5288, -5.0339,  0.4595,  1.9068,  3.3825, -2.3997, -5.6213,
         -2.7057, -3.3010, -3.2343, -4.1743, -7.0653, -7.5564, -3.9293, -5.5871

('answer_start_index: ',
 21,
 'answer_end_index: ',
 22,
 'start_token_score: ',
 3.38,
 'end_token_score: ',
 4.55,
 'answer:',
 'medical domain')

In [None]:
#@markdown ---
question= "What did cardiology team perform?" #@param {type:"string"}
passage = "To determine the extent of the heart damage and identify any blockages in the coronary arteries, an angiography was performed by the cardiology team. The results indicated significant blockages that required immediate intervention.John Doe was prescribed medications to manage his condition and was advised to follow a healthy diet regimen. The cardiology team emphasized the importance of lifestyle changes to prevent future cardiac events. A follow-up appointment was scheduled to monitor his recovery and adjust medications as needed.John Doe was instructed to maintain a record of his blood pressure and report any recurrent symptoms. His treatment plan also included regular physical activity and stress management techniques to support his overall heart health." #@param {type:"string"}
#@markdown ---

ans = bert_question_answer(question, passage)
print(ans)



input_ids [101, 2054, 2106, 4003, 20569, 2136, 4685, 1029, 102, 2000, 5646, 1996, 6698, 1997, 1996, 2540, 4053, 1998, 6709, 2151, 3796, 13923, 1999, 1996, 21887, 2854, 28915, 1010, 2019, 17076, 26535, 2001, 2864, 2011, 1996, 4003, 20569, 2136, 1012, 1996, 3463, 5393, 3278, 3796, 13923, 2008, 3223, 6234, 8830, 1012, 2198, 18629, 2001, 16250, 20992, 2000, 6133, 2010, 4650, 1998, 2001, 9449, 2000, 3582, 1037, 7965, 8738, 6939, 2078, 1012, 1996, 4003, 20569, 2136, 13155, 1996, 5197, 1997, 9580, 3431, 2000, 4652, 2925, 15050, 2824, 1012, 1037, 3582, 1011, 2039, 6098, 2001, 5115, 2000, 8080, 2010, 7233, 1998, 14171, 20992, 2004, 2734, 1012, 2198, 18629, 2001, 10290, 2000, 5441, 1037, 2501, 1997, 2010, 2668, 3778, 1998, 3189, 2151, 28667, 29264, 8030, 1012, 2010, 3949, 2933, 2036, 2443, 3180, 3558, 4023, 1998, 6911, 2968, 5461, 2000, 2490, 2010, 3452, 2540, 2740, 1012, 102]
sep_index 8
len_question 9
len_passage 133
segment_ids [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [None]:
question=data.iloc[:, 1][0]
passage=data.iloc[:, 2][0]
ans = bert_question_answer(question, passage)
print("question ",question)
print("ans ",ans)

input_ids [101, 2040, 2003, 2012, 3891, 2005, 1048, 24335, 8458, 10085, 21252, 16480, 9488, 3549, 2075, 13706, 1006, 29215, 2213, 1007, 1029, 1029, 102, 29215, 2213, 2615, 15245, 2064, 5258, 2044, 7524, 2000, 4840, 17996, 1010, 7510, 2015, 1010, 26308, 1010, 2030, 21016, 4475, 2013, 10372, 28156, 1012, 6726, 2089, 2036, 5258, 2043, 2122, 4475, 2024, 3495, 3107, 2046, 3714, 3096, 1010, 1996, 4451, 1010, 1996, 2159, 1010, 2030, 1996, 2677, 1010, 2030, 10712, 1010, 3081, 1996, 6805, 1997, 2019, 10372, 8469, 3372, 1012, 2711, 1011, 2000, 1011, 2711, 6726, 2038, 2025, 2042, 2988, 1010, 2007, 1996, 6453, 1997, 7471, 6726, 2013, 10372, 2388, 2000, 10768, 5809, 1010, 1998, 6524, 1010, 2083, 5812, 22291, 3370, 1012, 102]
sep_index 22
len_question 23
len_passage 93
segment_ids [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,