In [1]:
import os
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup, BertConfig
from sentence_transformers import SentenceTransformer
import re
import nltk
import timeit
import os
import numpy as np
from numpy import dot
from numpy.linalg import norm
import json

In [2]:
model = SentenceTransformer("dmis-lab/biobert-base-cased-v1.2")

No sentence-transformers model found with name /ukp-storage-1/pham/.cache/torch/sentence_transformers/dmis-lab_biobert-base-cased-v1.2. Creating a new one with MEAN pooling.


In [3]:
txts = []
for path, subdirs, files in os.walk("trigger_dataset_3k7/pdf_trigger"):
    for name in files:
        txts.append(os.path.join(path, name))
txts = [txt for txt in txts if "_triggers.txt" in txt]
txts

['trigger_dataset_3k7/pdf_trigger/0309/4210/36258995/36258995_0_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/4210/36258995/36258995_2_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/4210/35295623/35295623_1_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/4210/35295623/35295623_0_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/4210/29763019/29763019_1_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/4210/29763019/29763019_0_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/33279945/33279945_1_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/33279945/33279945_3_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/33127438/33127438_1_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/33127438/33127438_0_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/29627355/29627355_1_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/42731/29627355/29627355_0_triggers.txt',
 'trigger_dataset_3k7/pdf_trigger/0309/7907/89

In [4]:
def process_trigger_text(txt):
    with open(txt, "r", encoding="utf-8") as f:
        lines = f.readlines()
    triggers = []
    contexts = []
    for line in lines:
        parts = line.strip().split(":")
        triggers.append(parts[0][2:].strip())
        contexts.append(parts[1].strip())
    return triggers, contexts

In [5]:
txts_2 = os.listdir("trigger_dataset_3k7/triggers")
txts_2 = [txt for txt in txts_2 if "txt" in txt]
txts_2

['0125.txt',
 '0393_p2.txt',
 'record-26.txt',
 '0245.txt',
 'record-29.txt',
 'record-18.txt',
 'record-52.txt',
 '0466.txt',
 '0177.txt',
 'record-27.txt',
 '0113.txt',
 '0165.txt',
 'record-48.txt',
 '0065.txt',
 '0365.txt',
 'record-124.txt',
 '0313_p1.txt',
 '0421.txt',
 '0397.txt',
 'record-142.txt',
 '0389.txt',
 '0273.txt',
 '0129.txt',
 '0061_p2.txt',
 'record-106.txt',
 '0121.txt',
 '0189.txt',
 '0061_p1.txt',
 '0401.txt',
 '0442.txt',
 '0393_p1.txt',
 'record-13.txt',
 '0093.txt',
 '0293.txt',
 '0041.txt',
 '0427.txt',
 'record-68.txt',
 '0221.txt',
 '0285.txt',
 'record-55.txt',
 '0005.txt',
 '0045.txt',
 '0021.txt',
 'record-123.txt',
 '0325.txt',
 'record-176.txt',
 '0433.txt',
 '0237_p2.txt',
 'record-84.txt',
 '0033.txt',
 'record-19.txt',
 '0412.txt',
 'record-73.txt',
 'record-141_p2.txt',
 '0097.txt',
 '0381.txt',
 'record-37_p1.txt',
 '0149_p2.txt',
 'record-32.txt',
 'record-140.txt',
 '0053.txt',
 'record-108.txt',
 'record-82.txt',
 'record-17.txt',
 'record-141_

In [7]:
def cosine_sim(a, b):
    return (a @ b.T) / (norm(a)*norm(b))
    
def get_similarity_score(ehr, trig_list):
    ehr_trigs = []
    pdf_trigs = []
    with open(ehr, "r", encoding="utf-8") as f:
        lines = f.readlines()
    ehr_trigs = [line[2:].strip() for line in lines]
    for txt in trig_list:
        with open(txt, "r", encoding="utf-8") as f:
            lines = f.readlines()
        pdf_trigs += [line[2:].strip() for line in lines]
    ehr_embeds = model.encode(ehr_trigs, device="cuda")
    pdf_embeds = model.encode(pdf_trigs, device="cuda")
    avg_scores = []
    scores_list = []
    for e in ehr_embeds:
        scores = []
        for pe in pdf_embeds:
            scores.append(cosine_sim(e, pe))
        scores_list.append(scores)
        avg_scores.append(sum(scores)/len(scores))
    idx = np.argmax(avg_scores)
    pdf_idx = np.argmax(scores_list[idx])
    relate_trig = pdf_trigs[pdf_idx]
    pdf = trig_list[int(pdf_idx/5)]
    best_trig = ehr_trigs[np.argmax(avg_scores)]
    return best_trig, relate_trig, pdf

In [8]:
for txt2 in txts_2:
    ehr_trig = "trigger_dataset_3k7/triggers/" + txt2
    if "_" in txt2:
        kw = "/" + re.sub(r"_[a-z0-9]+.txt", "/", txt2)
    else:
        kw = "/" + txt2.replace(".txt", "/")
    trig_list = [txt for txt in txts if kw in txt]
    best_trig, relate_trig, pdf = get_similarity_score(ehr_trig, trig_list)
    trig_dict = {}
    trig_dict["trigger"] = best_trig
    trig_dict["related_trigger"] = relate_trig
    trig_dict["context_pdf"] = pdf
    path = "trigger_dataset_3k7/triggers/" + txt2.replace(".txt", ".json")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(trig_dict, f)

# Generate

In [6]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import pandas as pd
import torch
import os
import json
from tqdm import tqdm

In [2]:
llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
device = torch.device("cuda")
llama_model = llama_model.to(device)
llama_model.eval()
text_gen = pipeline(
    task="text-generation", 
    model=llama_model, 
    tokenizer=llama_tokenizer, 
    max_length=4000,
    device=device,
)

In [59]:
inst = "You're a doctor and you were given the following EMR by another doctor. You think that {trigger} is important."
template = "<s>[INST] <<SYS>>\n{inst}\n<<SYS>>\n\nGiven the following EMR, please find some content related to {trigger}.\n\nEMR: \"{document}\"\n\nContent found:\n[/INST]"

In [40]:
files = os.listdir("trigger_dataset_3k7/triggers")
json_files = [txt for txt in files if "json" in txt]
txt_files = [txt for txt in files if "txt" in txt]

In [61]:
def get_context(path, max_length, head=True):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    token_length = [len(llama_tokenizer.tokenize(line.strip())) for line in lines]
    if head:
        i = 0
        while sum(token_length[:i+1]) <= max_length and i <= len(lines):
            i = i + 1
        return "".join(lines[:i]), path
    else:
        i = -1
        while sum(token_length[i-1:]) <= max_length and i*(-1) <= len(lines):
            i = i - 1
        return "".join(lines[i:]), path

In [65]:
for j in tqdm(json_files):
    json_path = "trigger_dataset_3k7/triggers/" + j
    txt_path = "trigger_dataset_3k7/data/" + j.replace(".json", ".txt")
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    head = "p2" not in t
    doc,_ = get_context(txt_path, 3200, head=head)
    trig_1 = data["trigger"].split(":")[0].strip()
    instruct = inst.format(trigger=trig_1)
    prompt = template.format(inst=instruct, trigger=trig_1, document=doc)
    out = text_gen(prompt)
    out_path = txt_path.replace("/data/", "/ehr_context/")
    
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(out[0]['generated_text'])

100%|█████████████████████████████████████████████████████████████████████████████████| 122/122 [42:43<00:00, 21.01s/it]


# Question Generate

In [66]:
inst = "You're a doctor and you were given the following EMR by another doctor. You have some questions for the doctor who gave that EMR to you to get more details about the patient, \nYou have read an external text that contain additional information related to {trigger}: \n\"{auxiliary}\""
template = "<s>[INST] <<SYS>>\n{inst}\n<<SYS>>\n\nGiven the context in the EMR: \n\"{document}\"\n\nAfter reading the above EMR, what question do you have about \"{trigger}\"?\nQuestion:\n[/INST]"

In [67]:
def get_text(path):
    with open(path, "r", encoding="utf-8") as f:
        doc = f.read()
    return doc.split("[/INST]")[1].strip()

In [68]:
root = "results/llama2/7b-chat/generated_additional/"
files = os.listdir("trigger_dataset_3k7/triggers")
json_files = [txt for txt in files if "json" in txt]

for j in tqdm(json_files):
    json_path = "trigger_dataset_3k7/triggers/" + j
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    trigger = data["trigger"].split(":")[0].strip()
    external_path = "trigger_dataset_3k7/external_context/" + j.replace(".json", ".txt")
    ehr_path = "trigger_dataset_3k7/ehr_context/" + j.replace(".json", ".txt")
    auxiliary = get_text(external_path)
    document = get_text(ehr_path)
    instruct = inst.format(trigger=trigger, auxiliary=auxiliary)
    prompt = template.format(inst=instruct, trigger=trigger, document=document)
    
    out = text_gen(prompt)
    out_path = root + j.replace(".json", ".txt")
    
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(out[0]['generated_text'])

100%|█████████████████████████████████████████████████████████████████████████████████| 122/122 [28:53<00:00, 14.21s/it]
