# Import Libraries

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
import os
import json
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig
from trl import SFTTrainer


# Import Dataset

In [3]:
dataset = load_dataset("bhatthars/nbme_patient_notes",split='train')

In [4]:
split_dataset = dataset.train_test_split(test_size=0.2)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

# Load base model and tokenizer using 4-bit quantization

In [5]:
base_model = 'meta-llama/Llama-2-7b-chat-hf'

In [6]:
compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)

In [7]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quant_config,
    device_map={"": 0},use_auth_token='auth_token'
    
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True,token="auth_token")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Define PEFT QLoRA arguments

In [9]:
peft_params = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

In [10]:
training_params = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to="tensorboard"
)

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset = eval_dataset,
    peft_config=peft_params,
    dataset_text_field="Prompt",
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_params,
    packing=False,
)

# Train

In [12]:
trainer.train()

Step,Training Loss
25,2.3517
50,1.7088
75,1.5284
100,1.4906
125,1.4209
150,1.425
175,1.3629
200,1.4104


TrainOutput(global_step=200, training_loss=1.5873416137695313, metrics={'train_runtime': 550.6409, 'train_samples_per_second': 1.453, 'train_steps_per_second': 0.363, 'total_flos': 1.2677846263922688e+16, 'train_loss': 1.5873416137695313, 'epoch': 1.0})

# Save and Evaluate

In [13]:
trainer.model.save_pretrained("metadata/fine_tuned")
trainer.tokenizer.save_pretrained("metadata/tokenft")

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


('metadata/tokenft/tokenizer_config.json',
 'metadata/tokenft/special_tokens_map.json',
 'metadata/tokenft/tokenizer.json')

In [21]:
# from tensorboard import notebook
# log_dir = "results/runs"
# notebook.start("--logdir {} --port 4000".format(log_dir))


In [14]:
# Evaluate the model
metrics = trainer.evaluate()
print(metrics)

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


{'eval_loss': 1.3755254745483398, 'eval_runtime': 69.7525, 'eval_samples_per_second': 2.867, 'eval_steps_per_second': 0.358, 'epoch': 1.0}


In [15]:
fine_tuned_model_path = "metadata/fine_tuned"
fine_tuned_tokenizer_path = "metadata/tokenft"

In [16]:
model = AutoModelForCausalLM.from_pretrained(
    fine_tuned_model_path,
    device_map={"": 0},
    use_auth_token='auth_token'
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
tokenizer = AutoTokenizer.from_pretrained(fine_tuned_tokenizer_path, trust_remote_code=True,token="auth_token")

In [18]:
with open("preprocessed_data.json","r") as file:
    data = json.load(file)

In [19]:
ph = data['0']

In [20]:
eval_dataset['Prompt'][1]

'<s>[INST]<<SYS>>You are a resourceful medical assistant. Please ensure your answers are unbiased. Make sure the answers are from the text provided.<</SYS>>Patient Note: 17 yo male with no PMH presents with 3-4 months of palpitations\r\n-episodes occur randomly, no associated with activity, associated with shortness of breath and pre-syncope\r\n-no sweating, feeling of impending doom, anxiety with episodes, no diarrhea\r\n-takes adderall a few times that is prescribed to friend, however has been taking this for a year now\r\n-no history of thyroid problems\r\nPMH: none\r\nMeds: adderall (not prescribed)\r\nFamHx: Mom- "thyroid problem"; dad- heart attack at 52\r\nSocial: lives with roomate; 3-4 alcoholic beverages/week, no durgs, sexually active w girlfriend and uses condoms.\nExtract phrases from this text which may help understand the patient\'s medical condition.[/INST]dad- heart attack, Mom- "thyroid problem, episodes, adderall, adderall, shortness of breath, palpitations, 3-4 mont

In [21]:
idx = eval_dataset["Prompt"][1].index("[/INST]") + 7

pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer,max_length=435, truncation = True)

result = pipe(eval_dataset["Prompt"][1][:idx])
print(result[0]['generated_text'])

<s>[INST]<<SYS>>You are a resourceful medical assistant. Please ensure your answers are unbiased. Make sure the answers are from the text provided.<</SYS>>Patient Note: 17 yo male with no PMH presents with 3-4 months of palpitations
-episodes occur randomly, no associated with activity, associated with shortness of breath and pre-syncope
-no sweating, feeling of impending doom, anxiety with episodes, no diarrhea
-takes adderall a few times that is prescribed to friend, however has been taking this for a year now
-no history of thyroid problems
PMH: none
Meds: adderall (not prescribed)
FamHx: Mom- "thyroid problem"; dad- heart attack at 52
Social: lives with roomate; 3-4 alcoholic beverages/week, no durgs, sexually active w girlfriend and uses condoms.
Extract phrases from this text which may help understand the patient's medical condition.[/INST]Mom- "thyroid problem", palpitations, no diarrhea, adderall, 17 yo, male, shortness of breath, pre-syncope, 3-4 months, 3-4 months, 3-4 months

# Evaluate

In [1]:
# !pip install streamlit

In [8]:
def generate(patient_note, model_name, tokenizer_version):
    system_prompt = system_prompt = "You are a resourceful medical assistant. Please ensure your answers are unbiased. Make sure the answers are from the text provided"
    pipe = pipeline(task="text-generation",model = model_name, tokenizer = tokenizer_version,max_length=600)
    result = pipe(f"<s>[INST]<<SYS>>{system_prompt}<<SYS>>{patient_note} [/INST]")
    
    # process the keywords
    s1 = result[0]['generated_text'].split("[/INST]")
    lst = s1[1].split(',')
    a  = set()
    for x in lst: a.add(x)
    return {
        'response':s1,
        'keywords': a
    }

In [1]:
predicted = []
for i in range((eval_dataset.shape[0])):
    user_prompt = eval_dataset['Prompt'][i].split('[/INST]')[0].split('<</SYS>>')[1]
    result = generate(user_prompt,model,tokenizer)
    predicted_keywords = list(result['keywords'])
    predicted.append(predicted_keywords)

In [10]:
ground_truth =  [ x.split('[/INST]')[1] for x in eval_dataset['Prompt'] ] 

In [11]:
lst = []
for x in ground_truth:
    lst.append(x.replace("</s>","").split(','))

In [12]:
import csv
import os

def file_to_csv(lst,filename):
    # Open the file in write mode
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)

        # Write the data to the CSV file
        writer.writerows(lst)

    print("Data saved to", filename)

In [13]:
file_to_csv(lst,"Ground_truth.csv")
file_to_csv(predicted,"Predicted.csv")

Data saved to Ground_truth.csv
Data saved to Predicted.csv
