In [1]:
import os, torch, logging
from datasets import load_dataset, load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from transformers import LongT5ForConditionalGeneration
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer
from torch.utils.data import Dataset
import datasets
import pandas as pd

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Data

In [3]:
args_max_length = 512
args_target_max_length = 184
args_input_eos = True

In [4]:
data_path = "discq/data/clinical_qg_one_sentence"
data = load_from_disk(data_path)

raw_train_dataset = data['train']
raw_eval_dataset = data['validation']

train_df = pd.DataFrame(raw_train_dataset)
eval_df = pd.DataFrame(raw_eval_dataset)

PROMPTS = {
    "after_reading_what_question": """<<SYS>>\nYou're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. \n<</SYS>>\n\nGiven the EMR: \"{text}\"\nAfter reading the above EMR, what question do you have about "{trigger}"?\nQuestion:""",
}

template = PROMPTS["after_reading_what_question"]
column_names = raw_eval_dataset.column_names
padding = "max_length"

In [5]:
def convert_data(df):
    txts = []
    for idx, row in df.iterrows():
        txt = "<s>[INST] " + template.format(text=row["text"].strip(), trigger=row["trigger"]) + "\n[/INST]\n" + row['question'] + " </s>"
        txts.append(txt)
    return pd.DataFrame.from_dict({'text': txts})

In [6]:
train_data = convert_data(train_df)
eval_data = convert_data(eval_df)

train_dataset = datasets.Dataset.from_pandas(train_data)
eval_dataset = datasets.Dataset.from_pandas(eval_data)

In [7]:
print(train_dataset['text'][0])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. 
<</SYS>>

Given the EMR: "Mr. Hohlt is a 76 - year-old man with a history of hypertension , hyperlipidemia and prior myocardial infarction with a percutaneous transluminal coronary angioplasty of the right coronary artery in 2005 at the Cambridge Health Alliance ."
After reading the above EMR, what question do you have about "prior myocardial infarction"?
Question:
[/INST]
When was this? </s>


# Training

In [8]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM"
)

model.add_adapter(peft_config)
#lora_model = get_peft_model(model, peft_config)

In [9]:
training_args = TrainingArguments(
    output_dir=f"results/llama2/7b",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=10,
    weight_decay=1e-6,
    learning_rate=2e-4,
    logging_dir=f"results/logs",
    adam_epsilon=1e-3,
    max_grad_norm=1.0,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    disable_tqdm=False,
    gradient_accumulation_steps=16,
    #bf16=True,
)

In [10]:
trainer = SFTTrainer(
    model=model, 
    tokenizer=tokenizer,
    args=training_args, 
    train_dataset=train_dataset,
    eval_dataset=eval_dataset, 
    dataset_text_field="text",
    max_seq_length = args_max_length,
)

Map:   0%|          | 0/1398 [00:00<?, ? examples/s]

Map:   0%|          | 0/193 [00:00<?, ? examples/s]

In [11]:
train_results = trainer.train()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
0,3.0055,1.755797
1,1.1123,0.779789
2,0.7828,0.658465
3,0.7185,0.63029
4,0.6978,0.625987




# Test

In [16]:
model = trainer.model
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(
            in_features=4096, out_features=4096, bias=False
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
          )
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(
            in_features=4096, out_features=4096, bias=False
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=Fals

In [17]:
device = torch.device("cuda")
model = model.to(device)
output_path = "results/llama2_finetuned_generated.txt"
results = []
baseline = pd.read_csv("baselines/baseline_with_context.csv", sep='\t')
prompts = []
sys_msg = "<<SYS>>\nYou're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. \n<</SYS>>\n\n"


for index, row in baseline.iterrows():
    context = row["context"].strip()
    trigger = row["trigger"]
    context = context.strip().rstrip(".")
    prompt = f"<s>[INST] {sys_msg}Given the EMR: \"{context}\"\nAfter reading the above EMR, what question do you have about \"{trigger}\"?\nQuestion:\n[/INST]"
    prompts.append(prompt)

In [18]:
text_gen = pipeline(
    task="text-generation", 
    model=model, 
    tokenizer=tokenizer, 
    max_length=500,
    device=device,
    do_sample=True,
    temperature=0.5,
    top_p=0.5,
)

In [19]:
print(prompts[4])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. 
<</SYS>>

Given the EMR: "He was given Zosyn for empiric treatment . Gram stain and culture of the ascites failed to identify any organism . He was weaned from oxygen requirement and was continued on normal saline boluses and. intravenous albumin for treatment of prerenal azotemia . Somatostatin and midodrine were also given for this condition and he also received packed red blood cells . On postoperative day number ten , the patient returned to the floor with improved renal function and without oxygen requirement . The somatostatin and the midodrine were discontinued and the regimen of Lasix and spironolactone was started for diuresis "
After reading the above EMR, what question do you have about "prerenal azotemia"?
Question:
[/INST]


In [20]:
outputs = text_gen(prompts, batch_size=64)

In [26]:
print(outputs[1][0]['generated_text'])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. 
<</SYS>>

Given the EMR: "He was given Unasyn for empiric antibiotic treatment . At the time , blood , urine and bile was sent for culture , which were all negative . A chest x-ray was also negative for any pneumonia . Postoperatively , Mr. Cordano also suffered from low urine output ranges between 25 to 30 cc per hour . Fractional excretion of sodium was less than 1% suggesting a prerenal etiology . The sodium creatinine peaked at 2.0 and the patient was treated with boluses of intravenous fluid and intravenous albumin infusion was also administered . He responded well to this regimen and made adequate urine output "
After reading the above EMR, what question do you have about "low urine output ranges between 25 to 30 cc per hour ."?
Question:
[/INST]
How long did the patient have low urine out

In [21]:
for i, out in enumerate(outputs):
    txt_file = f"results/llama2/7b/generated/{i}.txt"
    with open(txt_file, "w", encoding="utf-8") as f:
        f.write(out[0]['generated_text'])

In [23]:
output = text_gen(prompts[1002])
print(output[0]['generated_text'])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. 
<</SYS>>

Given the EMR: "T99.9  HR 70AF  BP 120/69  RR 20 O2sat 94%RA. Neuro A&Ox3 MAE , non focal exam. Pulm CTAB. CV irreg irreg , sternum stable incision CDI. Abdm soft , NT/ND/+BS. Ext warm 2+ edema. Pertinent Results :. 2019-06-25 04:10 PM. GLUCOSE - 97 UREA N - 44 * CREAT - 2.0 * SODIUM - 143 POTASSIUM - 4.1 CHLORIDE - 107 TOTAL CO2 - 27 ANION GAP - 13. 2019-06-25 04:10 PM. ALT(SGPT) - 19 AST(SGOT) - 19 LD(LDH) - 153 ALK PHOS - 58 TOT BILI - 0.4. 2019-06-25 04:10 PM. ALBUMIN - 4.3. 2019-06-25 04:10 PM. %HbA1c - 5.7. 2019-06-25 04:10 PM. TSH - 1.2"
After reading the above EMR, what question do you have about "UREA N - 44 * CREAT - 2.0 *"?
Question:
[/INST]  Given the EMR, what is the patient's urine output? 


In [24]:
print(outputs[1002][0]['generated_text'])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient. 
<</SYS>>

Given the EMR: "T99.9  HR 70AF  BP 120/69  RR 20 O2sat 94%RA. Neuro A&Ox3 MAE , non focal exam. Pulm CTAB. CV irreg irreg , sternum stable incision CDI. Abdm soft , NT/ND/+BS. Ext warm 2+ edema. Pertinent Results :. 2019-06-25 04:10 PM. GLUCOSE - 97 UREA N - 44 * CREAT - 2.0 * SODIUM - 143 POTASSIUM - 4.1 CHLORIDE - 107 TOTAL CO2 - 27 ANION GAP - 13. 2019-06-25 04:10 PM. ALT(SGPT) - 19 AST(SGOT) - 19 LD(LDH) - 153 ALK PHOS - 58 TOT BILI - 0.4. 2019-06-25 04:10 PM. ALBUMIN - 4.3. 2019-06-25 04:10 PM. %HbA1c - 5.7. 2019-06-25 04:10 PM. TSH - 1.2"
After reading the above EMR, what question do you have about "UREA N - 44 * CREAT - 2.0 *"?
Question:
[/INST]  Given the EMR, what is the patient's urine output? 
