In [27]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='2,3'
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import pickle 
import string
from vllm import LLM, SamplingParams
from collections import OrderedDict
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import gc
import re
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score
import pprint
import openai
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import json

# Data

In [48]:
data = pd.read_csv("/home/srinivasb/snap/snapd-desktop-integration/current/Documents/headneck/manuscript/data/499_llama_labeled_H&N_patients_temporal_data.csv")
data = data[['MRN', 'pre_completion_notes', 'pre_completion_ED_flag_LLAMA',
             '0_3_months_notes', '0_3_months_ED_flag_LLAMA', '3_6_months_notes', '3_6_months_ED_flag_LLAMA']]
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 499 entries, 0 to 498
Data columns (total 7 columns):
 #   Column                        Non-Null Count  Dtype 
---  ------                        --------------  ----- 
 0   MRN                           499 non-null    int64 
 1   pre_completion_notes          499 non-null    object
 2   pre_completion_ED_flag_LLAMA  499 non-null    object
 3   0_3_months_notes              499 non-null    object
 4   0_3_months_ED_flag_LLAMA      499 non-null    object
 5   3_6_months_notes              499 non-null    object
 6   3_6_months_ED_flag_LLAMA      499 non-null    object
dtypes: int64(1), object(6)
memory usage: 27.4+ KB


In [54]:
import pandas as pd

# List of ED flag columns
ed_flag_cols = [
    # 'pre_completion_ED_flag_LLAMA',
    # '0_3_months_ED_flag_LLAMA',
    '3_6_months_ED_flag_LLAMA'
]

# Count rows where any of the ED flag columns has the value "ED visit"
rows_with_ed = data[ed_flag_cols].apply(lambda row: 'ED Visit' in row.values, axis=1)

# Number of such rows
num_rows = rows_with_ed.sum()
print(num_rows)

104


In [29]:
data.head()

Unnamed: 0,MRN,pre_completion_notes,pre_completion_ED_flag_LLAMA,0_3_months_notes,0_3_months_ED_flag_LLAMA,3_6_months_notes,3_6_months_ED_flag_LLAMA
0,193402,DEPARTMENT OF ...,ED Visit,Swallowing Therapy#7 swallow therapy: S: He ...,No ED Visit,OTOLARYNGOLOGY- HEAD AND NECK SURGERY FACIAL ...,No ED Visit
1,532380,Head and Neck Surgery - NEW PATIENT EVALUATION...,No ED Visit,INTERVAL HISTORY AND PHYSICAL and RISKS / BENE...,No ED Visit,Head and Neck Surgery - ESTABLISHED PATIENT VI...,No ED Visit
2,909785,"July 28, 2017 Trudell Mixon 301 Butte St. ...",ED Visit,Chief Complaint: SCC of lower gingiva Histo...,No ED Visit,Chief Complaint: SCC of lower gingiva Histo...,No ED Visit
3,1133331,Immunization Data from STOR which was last u...,No ED Visit,"Primary Diagnosis: MIBC, prior malignancies (...",No ED Visit,"Coddington, Steven Douglas 80 year old...",No ED Visit
4,1597647,Subjective Chief Complaint Patient presents...,ED Visit,INTERVAL HISTORY AND PHYSICAL and RISKS / BENE...,ED Visit,4/2/2013 CC: diarrhea and recent colitis ...,ED Visit


In [30]:
with open('/home/srinivasb/snap/snapd-desktop-integration/current/Documents/headneck/manuscript/data/embeddings/pre_completion_notes_embeddings.pkl', 'rb') as f:
    pre_completion_embeddings = pickle.load(f)

with open('/home/srinivasb/snap/snapd-desktop-integration/current/Documents/headneck/manuscript/data/embeddings/0_3_months_notes_embeddings.pkl', 'rb') as f:
    zero_3_months_embeddings = pickle.load(f)

with open('/home/srinivasb/snap/snapd-desktop-integration/current/Documents/headneck/manuscript/data/embeddings/3_6_months_notes_embeddings.pkl', 'rb') as f:
    three_6_months_embeddings = pickle.load(f)

data['pre_completion_notes_embedding'] = list(pre_completion_embeddings)
data['0_3_months_notes_embedding'] = list(zero_3_months_embeddings)
data['3_6_months_notes_embedding'] = list(three_6_months_embeddings)
X_pre = np.stack(data['pre_completion_notes_embedding'].values)
X_0_3 = np.stack(data['0_3_months_notes_embedding'].values)
X_3_6 = np.stack(data['3_6_months_notes_embedding'].values)

X = np.stack([X_pre, X_0_3, X_3_6], axis=1)

y = np.array([
    data['pre_completion_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1),
    data['0_3_months_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1),
    data['3_6_months_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1),
]).T

In [31]:
# Remove rows where all values in y are 0
rows_with_zeros = np.all(y == 0, axis=1)

indices_to_remove = np.where(rows_with_zeros)[0]

num_to_remove = int(len(indices_to_remove) * 1.0)
random_indices_to_remove = np.random.choice(indices_to_remove, size=num_to_remove, replace=False)

X = np.delete(X, random_indices_to_remove, axis=0)
y = np.delete(y, random_indices_to_remove, axis=0)
X.shape, y.shape

((481, 3, 1024), (481, 3))

In [32]:
data = data.drop(index=random_indices_to_remove).reset_index(drop=True)
data.to_csv('data/481_H&N_patients_temporal_data.csv', index=False)

# Each TimeStep

In [33]:
data['pre_completion_ED_flag_LLAMA'] = data['pre_completion_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1)
data['0_3_months_ED_flag_LLAMA'] = data['0_3_months_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1)
data['3_6_months_ED_flag_LLAMA'] = data['3_6_months_ED_flag_LLAMA'].apply(lambda x: 0 if x == "ED Visit" else 1)

In [8]:
llm = LLM("gradientai/Llama-3-8B-Instruct-262k", tensor_parallel_size=2, download_dir="/data/users/srinivasb/vllm-llama3-download-dir")
sampling_params = SamplingParams(temperature=0, max_tokens=20, logprobs=10)

INFO 04-22 12:56:38 config.py:890] Defaulting to use mp for distributed inference
INFO 04-22 12:56:38 config.py:999] Chunked prefill is enabled with max_num_batched_tokens=512.
INFO 04-22 12:56:38 llm_engine.py:213] Initializing an LLM engine (v0.6.0) with config: model='gradientai/Llama-3-8B-Instruct-262k', speculative_config=None, tokenizer='gradientai/Llama-3-8B-Instruct-262k', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=262144, download_dir='/data/users/srinivasb/vllm-llama3-download-dir', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(o

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 04-22 12:56:46 model_runner.py:926] Loading model weights took 7.5435 GB
[1;36m(VllmWorkerProcess pid=1649746)[0;0m INFO 04-22 12:56:46 model_runner.py:926] Loading model weights took 7.5435 GB
INFO 04-22 12:56:48 distributed_gpu_executor.py:57] # GPU blocks: 34270, # CPU blocks: 4096
[1;36m(VllmWorkerProcess pid=1649746)[0;0m INFO 04-22 12:56:54 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
[1;36m(VllmWorkerProcess pid=1649746)[0;0m INFO 04-22 12:56:54 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 04-22 12:56:54 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to un

# Timepoint 0
---

In [9]:
X = np.array(data['pre_completion_notes'].tolist())
y = data['pre_completion_ED_flag_LLAMA'].values

In [10]:
X.shape, y.shape

((481,), (481,))

In [11]:
from sklearn.model_selection import KFold

skf = KFold(n_splits=5, shuffle=True, random_state=42)

folds = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X)):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]
    
    folds.append({
        'fold': fold,
        'X_train': X_tr,
        'y_train': y_tr,
        'X_val': X_val,
        'y_val': y_val
    })

In [None]:
results_folds = []
auc_data_folds = []
y_val_folds_cleaned = []

max_words = 80000
indices_to_remove = []

for i, fold in enumerate(folds):
    X_val = fold['X_val']
    y_val = fold['y_val']
    
    print(f"Fold {i}")
    
    results = []
    auc_data = []
    cleaned_y_val = []
    
    for i, temp in enumerate(X_val):
        # Pre-filter long notes
        if isinstance(temp, str) and len(temp.split()) > max_words:
            continue  # Skip long notes
        
        # Append valid label
        cleaned_y_val.append(y_val[i])
        
        # Construct the prompt
        text_prompt = '''<|start_header_id|>user<|end_header_id|>You are an oncologist at a major cancer hospital, tasked with predicting hospital emergency department (ED) visits for patients.  
        I am going to provide you with clinical notes for a head and neck cancer patient collected until primary treatment completion date. Here are the notes: ''' 
        text_prompt += str(temp) 
        text_prompt += '''\nPlease analyze notes carefully. Based on this analysis, will the patient have an ED visit to the hospital? Respond with either 'POSITIVE' (if the patient is likely to have ED visit) or 'NEGATIVE' (if the patient is unlikely to have ED visit). Please respond with 'POSITIVE' or 'NEGATIVE' only.''' 
        text_prompt += "<|eot_id|><|start_header_id|>ANSWER: " 

        torch.cuda.empty_cache()
        output = llm.generate(text_prompt, sampling_params)
        del text_prompt
        
        if output and output[0].outputs:
            res = output[0].outputs[0].text
            print(output[0].outputs[0])
            results.append(res)
        else:
            print("Error: LLM output is empty or improperly structured.", i)

        # AUC elements ***********************************************
        correct_answer_token = output[0].outputs[0].token_ids[0]
        wrong_answer_tokens_func = lambda correct_answer_tokens: 27592 if correct_answer_tokens == 85165 else 85165
        wrong_answer_token = wrong_answer_tokens_func(correct_answer_token)

        # Logit for the wrong answer
        all_logprobs = output[0].outputs[0].logprobs
        for logprob_dict in all_logprobs:
            if wrong_answer_token in logprob_dict:
                wrong_answer_logit = logprob_dict[wrong_answer_token].logprob
            if correct_answer_token in logprob_dict:
                correct_answer_logit = logprob_dict[correct_answer_token].logprob

        new_entry = {'correct_logit': correct_answer_logit, 'wrong_logit': wrong_answer_logit}
        auc_data.append(new_entry)
        # ************************************************************

    # Append results and AUC data for this fold
    results_folds.append(results)
    auc_data_folds.append(auc_data)
    y_val_folds_cleaned.append(cleaned_y_val)


In [None]:
# Metrics for each fold
accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []

for fold_idx in range(5):
    preds = []
    results = results_folds[fold_idx]
    labels = y_val_folds_cleaned[fold_idx]
    auc_data = auc_data_folds[fold_idx]

    df_test = pd.DataFrame()
    df_test['label'] = labels[:len(results)]

    for num, i in enumerate(results):
        if 'POS' in i:
            preds.append(1)
        elif 'NEG' in i:
            preds.append(0)
        else:
            preds.append(2)
            print(f"⚠️ Fold {fold_idx} – Uncertain prediction at index {num}")

    df_test['prediction'] = preds

    # Filter out uncertain predictions
    df_new = df_test[df_test['prediction'] != 2].reset_index(drop=True)

    # Skip this fold if no valid predictions
    if len(df_new) == 0:
        print(f"⚠️ Skipping fold {fold_idx}: No valid predictions.")
        continue

    # Classification metrics
    acc = accuracy_score(df_new['label'], df_new['prediction'])
    prec = precision_score(df_new['label'], df_new['prediction'], zero_division=0)
    rec = recall_score(df_new['label'], df_new['prediction'], zero_division=0)
    f1 = f1_score(df_new['label'], df_new['prediction'], zero_division=0)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)

    # AUC calculation
    true_labels = []
    predicted_probs = []

    for data, label, pred in zip(auc_data, df_new['label'], df_new['prediction']):
        if pred == 0:
            logit_0 = data['correct_logit']
            logit_1 = data['wrong_logit']
        else:
            logit_0 = data['wrong_logit']
            logit_1 = data['correct_logit']

        logits = np.array([logit_0, logit_1])
        probs = softmax(logits)

        true_labels.append(label)
        predicted_probs.append(probs[1])  # Probability for class 1

    try:
        auc = roc_auc_score(true_labels, predicted_probs)
        aucs.append(auc)
    except:
        print(f"⚠️ Fold {fold_idx} – AUC could not be computed.")

# ----------------------------
# Final Report
# ----------------------------
def summarize(metric_list):
    return f"{np.mean(metric_list):.4f} ± {np.std(metric_list):.4f}"

print("\n==============================")
print("✅ 5-Fold Cross Validation Summary")
print("==============================")
print("Accuracy :", summarize(accuracies))
print("Precision:", summarize(precisions))
print("Recall   :", summarize(recalls))
print("F1 Score :", summarize(f1s))
print("AUC      :", summarize(aucs))


✅ 5-Fold Cross Validation Summary
Accuracy : 0.6255 ± 0.0852
Precision: 0.3491 ± 0.0657
Recall   : 0.4823 ± 0.0827
F1 Score : 0.4035 ± 0.0685
AUC      : 0.5632 ± 0.0365


# Timepoint 1
---

In [21]:
X = np.array(data['0_3_months_notes'].tolist())
y = data['0_3_months_ED_flag_LLAMA'].values

In [22]:
from sklearn.model_selection import KFold

skf = KFold(n_splits=5, shuffle=True, random_state=42)

folds = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X)):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]
    
    folds.append({
        'fold': fold,
        'X_train': X_tr,
        'y_train': y_tr,
        'X_val': X_val,
        'y_val': y_val
    })

In [None]:
results_folds = []
auc_data_folds = []
y_val_folds_cleaned = []

max_words = 80000
indices_to_remove = []

for i, fold in enumerate(folds):
    X_val = fold['X_val']
    y_val = fold['y_val']
    
    print(f"Fold {i}")
    
    results = []
    auc_data = []
    cleaned_y_val = []
    
    for i, temp in enumerate(X_val):
        # Pre-filter long notes
        if isinstance(temp, str) and len(temp.split()) > max_words:
            continue  # Skip long notes
        
        # Append valid label
        cleaned_y_val.append(y_val[i])
        
        # Construct the prompt
        text_prompt = '''<|start_header_id|>user<|end_header_id|>You are an oncologist at a major cancer hospital, tasked with predicting hospital emergency department (ED) visits for patients.  
        I am going to provide you with clinical notes for a head and neck cancer patient collected until primary treatment completion date. Here are the notes: ''' 
        text_prompt += str(temp) 
        text_prompt += '''\nPlease analyze notes carefully. Based on this analysis, will the patient have an ED visit to the hospital? Respond with either 'POSITIVE' (if the patient is likely to have ED visit) or 'NEGATIVE' (if the patient is unlikely to have ED visit). Please respond with 'POSITIVE' or 'NEGATIVE' only.''' 
        text_prompt += "<|eot_id|><|start_header_id|>ANSWER: " 

        torch.cuda.empty_cache()
        output = llm.generate(text_prompt, sampling_params)
        del text_prompt
        
        if output and output[0].outputs:
            res = output[0].outputs[0].text
            print(output[0].outputs[0])
            results.append(res)
        else:
            print("Error: LLM output is empty or improperly structured.", i)

        # AUC elements ***********************************************
        correct_answer_token = output[0].outputs[0].token_ids[0]
        wrong_answer_tokens_func = lambda correct_answer_tokens: 27592 if correct_answer_tokens == 85165 else 85165
        wrong_answer_token = wrong_answer_tokens_func(correct_answer_token)

        # Logit for the wrong answer
        all_logprobs = output[0].outputs[0].logprobs
        for logprob_dict in all_logprobs:
            if wrong_answer_token in logprob_dict:
                wrong_answer_logit = logprob_dict[wrong_answer_token].logprob
            if correct_answer_token in logprob_dict:
                correct_answer_logit = logprob_dict[correct_answer_token].logprob

        new_entry = {'correct_logit': correct_answer_logit, 'wrong_logit': wrong_answer_logit}
        auc_data.append(new_entry)
        # ************************************************************

    # Append results and AUC data for this fold
    results_folds.append(results)
    auc_data_folds.append(auc_data)
    y_val_folds_cleaned.append(cleaned_y_val)


In [None]:
# Metrics for each fold
accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []

for fold_idx in range(5):
    preds = []
    results = results_folds[fold_idx]
    labels = y_val_folds_cleaned[fold_idx]
    auc_data = auc_data_folds[fold_idx]

    df_test = pd.DataFrame()
    df_test['label'] = labels[:len(results)]

    for num, i in enumerate(results):
        if 'POS' in i:
            preds.append(1)
        elif 'NEG' in i:
            preds.append(0)
        else:
            preds.append(2)
            print(f"⚠️ Fold {fold_idx} – Uncertain prediction at index {num}")

    df_test['prediction'] = preds

    # Filter out uncertain predictions
    df_new = df_test[df_test['prediction'] != 2].reset_index(drop=True)

    # Skip this fold if no valid predictions
    if len(df_new) == 0:
        print(f"⚠️ Skipping fold {fold_idx}: No valid predictions.")
        continue

    # Classification metrics
    acc = accuracy_score(df_new['label'], df_new['prediction'])
    prec = precision_score(df_new['label'], df_new['prediction'], zero_division=0)
    rec = recall_score(df_new['label'], df_new['prediction'], zero_division=0)
    f1 = f1_score(df_new['label'], df_new['prediction'], zero_division=0)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)

    # AUC calculation
    true_labels = []
    predicted_probs = []

    for data, label, pred in zip(auc_data, df_new['label'], df_new['prediction']):
        if pred == 0:
            logit_0 = data['correct_logit']
            logit_1 = data['wrong_logit']
        else:
            logit_0 = data['wrong_logit']
            logit_1 = data['correct_logit']

        logits = np.array([logit_0, logit_1])
        probs = softmax(logits)

        true_labels.append(label)
        predicted_probs.append(probs[1])  # Probability for class 1

    try:
        auc = roc_auc_score(true_labels, predicted_probs)
        aucs.append(auc)
    except:
        print(f"⚠️ Fold {fold_idx} – AUC could not be computed.")

# ----------------------------
# Final Report
# ----------------------------
def summarize(metric_list):
    return f"{np.mean(metric_list):.4f} ± {np.std(metric_list):.4f}"

print("\n==============================")
print("✅ 5-Fold Cross Validation Summary")
print("==============================")
print("Accuracy :", summarize(accuracies))
print("Precision:", summarize(precisions))
print("Recall   :", summarize(recalls))
print("F1 Score :", summarize(f1s))
print("AUC      :", summarize(aucs))


✅ 5-Fold Cross Validation Summary
Accuracy : 0.2634 ± 0.0306
Precision: 0.6834 ± 0.1153
Recall   : 0.1708 ± 0.0127
F1 Score : 0.2716 ± 0.0187
AUC      : 0.3155 ± 0.0289


# Timepoint 2
---

In [77]:
X = np.array(data['3_6_months_notes'].tolist())
y = data['3_6_months_ED_flag_LLAMA'].values

In [78]:
from sklearn.model_selection import KFold

skf = KFold(n_splits=5, shuffle=True, random_state=42)

folds = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X)):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]
    
    folds.append({
        'fold': fold,
        'X_train': X_tr,
        'y_train': y_tr,
        'X_val': X_val,
        'y_val': y_val
    })

In [None]:
results_folds = []
auc_data_folds = []
y_val_folds_cleaned = []

max_words = 80000
indices_to_remove = []

for i, fold in enumerate(folds):
    X_val = fold['X_val']
    y_val = fold['y_val']
    
    print(f"Fold {i}")
    
    results = []
    auc_data = []
    cleaned_y_val = []
    
    for i, temp in enumerate(X_val):
        # Pre-filter long notes
        if isinstance(temp, str) and len(temp.split()) > max_words:
            continue  # Skip long notes
        
        # Append valid label
        cleaned_y_val.append(y_val[i])
        
        # Construct the prompt
        text_prompt = '''<|start_header_id|>user<|end_header_id|>You are an oncologist at a major cancer hospital, tasked with predicting hospital emergency department (ED) visits for patients.  
        I am going to provide you with clinical notes for a head and neck cancer patient collected until primary treatment completion date. Here are the notes: ''' 
        text_prompt += str(temp) 
        text_prompt += '''\nPlease analyze notes carefully. Based on this analysis, will the patient have an ED visit to the hospital? Respond with either 'POSITIVE' (if the patient is likely to have ED visit) or 'NEGATIVE' (if the patient is unlikely to have ED visit). Please respond with 'POSITIVE' or 'NEGATIVE' only.''' 
        text_prompt += "<|eot_id|><|start_header_id|>ANSWER: " 

        torch.cuda.empty_cache()
        output = llm.generate(text_prompt, sampling_params)
        del text_prompt
        
        if output and output[0].outputs:
            res = output[0].outputs[0].text
            print(output[0].outputs[0])
            results.append(res)
        else:
            print("Error: LLM output is empty or improperly structured.", i)

        # AUC elements ***********************************************
        correct_answer_token = output[0].outputs[0].token_ids[0]
        wrong_answer_tokens_func = lambda correct_answer_tokens: 27592 if correct_answer_tokens == 85165 else 85165
        wrong_answer_token = wrong_answer_tokens_func(correct_answer_token)

        # Logit for the wrong answer
        all_logprobs = output[0].outputs[0].logprobs
        for logprob_dict in all_logprobs:
            if wrong_answer_token in logprob_dict:
                wrong_answer_logit = logprob_dict[wrong_answer_token].logprob
            if correct_answer_token in logprob_dict:
                correct_answer_logit = logprob_dict[correct_answer_token].logprob

        new_entry = {'correct_logit': correct_answer_logit, 'wrong_logit': wrong_answer_logit}
        auc_data.append(new_entry)
        # ************************************************************

    # Append results and AUC data for this fold
    results_folds.append(results)
    auc_data_folds.append(auc_data)
    y_val_folds_cleaned.append(cleaned_y_val)


In [None]:
# Metrics for each fold
accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []

for fold_idx in range(5):
    preds = []
    results = results_folds[fold_idx]
    labels = y_val_folds_cleaned[fold_idx]
    auc_data = auc_data_folds[fold_idx]

    df_test = pd.DataFrame()
    df_test['label'] = labels[:len(results)]

    for num, i in enumerate(results):
        if 'POS' in i:
            preds.append(1)
        elif 'NEG' in i:
            preds.append(0)
        else:
            preds.append(2)
            print(f"⚠️ Fold {fold_idx} – Uncertain prediction at index {num}")

    df_test['prediction'] = preds

    # Filter out uncertain predictions
    df_new = df_test[df_test['prediction'] != 2].reset_index(drop=True)

    # Skip this fold if no valid predictions
    if len(df_new) == 0:
        print(f"⚠️ Skipping fold {fold_idx}: No valid predictions.")
        continue

    # Classification metrics
    acc = accuracy_score(df_new['label'], df_new['prediction'])
    prec = precision_score(df_new['label'], df_new['prediction'], zero_division=0)
    rec = recall_score(df_new['label'], df_new['prediction'], zero_division=0)
    f1 = f1_score(df_new['label'], df_new['prediction'], zero_division=0)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)

    # AUC calculation
    true_labels = []
    predicted_probs = []

    for data, label, pred in zip(auc_data, df_new['label'], df_new['prediction']):
        if pred == 0:
            logit_0 = data['correct_logit']
            logit_1 = data['wrong_logit']
        else:
            logit_0 = data['wrong_logit']
            logit_1 = data['correct_logit']

        logits = np.array([logit_0, logit_1])
        probs = softmax(logits)

        true_labels.append(label)
        predicted_probs.append(probs[1])  # Probability for class 1

    try:
        auc = roc_auc_score(true_labels, predicted_probs)
        aucs.append(auc)
    except:
        print(f"⚠️ Fold {fold_idx} – AUC could not be computed.")

# ----------------------------
# Final Report
# ----------------------------
def summarize(metric_list):
    return f"{np.mean(metric_list):.4f} ± {np.std(metric_list):.4f}"

print("\n==============================")
print("✅ 5-Fold Cross Validation Summary")
print("==============================")
print("Accuracy :", summarize(accuracies))
print("Precision:", summarize(precisions))
print("Recall   :", summarize(recalls))
print("F1 Score :", summarize(f1s))
print("AUC      :", summarize(aucs))


✅ 5-Fold Cross Validation Summary
Accuracy : 0.2988 ± 0.0456
Precision: 0.8102 ± 0.0428
Recall   : 0.2442 ± 0.0334
F1 Score : 0.3735 ± 0.0385
AUC      : 0.3735 ± 0.0493


# All Timepoints
---

In [None]:
notes_melted = pd.melt(
    data,
    id_vars=['MRN'],
    value_vars=['pre_completion_notes', '0_3_months_notes', '3_6_months_notes'],
    var_name='timepoint',
    value_name='notes'
)

flags_melted = pd.melt(
    data,
    id_vars=['MRN'],
    value_vars=['pre_completion_ED_flag_LLAMA', '0_3_months_ED_flag_LLAMA', '3_6_months_ED_flag_LLAMA'],
    var_name='timepoint',
    value_name='ed_flag'
)

notes_melted['timepoint'] = notes_melted['timepoint'].str.replace('_notes', '')
flags_melted['timepoint'] = flags_melted['timepoint'].str.replace('_ED_flag_LLAMA', '')

merged_df = pd.merge(notes_melted, flags_melted, on=['MRN', 'timepoint'])
merged_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1443 entries, 0 to 1442
Data columns (total 4 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   MRN        1443 non-null   int64 
 1   timepoint  1443 non-null   object
 2   notes      1443 non-null   object
 3   ed_flag    1443 non-null   int64 
dtypes: int64(2), object(2)
memory usage: 45.2+ KB


In [37]:
X = np.array(merged_df['notes'].tolist())
y = merged_df['ed_flag'].values

In [38]:
from sklearn.model_selection import KFold

skf = KFold(n_splits=5, shuffle=True, random_state=42)

folds = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X)):
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]
    
    folds.append({
        'fold': fold,
        'X_train': X_tr,
        'y_train': y_tr,
        'X_val': X_val,
        'y_val': y_val
    })

In [None]:
results_folds = []
auc_data_folds = []
y_val_folds_cleaned = []

max_words = 80000
indices_to_remove = []

for i, fold in enumerate(folds):
    X_val = fold['X_val']
    y_val = fold['y_val']
    
    print(f"Fold {i}")
    
    results = []
    auc_data = []
    cleaned_y_val = []
    
    for i, temp in enumerate(X_val):
        # Pre-filter long notes
        if isinstance(temp, str) and len(temp.split()) > max_words:
            continue  # Skip long notes
        
        # Append valid label
        cleaned_y_val.append(y_val[i])
        
        # Construct the prompt
        text_prompt = '''<|start_header_id|>user<|end_header_id|>You are an oncologist at a major cancer hospital, tasked with predicting hospital emergency department (ED) visits for patients.  
        I am going to provide you with clinical notes for a head and neck cancer patient collected until primary treatment completion date. Here are the notes: ''' 
        text_prompt += str(temp) 
        text_prompt += '''\nPlease analyze notes carefully. Based on this analysis, will the patient have an ED visit to the hospital? Respond with either 'POSITIVE' (if the patient is likely to have ED visit) or 'NEGATIVE' (if the patient is unlikely to have ED visit). Please respond with 'POSITIVE' or 'NEGATIVE' only.''' 
        text_prompt += "<|eot_id|><|start_header_id|>ANSWER: " 

        torch.cuda.empty_cache()
        output = llm.generate(text_prompt, sampling_params)
        del text_prompt
        
        if output and output[0].outputs:
            res = output[0].outputs[0].text
            print(output[0].outputs[0])
            results.append(res)
        else:
            print("Error: LLM output is empty or improperly structured.", i)

        # AUC elements ***********************************************
        correct_answer_token = output[0].outputs[0].token_ids[0]
        wrong_answer_tokens_func = lambda correct_answer_tokens: 27592 if correct_answer_tokens == 85165 else 85165
        wrong_answer_token = wrong_answer_tokens_func(correct_answer_token)

        # Logit for the wrong answer
        all_logprobs = output[0].outputs[0].logprobs
        for logprob_dict in all_logprobs:
            if wrong_answer_token in logprob_dict:
                wrong_answer_logit = logprob_dict[wrong_answer_token].logprob
            if correct_answer_token in logprob_dict:
                correct_answer_logit = logprob_dict[correct_answer_token].logprob

        new_entry = {'correct_logit': correct_answer_logit, 'wrong_logit': wrong_answer_logit}
        auc_data.append(new_entry)
        # ************************************************************

    # Append results and AUC data for this fold
    results_folds.append(results)
    auc_data_folds.append(auc_data)
    y_val_folds_cleaned.append(cleaned_y_val)


In [None]:
# Metrics for each fold
accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []

for fold_idx in range(5):
    preds = []
    results = results_folds[fold_idx]
    labels = y_val_folds_cleaned[fold_idx]
    auc_data = auc_data_folds[fold_idx]

    df_test = pd.DataFrame()
    df_test['label'] = labels[:len(results)]

    for num, i in enumerate(results):
        if 'POS' in i:
            preds.append(1)
        elif 'NEG' in i:
            preds.append(0)
        else:
            preds.append(2)
            print(f"⚠️ Fold {fold_idx} – Uncertain prediction at index {num}")

    df_test['prediction'] = preds

    # Filter out uncertain predictions
    df_new = df_test[df_test['prediction'] != 2].reset_index(drop=True)

    # Skip this fold if no valid predictions
    if len(df_new) == 0:
        print(f"⚠️ Skipping fold {fold_idx}: No valid predictions.")
        continue

    # Classification metrics
    acc = accuracy_score(df_new['label'], df_new['prediction'])
    prec = precision_score(df_new['label'], df_new['prediction'], zero_division=0)
    rec = recall_score(df_new['label'], df_new['prediction'], zero_division=0)
    f1 = f1_score(df_new['label'], df_new['prediction'], zero_division=0)

    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)

    # AUC calculation
    true_labels = []
    predicted_probs = []

    for data, label, pred in zip(auc_data, df_new['label'], df_new['prediction']):
        if pred == 0:
            logit_0 = data['correct_logit']
            logit_1 = data['wrong_logit']
        else:
            logit_0 = data['wrong_logit']
            logit_1 = data['correct_logit']

        logits = np.array([logit_0, logit_1])
        probs = softmax(logits)

        true_labels.append(label)
        predicted_probs.append(probs[1])  # Probability for class 1

    try:
        auc = roc_auc_score(true_labels, predicted_probs)
        aucs.append(auc)
    except:
        print(f"⚠️ Fold {fold_idx} – AUC could not be computed.")

# ----------------------------
# Final Report
# ----------------------------
def summarize(metric_list):
    return f"{np.mean(metric_list):.4f} ± {np.std(metric_list):.4f}"

print("\n==============================")
print("✅ 5-Fold Cross Validation Summary")
print("==============================")
print("Accuracy :", summarize(accuracies))
print("Precision:", summarize(precisions))
print("Recall   :", summarize(recalls))
print("F1 Score :", summarize(f1s))
print("AUC      :", summarize(aucs))


✅ 5-Fold Cross Validation Summary
Accuracy : 0.3623 ± 0.0416
Precision: 0.6131 ± 0.0909
Recall   : 0.2548 ± 0.0274
F1 Score : 0.3585 ± 0.0379
AUC      : 0.3604 ± 0.0473
