In [1]:
import os, torch, logging
from datasets import load_dataset, load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from transformers import LongT5ForConditionalGeneration
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer
from torch.utils.data import Dataset
import datasets
import pandas as pd
import json
from sklearn.model_selection import train_test_split

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True, max_length=4000, truncate=True)
#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Data

In [3]:
def apply_prompt(txt, path, triggers):
    inst = "You're a doctor and you were given the following EMR by another doctor. You need to find some clinical triggers in this EMR that you think are important and need more infomation."
    #trig_txt = ",\n".join(triggers)
    with open(os.path.join(path, txt), "r", encoding="utf-8") as f:
        context = f.read()
    prompt = "<s>[INST] <<SYS>>\n{inst}\n<<SYS>>\n\nPlease find the clinical triggers for the following EMR.\n\nEMR: \"{auxiliary}\"\n\nClinical triggers found:\n[/INST]\n{trigs}\n</s>\n"
    return prompt.format(inst=inst, auxiliary=context, trigs=triggers)

In [4]:
with open("trigger_dataset_1k/dataset.json") as f:
    dataset = json.load(f)
    
new_data = {}
for k, v in dataset.items():
    triggers = ["- " + trig.strip(" .") for trig in v]
    trig_txt = "\n".join(triggers)
    new_data[k] = trig_txt

In [8]:
txts = []
path = "trigger_dataset_1k/data"
for k, v in new_data.items():
    txt = apply_prompt(k, path, v)
    txts.append(txt)

In [9]:
train_txts, eval_txts = train_test_split(txts, test_size=0.2)
train_data = pd.DataFrame.from_dict({'text': train_txts})
eval_data = pd.DataFrame.from_dict({'text': eval_txts})

train_dataset = datasets.Dataset.from_pandas(train_data)
eval_dataset = datasets.Dataset.from_pandas(eval_data)

In [10]:
print(train_dataset['text'][0])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You need to find some clinical triggers in this EMR that you think are important and need more infomation.
<<SYS>>

Please find the clinical triggers for the following EMR.

EMR: "Admission Date :
2019-06-25
Discharge Date :
2019-07-01
Date of Birth :
1950-04-19
Sex :
M
Service :
CARDIOTHORACIC
Allergies :
Penicillins / Norvasc / Verapamil
Attending : Christopher Q Thompson , M.D.
Chief Complaint :
preop w/ for knee surgery revealed cardiac dz , asymptomatic
Major Surgical or Invasive Procedure :
CABGx4 ( 06-26 )
History of Present Illness :
69yoM with OA having workup for knee replacements found to be in Afib , had stress test that was positive followed by cardiac catheterization which revealed severe 3VD .
Then referred for CABG
Past Medical History :
HTN
chol
AFib
OA needs bilat arthroplasty
CRI
Social History :
Lives with wife .
Remote Timothy , quit 1990 . + ETOH ( 1-2 drinks / Rivero )
Famil

# Training

In [11]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM"
)

model.add_adapter(peft_config)

In [12]:
training_args = TrainingArguments(
    output_dir=f"results/llama2/7b-trigger",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=10,
    weight_decay=1e-6,
    learning_rate=2e-4,
    logging_dir=f"results/logs",
    adam_epsilon=1e-3,
    max_grad_norm=1.0,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    disable_tqdm=False,
    gradient_accumulation_steps=16,
    #bf16=True,
)

In [13]:
args_max_length = 1100

trainer = SFTTrainer(
    model=model, 
    tokenizer=tokenizer,
    args=training_args, 
    train_dataset=train_dataset,
    eval_dataset=eval_dataset, 
    dataset_text_field="text",
    max_seq_length = args_max_length,
)

Map:   0%|          | 0/183 [00:00<?, ? examples/s]

Map:   0%|          | 0/46 [00:00<?, ? examples/s]

In [14]:
train_results = trainer.train()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
0,2.3527,2.235887
1,2.3337,2.215906
2,2.1202,2.197095
3,2.2966,2.184992
4,2.0214,2.180815




In [56]:
with open("results/llama2/7b-trigger-5e/train_log.json", "w") as f:
    json.dump(train_results, f)

# Data Processing

In [3]:
model = AutoModelForCausalLM.from_pretrained("results/llama2/7b-trigger-5e/checkpoint-55")
device = torch.device("cuda")

text_gen = pipeline(
    task="text-generation", 
    model=model, 
    tokenizer=tokenizer, 
    max_length=4000,
    device=device
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
txts_list = os.listdir("trigger_dataset_3k7/data")
txts_list = ["trigger_dataset_3k7/data/" + txt for txt in txts_list if "ipynb" not in txt]
txts_list

['trigger_dataset_3k7/data/record-30.txt',
 'trigger_dataset_3k7/data/0109.txt',
 'trigger_dataset_3k7/data/record-55.txt',
 'trigger_dataset_3k7/data/record-28.txt',
 'trigger_dataset_3k7/data/0245.txt',
 'trigger_dataset_3k7/data/0313_p2.txt',
 'trigger_dataset_3k7/data/0301.txt',
 'trigger_dataset_3k7/data/0361.txt',
 'trigger_dataset_3k7/data/0005.txt',
 'trigger_dataset_3k7/data/0125.txt',
 'trigger_dataset_3k7/data/0466.txt',
 'trigger_dataset_3k7/data/record-176.txt',
 'trigger_dataset_3k7/data/0045.txt',
 'trigger_dataset_3k7/data/0149_p1.txt',
 'trigger_dataset_3k7/data/0181.txt',
 'trigger_dataset_3k7/data/0460.txt',
 'trigger_dataset_3k7/data/0309.txt',
 'trigger_dataset_3k7/data/record-140.txt',
 'trigger_dataset_3k7/data/record-175.txt',
 'trigger_dataset_3k7/data/0061_p2.txt',
 'trigger_dataset_3k7/data/record-141_p1.txt',
 'trigger_dataset_3k7/data/record-84.txt',
 'trigger_dataset_3k7/data/0001.txt',
 'trigger_dataset_3k7/data/0033.txt',
 'trigger_dataset_3k7/data/recor

In [7]:
inst = "You're a doctor and you were given the following EMR by another doctor. You need to find 5 clinical triggers in this EMR that you think are the most important and need more infomation."
#trig_txt = ",\n".join(triggers)
template = "<s>[INST] <<SYS>>\n{inst}\n<<SYS>>\n\nPlease find the clinical triggers for the following EMR.\n\nEMR: \"{auxiliary}\"\n\nClinical triggers found:\n[/INST]"

prompts = []
for txt in txts_list:
    with open(txt, "r", encoding="utf-8") as f:
        content = f.read()
    prompt = template.format(inst=inst, auxiliary=content)
    prompts.append(prompt)
print(prompts[0])

<s>[INST] <<SYS>>
You're a doctor and you were given the following EMR by another doctor. You need to find 5 clinical triggers in this EMR that you think are the most important and need more infomation.
<<SYS>>

Please find the clinical triggers for the following EMR.

EMR: "Admission Date :
2015-09-16
Discharge Date :
2015-09-21
Date of Birth :
1962-09-26
Sex :
M
Service :
CARDIOTHOR
REASON FOR ADMISSION :
Mitral valve repair .
HISTORY OF PRESENT ILLNESS :
The patient is a 52 year old Czech speaking gentleman with a history of severe mitral regurgitation .
For the past several years , the patient has continued to have significant exercise intolerance which has limited his ability to work .
He reports severe dyspnea on exertion .
The patient had a recent admission earlier in the month and he was discharged on 2015-09-10 , after undergoing cardiac catheterization .
This was significant for demonstrating no flow-limited disease in the coronaries but several fistulae
from the left anterio

In [4]:
from tqdm import tqdm

with open("mimic-s2orc/pdf_list.txt", "r") as f:
    pdfs_list = f.readlines()
dirs_list = ["mimic-s2orc/" + txt.strip().replace(".pdf", "") for txt in pdfs_list if ".pdf" in txt]

inst = "You're a doctor and you were given the following EMR by another doctor. You need to find 5 clinical triggers in this EMR that you think are the most important and need more infomation."
#trig_txt = ",\n".join(triggers)
template = "<s>[INST] <<SYS>>\n{inst}\n<<SYS>>\n\nPlease find the clinical triggers for the following EMR.\n\nEMR: \"{auxiliary}\"\n\nClinical triggers found:\n[/INST]"

In [5]:
dirs_list[53:55]

['mimic-s2orc/pdf/record-25/5849/26804946',
 'mimic-s2orc/pdf/record-25/5849/31443997']

In [None]:
for dir_path in tqdm(dirs_list[53:55]):
    out_dir = dir_path.replace("/pdf/", "/pdf_trigger/")
    os.makedirs(out_dir, exist_ok=True)
    txts = [txt for txt in os.listdir(dir_path) if ".txt" in txt]
    if len(txts) < 4:
        n = len(txts)
    elif len(txts) >= 4 and len(txts) < 7:
        n = len(txts) - 2
    elif len(txts) >= 7 and len(txts) < 10:
        n = len(txts) - 3
    else:
        n = 7
    for txt in txts[:n]:
        with open(os.path.join(dir_path, txt), "r", encoding="utf-8") as f:
            content = f.read()
        prompt = template.format(inst=inst, auxiliary=content)
        out = text_gen(prompt)
        out_path = os.path.join(out_dir, txt)
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(out[0]['generated_text'])