In [1]:
import pandas as pd
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import os, sys, re, torch, json, glob, argparse, gc, ast, pickle, requests
import numpy as np
from tqdm import tqdm
from ast import literal_eval
from sentence_transformers import SentenceTransformer, util
from tokenizers import AddedToken
from peft import PeftModel, PeftConfig
from scripts.formatting_results import *
from scripts.bert_filtering import *
from scripts.negation import *
from scripts.prompting import *
from scripts.utils import *
from scripts.llama_vision_engine import *
#from scripts.llava_med_engine import *
device = "cuda" if torch.cuda.is_available() else "cpu"
gc.collect()
torch.cuda.empty_cache()

In [2]:
## Provide the model directory here
#model_id = './models/phenogpt2/'
model_id = '/home/nguyenqm/projects/github/PhenoGPT2/phenogpt2_EHR_L318B_text_FPFF/model'
## Set True if your model is ft with LoRA, otherwise False
lora = False
## Provide the input dictionary file
input_dir = './data/example/text_examples.json'
#input_dir = './data/example/vision_examples.json'
data_input = read_input(input_dir)
## Replication (you can try to run three times for a sample then combine them but default to be 0)
i = 0
## Whether you want to remove negated findings (note that higher recall when negation = False)
negation = False
## If you don't want to split then keep 'wc' as 0; otherwise provide word size you want for each chunk
wc = 0
if wc != 0: 
    bert_tokenizer, bert_model = bert_init(local_dir = "./models/bert_filtering/")

In [3]:
# Determine processing mode
use_text = True
use_vision = False
vision_model = 'llama-vision'
###
# Vision model setup (only if vision is enabled)
print(f"use_vision: {use_vision}")
if use_vision:
    #base_ckpt = "/mnt/isilon/wang_lab/shared/LlaMA3.2-vision-instruct"
    phenogpt2_vision = LLaMA_Generator(os.getcwd() + "/models/llama-vision-phenogpt2/")#, base_ckpt)
#     vision_model = vision_model.lower() if vision_model else "llama-vision"
#     if vision_model == "llava-med":
#         phenogpt2_vision = LLaVA_Generator(os.path.join(os.getcwd(), "llava-med-phenogpt2"))
#     elif vision_model == "llama-vision":
#         phenogpt2_vision = LLaMA_Generator(os.path.join(os.getcwd(), "llama-vision-phenogpt2"))
#     else:
#         raise ValueError(f"Unsupported vision model '{vision_model}'. Use 'llava-med' or 'llama-vision'.")

use_vision: False


In [4]:
if lora:
    peft_config = PeftConfig.from_pretrained(model_id)
    base_model_name = peft_config.base_model_name_or_path or "./models/hpo_aware_pretrain/"
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    model = PeftModel.from_pretrained(model, model_id)
else: # either full finetuning or merged LoRA
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
tokenizer_id = model_id
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast = True)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(146672, 4096, padding_idx=128256)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps

In [6]:
output = 'example_testing'
output_dir = os.getcwd() + f"/data/results/{output}/"

if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok = True)
print(output_dir)

/mnt/isilon/wang_lab/quan/projects/github_official/PhenoGPT2/data/results/example_testing/


In [2]:
all_responses = {}
for index, dt in tqdm(data_input.items()):
    all_responses[index] = {}
    if use_text:
        text = data_input[index]['clinical_note'].lower()
        if wc != 0:
            all_chunks = chunking_documents(text, bert_tokenizer, bert_model, word_count = wc)
        else:
            all_chunks = [text]
        temp_response = {}
        for para_id, chunk in enumerate(all_chunks):
            if len(all_chunks) > 1:
                pred_label = predict_label(bert_tokenizer, bert_model, {"text":chunk})
            else: # in case users only want to use the whole note for testing
                pred_label = 'INFORMATIVE'
            if pred_label == 'INFORMATIVE':
                raw_response = generate_output(model, tokenizer, chunk, temperature = 0.4, max_new_tokens = 5000, device = device)
                response = "{'demographics': {'age': '" + raw_response
                # Try first attempt
                try:
                    final_response = fix_and_parse_json(response)
                    phenos = final_response.get("phenotypes", {})
                    if not isinstance(phenos, dict) or len(phenos) == 0:
                        raise ValueError("Empty or invalid phenotype dict")
                except Exception:
                    # Retry with alternative prompt
                    try:
                        raw_response = generate_output(model, tokenizer, chunk, temperature=0.4, max_new_tokens=6000, alternative_prompt=True, device = device)
                        response = "{'demographics': {'age': '" + raw_response
                        final_response = fix_and_parse_json(response)
                        phenos = final_response.get("phenotypes", {})
                        if not isinstance(phenos, dict) or len(phenos) == 0:
                            raise ValueError("Empty or invalid phenotype dict after retry")
                    except Exception:
                        final_response = {'error_response': response}
                        final_response['pid'] = data_input[index].get('pid', data_input[index].get('pmid', 'unknown'))
                        temp_response[para_id] = final_response
                        continue  # move to the next item
                example_removed = ['cleft palate', 'seizures', 'dev delay'] ## here are phenotypes in one-shot in alternative prompt
                temp_phenotypes = {k:v for k,v in final_response['phenotypes'].items() if (k not in example_removed) or (k in chunk)}
                final_response['phenotypes'] = temp_phenotypes
                if negation:
                    phenotypes = list(final_response['phenotypes'].keys())
                    phenotypes = [p.lower() for p in phenotypes]
                    positive_phenotypes = remove_negation(model, tokenizer, chunk, phenotypes, device = device)
                    try:
                        phen_dict = {x:y for x,y in final_response['phenotypes'].items() if x in positive_phenotypes and "HP:" in y['HPO_ID']}
                    except:
                        phen_dict = {x:y for x,y in final_response['phenotypes'].items() if x in positive_phenotypes}
                else:
                    phen_dict = {}
                final_response['filtered_phenotypes'] = phen_dict
                if 'pid' in data_input[index]:
                    final_response['pid'] = data_input[index]['pid']
                else:
                    final_response['pid'] = data_input[index]['pmid']
                if 'demographics' in final_response.keys():
                    if ('age' not in final_response['demographics'].keys()) or (final_response['demographics']['age'] == '10-year-old' and '10-year-old' not in chunk):
                        final_response['demographics']['age'] = 'unknown'
                    if 'sex' not in final_response['demographics'].keys():
                        final_response['demographics']['sex'] = 'unknown'
                    if 'ethnicity' not in final_response['demographics'].keys() or (final_response['demographics']['ethnicity'].lower() in ['vietnamese', 'vietnam'] and ('vietnamese' not in chunk or 'vietnam' not in chunk)):
                        final_response['demographics']['ethnicity'] = 'unknown'
                    if 'race' not in final_response['demographics'].keys():
                        final_response['demographics']['race'] = 'unknown'
                temp_response[para_id] = final_response
        all_responses[index]['text'] = merge_outputs(temp_response)
    else:
        all_responses[index]['text'] = {}
    if use_vision:
        vision_phenotypes = phenogpt2_vision.generate_descriptions(dt['image'])
        phen2hpo = generate_output(model, tokenizer, vision_phenotypes, temperature = 0.4, max_new_tokens = 1024, device = device)
        phen2hpo = "{'demographics': {'age': '" + phen2hpo
        phen2hpo = fix_and_parse_json(phen2hpo)
        phen2hpo = phen2hpo.get("phenotypes", {})
        try:
            phen2hpo = {phen:hpo_dict['HPO_ID'] for phen,hpo_dict in phen2hpo.items()}
        except:
            phen2hpo = {}
        all_responses[index]['image'] = phen2hpo
    else:
        all_responses[index]['image'] = {}

In [10]:
with open(f'{output_dir}phenogpt2_rep{i}.json', 'w') as f:
    json.dump(all_responses, f, indent = 2)