In [238]:
import os
import torch
import json
import re
import warnings
import pandas as pd
warnings.filterwarnings('ignore')
os.environ['HF_HOME'] = '/raid/abhilash/huggingfacecache/huggingface/hub/'

from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
from fuzzywuzzy import fuzz
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
from transformers.utils import logging
logging.set_verbosity_error() 

In [47]:
# ! ls /raid/abhilash/huggingfacecache/huggingface/hub

### Define Model

In [6]:
'''
Not using it for now
'''
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [182]:
model_name = "BioMistral/BioMistral-7B"  # Specify the name or path of the Mistral model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    cache_dir="/raid/abhilash/huggingfacecache/huggingface/hub/"
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_bos_token=True, trust_remote_code=True)

# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/raid/abhilash/huggingfacecache/huggingface/hub/")
# model = AutoModel.from_pretrained(model_name)

In [183]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
      (1): MistralDecoderLayer(
        (self

In [21]:
# Get the number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params/1e9} billion")

Total number of parameters: 7.241732096 billion


### Read and Process Data

In [44]:
path = '../MedQA/questions'
us_dir = 'US/metamap_extracted_phrases'
taiwan_dir = 'Taiwan/metamap'

us_path = os.path.join(path, us_dir)
taiwan_path = os.path.join(path, taiwan_dir)

In [45]:
us_path

'../MedQA/questions/US/metamap_extracted_phrases'

In [46]:
taiwan_path

'../MedQA/questions/Taiwan/metamap'

In [51]:
def read_json(filepath):
    data = []
    with open(filepath, 'r', encoding='utf-8') as file:
        for line in file:
            json_object = json.loads(line.strip())
            data.append(json_object)
    return data

In [54]:
'''
US QA
'''
us_train_path = os.path.join(us_path, 'train', 'phrases_train.jsonl')  
us_dev_path = os.path.join(us_path, 'dev', 'phrases_dev.jsonl')
us_test_path = os.path.join(us_path, 'test', 'phrases_test.jsonl')

us_train_qa = read_json(us_train_path)
us_dev_qa = read_json(us_dev_path)
us_test_qa = read_json(us_test_path)

'''
Taiwan QA
'''
taiwan_train_path = os.path.join(taiwan_path, 'train', 'tw_train.jsonl')  
taiwan_dev_path = os.path.join(taiwan_path, 'dev', 'tw_dev.jsonl')
taiwan_test_path = os.path.join(taiwan_path, 'test', 'tw_test.jsonl')

taiwan_train_qa = read_json(taiwan_train_path)
taiwan_dev_qa = read_json(taiwan_dev_path)
taiwan_test_qa = read_json(taiwan_test_path)

In [58]:
us_train_qa[0]

{'question': 'A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?',
 'answer': 'Nitrofurantoin',
 'options': {'A': 'Ampicillin',
  'B': 'Ceftriaxone',
  'C': 'Ciprofloxacin',
  'D': 'Doxycycline',
  'E': 'Nitrofurantoin'},
 'meta_info': 'step2&3',
 'answer_idx': 'E',
 'metamap_phrases': ['23 year old',
  'weeks presents',
  'burning',
  'urination',
  'states',
  'started 1 day',
  'worsening',
  'cranberry',
  'well',
  'followed by',
  'doctor',
  'pregnancy',
 

In [194]:
def format_options(options):
    option_f = []
    for k, v in options.items():
        option_f.append(f"\"{v}\"")
    return ', '.join(option_f)

def prompt_generator(question, options, option_flag=True):
    if option_flag:
        question += " The right choice out of " + options + " is " 
    else:
        question += " The right answer is "
    return question

### Perform QA

In [184]:
question = us_train_qa[4]['question']
answer = us_train_qa[4]['answer']
options = us_train_qa[4]['options']

In [267]:
def get_predictions(inp_prompt, device=0):
    '''
    tokenize
    '''
    model_input = tokenizer(inp_prompt, return_tensors="pt").to("cuda:"+str(device))
    model.eval()
    with torch.no_grad():
        output = model.generate(**model_input, max_new_tokens=200, repetition_penalty=1.15)
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
    prediction = answer.split('is')[-1]
    prediction = re.sub(r'[^a-zA-Z0-9 ,:.\']', '', prediction).strip()
    return answer, prediction

In [151]:
'''
with options
'''
options_f = format_options(options)
inp_prompt = prompt_generator(question, options_f)
final_ans, prediction = get_predictions(inp_prompt)
print(f"QUESTION:: \n {inp_prompt}")
print(f"\nRESPONSE :: \n {prediction}")
print(f"\nGROUND-TRUTH :: \n {answer}")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


QUESTION:: 
 A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? The right choice out of "Factor V Leiden", "Hemophilia A", "Lupus anticoagulant", "Protein C deficiency", "Von Willebrand disease" is 

RESPONSE :: 
 Factor V Leiden

GROUND-TRUTH :: 
 Von Willebrand disease


In [152]:
'''
without options
'''
inp_prompt = prompt_generator(question, options=None, option_flag=False)
final_ans, prediction = get_predictions(inp_prompt)
print(f"QUESTION:: \n {inp_prompt}")
print(f"\nRESPONSE :: \n {prediction}")
print(f"\nGROUND-TRUTH :: \n {answer}")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


QUESTION:: 
 A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? The right answer is 

RESPONSE :: 
 thalassemia trait.

GROUND-TRUTH :: 
 Von Willebrand disease


In [156]:
prediction

'thalassemia trait.'

In [157]:
answer

'Von Willebrand disease'

In [174]:
'''
Fuzzywuzzy similarity
'''
fuzz.ratio(prediction.lower(), answer.lower()) / 100.0

0.24

In [181]:
'''
Semantic Similarity
'''
ss_model = SentenceTransformer('all-MiniLM-L6-v2')
embedding1 = ss_model.encode(prediction, convert_to_tensor=True)
embedding2 = ss_model.encode(answer, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)
cosine_scores

tensor([[0.4840]], device='cuda:0')

### Evaluation

To determine whether the answer is correct or not, we do,

if (semantic_similarity + fuzzy_score)/2 > 0.5 -> correct answer else incorrect answer

In [273]:
device=7

In [197]:
'''
define sentence transformer
'''
ss_model = SentenceTransformer('all-MiniLM-L6-v2')

In [269]:
def get_fuzzy_score(prediction, answer):
    return fuzz.ratio(prediction.lower(), answer.lower()) / 100.0

def get_semantic_similarity(prediction, answer):
    embedding1 = ss_model.encode(prediction, convert_to_tensor=True)
    embedding2 = ss_model.encode(answer, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)
    return cosine_scores

def evaluate_predictions(dataset, device=0):
    correct_w = []
    correct_wo = []
    pred_wo = []
    gt = []
    pred_w = []
    qstn = []
    
    for ind in tqdm(range(len(dataset))):
        
        question = dataset[ind]['question']
        answer = dataset[ind]['answer']
        options = dataset[ind]['options']
        
        '''
        generating with options
        '''
        options_f = format_options(options)
        inp_prompt = prompt_generator(question, options_f)
        final_ans, prediction = get_predictions(inp_prompt, device=device)
        
        pred_w.append(prediction)
        gt.append(answer)
        qstn.append(inp_prompt)
        
        fuzzy_score = get_fuzzy_score(prediction, answer)
        ss_score = get_semantic_similarity(prediction, answer)
        
        if (fuzzy_score+ss_score)/2 > 0.5:
            correct_w.append(1)
            # print(f"WITH OPTION: \nCorrect for prediction = {prediction} and ground-truth = {answer}")
        else:
            correct_w.append(0)
            
        '''
        generating without options
        '''
        inp_prompt_ = prompt_generator(question, options=None, option_flag=False)
        final_ans_, prediction_ = get_predictions(inp_prompt_, device=device)
        
        pred_wo.append(prediction_)
        
        fuzzy_score = get_fuzzy_score(prediction_, answer)
        ss_score = get_semantic_similarity(prediction_, answer)
        
        if (fuzzy_score+ss_score)/2 > 0.5:
            correct_wo.append(1)
            # print(f"WITHOUT OPTION: \nCorrect for prediction = {prediction_} and ground-truth = {answer}")
        else:
            correct_wo.append(0)
            
    accuracy_w = (sum(correct_w)/len(dataset))*100
    accuracy_wo = (sum(correct_wo)/len(dataset))*100
    
    return {'question':qstn,
            'gt':gt,
            'preds_W':pred_w,
            'correct_inds_W':correct_w,
            'preds_WO':pred_wo,
            'correct_inds_WO':correct_wo
           }, accuracy_w, accuracy_wo

In [265]:
cols = ['Question', 'Answer', 'Prediction With options', 'Is Correct', 'Prediction W/O options', 'Is Correct WO']

In [274]:
'''
us test dataset
'''
metrics, acc_w, acc_wo = evaluate_predictions(us_test_qa, device=device)
print(f"\nTEST DATASET ACCURACY::\nWith options = {acc_w}\nWithout options = {acc_wo}")

df = pd.DataFrame(metrics)
df.columns=cols
df.to_csv('us_test_evaluation.csv', index=False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1273/1273 [46:09<00:00,  2.18s/it]


VAL DATASET ACCURACY::
With options = 20.816967792615866
Without options = 6.912804399057344





In [275]:
'''
us dev dataset
'''
metrics, acc_w, acc_wo = evaluate_predictions(us_dev_qa, device=device)
print(f"\nVAL DATASET ACCURACY::\nWith options = {acc_w}\nWithout options = {acc_wo}")

df = pd.DataFrame(metrics)
df.columns=cols
df.to_csv('us_dev_evaluation.csv', index=False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1272/1272 [45:59<00:00,  2.17s/it]


VAL DATASET ACCURACY::
With options = 23.5062893081761
Without options = 8.647798742138365



