## Import Packages

In [1]:
import pandas as pd
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import os, sys, re, torch, json, glob, argparse, gc, ast, pickle, requests
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
from ast import literal_eval
from sentence_transformers import SentenceTransformer, util
from tokenizers import AddedToken
from peft import PeftModel, PeftConfig

from torch.utils.data import Dataset, DataLoader

from scripts.formatting_results import *
from scripts.bert_filtering import *
from scripts.negation import *
from scripts.prompting import *
from scripts.utils import *
from scripts.llama_vision_engine import *
# from scripts.llava_med_engine import *


device = "cuda" if torch.cuda.is_available() else "cpu"
gc.collect()
torch.cuda.empty_cache()


## Define Helper Functions

In [2]:
def shard_dict(merged_output, index, num_shards=30):
    """
    Split a dictionary into `num_shards` parts using stable modulo sharding.
    If index == num_shards, return the remaining (unassigned) keys.
    """
    assert 0 <= index < num_shards

    keys = sorted(merged_output.keys())  # deterministic order

    shard_keys = [k for i, k in enumerate(keys) if i % num_shards == index]

    return {k: merged_output[k] for k in shard_keys}


class PhenoGPT2Dataset(Dataset):
    """
    Minimal dataset wrapper to allow DataLoader prefetching.
    NOTE: No logic changesâ€”this only changes how items are fed into the loop.
    """
    def __init__(self, data_input):
        self.data_input = data_input
        self.keys = list(data_input.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        k = self.keys[idx]
        return k, self.data_input[k]


def _collate_single(batch):
    # batch is a list of length 1 (batch_size=1)
    return batch[0]


def build_llm(args):
    ##set up model
    #Model
    if args.model_dir:
        model_id = args.model_dir
    else:
        model_id = os.getcwd() + '/models/phenogpt2'

    if args.lora:
        peft_config = PeftConfig.from_pretrained(model_id)
        # Get path to this file (inference.py)
        current_file = os.path.abspath(__file__)

        # Get path to phenogpt2 root (go up 2 levels: scripts/ -> phenogpt2/)
        project_root = os.path.dirname(current_file)

        # Get path to hpo_aware_pretrain
        hpo_aware_pretrain_dir = os.path.join(project_root, "models", "hpo_aware_pretrain")

        base_model_name = peft_config.base_model_name_or_path if os.path.isfile(peft_config.base_model_name_or_path) else hpo_aware_pretrain_dir
        model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation= args.attn_implementation
        )
        model = PeftModel.from_pretrained(model, model_id)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation= args.attn_implementation
        )

    #Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast = True)
    model.eval()

    config = model.config
    if config.model_type == 'llama':
        tokenizer.chat_template = tokenizer.chat_template = """{% for message in messages %}
    {% if message['role'] == 'system' %}
    <|start_header_id|>system<|end_header_id|>
    {{ message['content'] }}<|eot_id|>
    {% elif message['role'] == 'user' %}
    <|start_header_id|>user<|end_header_id|>
    {{ message['content'] }}<|eot_id|>
    {% elif message['role'] == 'assistant' %}
    <|start_header_id|>assistant<|end_header_id|>
    {{ message['content'] }}<|eot_id|>
    {% endif %}
    {% endfor %}
    {% if add_generation_prompt %}
    <|start_header_id|>assistant<|end_header_id|>
    {% endif %}
    """

    return model, tokenizer


def build_negation(args):
    if args.negation:
        negation_tokenizer = AutoTokenizer.from_pretrained(args.negation_model, use_fast = True)
        negation_model = AutoModelForCausalLM.from_pretrained(
            args.negation_model,
            dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation= args.attn_implementation
        )
        emb_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
    else:
        emb_model = None
        negation_tokenizer = None
        negation_model = None

    return negation_model, negation_tokenizer, emb_model


def infer_modes(args, data_input):
    # Determine processing mode
    use_text = use_vision = False
    if args.text_only:
        use_text = True
    elif args.vision_only:
        use_vision = True
    else:
        # Automatically infer mode based on data
        for dt in data_input.values():
            if pd.notnull(dt.get('clinical_note')): use_text = True
            if pd.notnull(dt.get('image')): use_vision = True
            if use_text and use_vision:
                break  # no need to continue scanning
    return use_text, use_vision


def build_vision(args, use_vision):
    # Vision model setup (only if vision is enabled)
    print(f"use_vision: {use_vision}")
    if use_vision:
        phenogpt2_vision = LLaMA_Generator(os.getcwd() + "/models/llama-vision-phenogpt2")
        # vision_model = args.vision.lower() if args.vision else "llava-med"
        # if vision_model == "llava-med":
        #     phenogpt2_vision = LLaVA_Generator(os.getcwd() + "/models/llava-med-phenogpt2")
        # elif vision_model == "llama-vision":
        #     phenogpt2_vision = LLaMA_Generator(os.getcwd() + "/models/llama-vision-phenogpt2")
        # else:
        #     raise ValueError(f"Unsupported vision model '{vision_model}'. Use 'llava-med' or 'llama-vision'.")
        return phenogpt2_vision
    return None


def process_one_item(index, dt, data_input, args, model, tokenizer, phenogpt2_vision,
                     use_text, use_vision, bert_tokenizer, bert_model,
                     negation, negation_model, negation_tokenizer, emb_model, wc):
    all_responses = {}
    all_responses[index] = {}

    if use_text:
        text = data_input[index]['clinical_note'].lower()
        if wc != 0:
            all_chunks = chunking_documents(text, bert_tokenizer, bert_model, word_count = wc)
        else:
            all_chunks = [text]
        temp_response = {}
        for para_id, chunk in enumerate(all_chunks):
            chunk = chunk.replace("'", ""). replace('"', '')
            if len(all_chunks) > 1:
                pred_label = predict_label(bert_tokenizer, bert_model, {"text":chunk})
            else: # in case users only want to use the whole note for testing
                pred_label = 'INFORMATIVE'
            if pred_label == 'INFORMATIVE':
                # Try first attempt
                response = generate_output(model, tokenizer, chunk, temperature = 0.3, max_new_tokens = 3000, device = device)
                try:
                    final_response, complete_check = valid_json(response)
                    phenos = final_response.get("phenotypes", {})
                    if not isinstance(phenos, dict) or len(phenos) == 0:
                        raise ValueError("Empty or invalid phenotype dict")
                except Exception:
                    try:
                        response = generate_output(model, tokenizer, chunk, temperature=0.4, max_new_tokens=5000, device = device)
                        final_response, complete_check = valid_json(response)
                        phenos = final_response.get("phenotypes", {})
                        if not isinstance(phenos, dict) or len(phenos) == 0:
                            raise ValueError("Empty or invalid phenotype dict after retry")
                    except Exception as e:
                        print(f"Error: {e}", flush = True)
                        final_response = {'error_response': response}
                        final_response['pid'] = data_input[index].get('pid', data_input[index].get('pmid', 'unknown'))
                        temp_response[para_id] = final_response
                        continue  # move to the next item
                if negation:
                    print('Starting detecting negation')
                    try:
                        negation_response = negation_detection(negation_model, negation_tokenizer, chunk, final_response, device = device, max_new_tokens = 10000)
                        final_response = process_negation(final_response, negation_response, complete_check, emb_model)
                    except:
                        final_response['filtered_phenotypes'] = {}
                else:
                    final_response['filtered_phenotypes'] = {}
                # if seen <= 10: ## You can comment this out for logging some early results
                #     if len(final_response['filtered_phenotypes']) > 0:
                #         print(final_response['filtered_phenotypes'], flush = True)
                #         print(final_response['negation_analysis'], flush = True)
                #     else:
                #         print(final_response['negation_analysis'], flush = True)
                temp_response[para_id] = final_response
        if len(temp_response) > 1: # if splitting notes into multiple chunks, now merge all
            all_responses[index]['text'] = merge_outputs(temp_response)
        else:
            temp_value = list(temp_response.values())
            if len(temp_value) > 0:
                all_responses[index]['text'] = temp_value[0] # use the whole note as input
            else:
                all_responses[index]['text'] = {}
    else:
        all_responses[index]['text'] = {}

    if use_vision:
        vision_phenotypes = phenogpt2_vision.generate_descriptions(dt['image'])
        phen2hpo = generate_output(model, tokenizer, vision_phenotypes, temperature = 0.4, max_new_tokens = 1024, device = device)
        phen2hpo = "{'demographics': {'age': '" + phen2hpo
        phen2hpo = valid_json(phen2hpo)
        phen2hpo = phen2hpo.get("phenotypes", {})
        try:
            phen2hpo = {phen:hpo_dict['HPO_ID'] for phen,hpo_dict in phen2hpo.items()}
        except:
            phen2hpo = {}
        all_responses[index]['image'] = phen2hpo
    else:
        all_responses[index]['image'] = {}

    return all_responses[index]

## Set Up Your Input here

In [3]:
## Input & Output Directory
input_dir = "/home/nguyenqm/projects/PhenoGPT2/testing/phenotagger/input/GSC+/"
output_dir = "/home/nguyenqm/projects/github/PhenoGPT2/phenogpt2_qwen3_ehr_8b_ft_nofilter/model/evaluations/GSC+_sample"

## Directory to the PhenoGPT2 fine-tuned weights
model_dir = "/home/nguyenqm/projects/github/PhenoGPT2/phenogpt2_qwen3_ehr_8b_ft_nofilter/model"

## Flash Attention helps faster inference and lower GPU memory but it may not work for ARM-based system. Use "spda" or "eager" instead
attn_implementation='flash_attention_2'

## Provide the directory to LoRA weights if available; otherwise set False
lora=False

## Specify if you want to remove false positives (highly recommended)
negation=True

## Specify the NEGATION model name
negation_model_name = "Qwen/Qwen3-4B-Instruct-2507"

## This is always on! (unless you want to run vision analysis only)
text_only=True

## This is always off! (unless you want to run vision analysis)
vision_only=False

## Specify vision model: llama-vision, qwen-vision, and llava-med
vision='llama-vision'

## Chunking word per paragraph (recommend to use 300 for long clinical notes); otherwise keep it at 0!
wc = 0

## If you want to run multiple instances of results (mostly use for SLURM JOB ARRAY)
index = 0

args = argparse.Namespace(
    input=input_dir,
    output=output_dir,
    model_dir=model_dir,
    lora=lora,
    index=index,
    negation=negation,
    negation_model=negation_model_name,
    attn_implementation=attn_implementation,
    text_only=text_only,
    vision_only=vision_only,
    vision=vision,
    wc=wc
)

## Loading Models

In [4]:
## Load PhenoGPT2
model, tokenizer = build_llm(args)
## Load Negation pipeline
negation_model, negation_tokenizer, emb_model = build_negation(args)

## Run BERT model for filtering non-phenotypic chunks
wc = args.wc
if wc != 0:
    bert_tokenizer, bert_model = bert_init(local_dir = "./models/bert_filtering/")
else:
    bert_tokenizer, bert_model = None, None

Loading weights:   0%|          | 0/399 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/310 [00:00<?, ?it/s]

## Start PhenoGPT2

In [5]:
print('start phenogpt2')
output_dir = args.output

print(output_dir)
if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok = True)

## Process Input:
data_input = read_input(args.input)

# Load extracted results
out_path = f"{args.output}/phenogpt2_rep{args.index}.pkl"
print(out_path, flush=True)

use_text, use_vision = infer_modes(args, data_input)

phenogpt2_vision = build_vision(args, use_vision)

i = args.index
negation = args.negation
all_responses = {}

start phenogpt2
/home/nguyenqm/projects/github/PhenoGPT2/phenogpt2_qwen3_ehr_8b_ft_nofilter/model/evaluations/GSC+_sample
/home/nguyenqm/projects/github/PhenoGPT2/phenogpt2_qwen3_ehr_8b_ft_nofilter/model/evaluations/GSC+_sample/phenogpt2_rep0.pkl
use_vision: False


## Running PhenoGPT2

In [8]:
# ----------------------------
# DataLoader wrapper (prefetch)
# ----------------------------
dataset = PhenoGPT2Dataset(data_input)
loader = DataLoader(
    dataset,
    batch_size=5,
    shuffle=False,
    num_workers=min(8, max(0, (os.cpu_count() or 4) - 1)),
    pin_memory=True,
    persistent_workers=True if (os.cpu_count() or 0) > 1 else False,
    collate_fn=_collate_single,
    prefetch_factor=4 if (os.cpu_count() or 0) > 1 else None,
)

seen=0
for index, dt in tqdm(loader):
    all_responses[index] = {}
    result = process_one_item(
        index=index,
        dt=dt,
        data_input=data_input,
        args=args,
        model=model,
        tokenizer=tokenizer,
        phenogpt2_vision=phenogpt2_vision,
        use_text=use_text,
        use_vision=use_vision,
        bert_tokenizer=bert_tokenizer,
        bert_model=bert_model,
        negation=negation,
        negation_model=negation_model,
        negation_tokenizer=negation_tokenizer,
        emb_model=emb_model,
        wc=wc
    )
    all_responses[index] = result
#     if seen <= 10:
#         print(all_responses[index], flush=True)
    seen += 1

In [None]:
with open(f'{out_path}', 'wb') as f:
    pickle.dump(all_responses, f)