## Preprocessing

In [6]:
import json
import re
dataset = "mimic3"
task = "mortality"
output_path = "llm_finetune_data_multitask"
# mimic3_mortality_train_0_notes_checkpoint,

        
# ==========================================
# train data
ori_path = f"llm_finetune_data_ulti/{dataset}_{task}_train_0_notes_checkpoint.jsonl"

with open(ori_path, "r") as f:
    data = [json.loads(line) for line in f]
t = 0

    
instruction_prev_prev = "\nGiven the following task description, patient EHR context, similar patients, and retrieved medical knowledge, Please provide a step-by-step reasoning process that leads to the prediction outcome based on the patient's context and relevant medical knowledge.\nAfter the reasoning process, provide the prediction label (0/1)."
instruction_prev = "\nGiven the following task description, patient EHR context, similar patients, and retrieved medical knowledge..."
instruction_new_reason = "\n[Reasoning] Given the following task description, patient EHR context, similar patients, and retrieved medical knowledge, Please provide a step-by-step reasoning process that leads to the prediction outcome based on the patient's context and relevant medical knowledge.\nAfter the reasoning process, provide the prediction label (0/1)."
instruction_new_pred = "\n[Label Prediction] Given the following task description, patient EHR context, similar patients, and retrieved medical knowledge, Please directly predict the label (0/1).\n"


label_pred_data = []
reasoning_data = []
patient_id_pattern = re.compile(
    r"# Patient EHR Context #.*?Patient ID:\s*([^\s]+)",
    re.DOTALL
)

context_path = f"patient_context/base_context/patient_contexts_{dataset}_{task}_notes.json"
patient_data_path = f"ehr_prepare/pateint_{dataset}_{task}_physician_summary.json"
patient_data = json.load(open(patient_data_path))

for item in data:
    input_new = item["input"].replace(instruction_prev, instruction_new_reason)
        
    match = patient_id_pattern.search(input_new)
    if match:
        patient_id = match.group(1)
        t+= 1
    else:
        print('couldnt find')

    output_new = "# Prediction # " + str(patient_data[patient_id]['label'])
    # print(output_new)
    label_pred_data.append({"input": input_new, "output": output_new})
    
    # input_new = item["input"].replace(instruction_prev, instruction_new_reason)
    output_new = item["output"]
    reasoning_data.append({"input": input_new, "output": output_new})
    # break
    
    
data = reasoning_data
with open(f"{output_path}/{dataset}_{task}_train_notes.jsonl", "w") as f:
    for item in data:
        f.write(json.dumps(item) + "\n")
        


In [None]:
# reasoning_data

In [7]:
 # test data

ori_path = f"llm_finetune_data_ulti/{dataset}_{task}_test_0_notes_checkpoint.jsonl"


with open(ori_path, "r") as f:
    data = [json.loads(line) for line in f]
    
    
label_pred_data = []

for item in data:
    input_new = item["input"].replace(instruction_prev, instruction_new_pred)
    # output_new = item["output"][-1]
    match = patient_id_pattern.search(input_new)
    if match:
        patient_id = match.group(1)
        # print(patient_id)
        t+= 1
    else:
        print('couldnt find')

    output_new = "# Prediction # " +str(patient_data[patient_id]['label'])
    label_pred_data.append({"input": input_new, "output":output_new})
    
with open(f"{output_path}/{dataset}_{task}_test_notes.jsonl", "w") as f:
    for item in label_pred_data:
        f.write(json.dumps(item) + "\n")

In [5]:
import json

input_path = "../llm_finetune_data_multitask/mimic3_readmission30_test_notes.jsonl"
output_path = "alpaca_readmission30_notes_test.json"

alpaca_data = []

with open(input_path, "r") as infile:
    for line in infile:
        data = json.loads(line)
        output_raw = data["output"]
        # if "# Prediction #" in output_raw:
        #     prediction_part = output_raw.split("# Prediction #")[-1].strip()
        # else:
        prediction_part = output_raw.strip().replace("# Reasoning #\n\n", "")  # fallback if missing tag
        if "# Prediction #" in prediction_part:
            output = prediction_part.split("# Prediction #")[-1].strip()
        else:
            output = prediction_part
            
        if "# Reasoning #" in prediction_part:
            reasoning = prediction_part.split("# Reasoning #")[1].strip()
        else:
            reasoning = ""
            
        alpaca_entry = {
            "instruction": "Predict hospital readmission within 30 days based on EHR context.",
            "input": data["input"],
            "label": output,
            "Reasoning":reasoning
        }
        alpaca_data.append(alpaca_entry)
        # if len(alpaca_data) == 2:
        #     break

with open(output_path, "w") as outfile:
    json.dump(alpaca_data, outfile, indent=2)


## Inference

In [6]:
from unsloth import FastLanguageModel
import torch
from peft import PeftModel
from transformers import TextStreamer
import json
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import pandas as pd
# === SETTINGS ===
max_seq_length = 16384
dtype = None
load_in_4bit = True
lora_path = "lora_model_readmission30_notes"

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. Give label 0 if the patient will live, or 1 if he will die in hospital. Then give the reason behind it in 2 lines

### Instruction:
{}

### Input:
{}

### label:
{}

### Reasoning:
{}
"""

# === LOAD DATA ===
with open('alpaca_readmission30_notes_test.json', 'r') as jsonfile:
    admission = json.load(jsonfile)

# === LOAD BASE MODEL ===
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/llama-3-8b-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

# === LOAD LoRA ADAPTER ===
model = PeftModel.from_pretrained(
    base_model,
    model_id=lora_path,
    adapter_name="default"
)

# === PREPARE FOR INFERENCE ===
FastLanguageModel.for_inference(model)
model.eval()

import re
import os


# === INIT CSV ===
csv_file = "predictions_vs_ground_truth_test_readmission30_notes.csv"
fieldnames = ["patient_id", "ground_truth", "prediction", "match", "reasoning"]

# # # Write header first
with open(csv_file, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()


==((====))==  Unsloth 2025.5.1: Fast Llama patching. Transformers: 4.51.3.
   \\   /|    Tesla V100-SXM2-32GB. Num GPUs = 1. Max memory: 31.733 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: unsloth/llama-3-8b-bnb-4bit can only handle sequence lengths of at most 8192.
But with kaiokendev's RoPE scaling of 2.0, it can be magically be extended to 16384!


In [11]:
import re
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

results = []
y_true = []
y_pred = []

for i, entry in enumerate(admission[:]):
    input_text = entry["input"]
    GT = entry["label"]
    
    # Build prompt
    prompt = alpaca_prompt.format(entry["instruction"], input_text, "", "")
    
    # Extract patient ID (if available)
    match = re.search(r"# Patient EHR Context #\n\nPatient ID:\s*([^\s\n]+)", input_text)
    patient_id = match.group(1) if match else "UNKNOWN"

    # Generate output
    try:
        inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
        outputs = model.generate(**inputs, max_new_tokens=1000)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        print(f"[{i}] Skipped due to generation error: {e}")
        continue

    # Try to extract reasoning section
    reasoning_section = prediction.split("### Reasoning:")[-1].strip()
    
    label_match = re.search(r'# Prediction #\s*\n\s*(\d)', reasoning_section)

    if label_match:
        pred_label = label_match.group(1)
        try:
            y_true.append(int(GT))
            y_pred.append(int(pred_label))
        except Exception as e:
            print(f"[{i}] Label casting failed: {e}")
            continue
    # Attempt to extract label
    # label_match = re.search(r"(?:label\s*[: is]*\s*)([01])", prediction.lower())
    # if label_match:
    #     pred_label = label_match.group(1)
    #     try:
    #         y_true.append(int(GT))
    #         y_pred.append(int(pred_label))
    #     except Exception as e:
    #         print(f"[{i}] Label casting failed: {e}")
    #         continue
    else:
        pred_label = "0"
        print(f"[{i}] Could not parse label. Saved for manual review.")
            
    result = {
        "patient_id": patient_id,
        "ground_truth": GT,
        "prediction": pred_label,
        "match": str(GT) == str(pred_label),
        "reasoning": reasoning_section
    }
    print(result)
    results.append(result)

    # Save every 10 entries
    if (i + 1) % 3 == 0 or (i + 1) == len(admission):
        with open(csv_file, "a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=result.keys())
            writer.writerows(results)
        results = []
        # break

# === METRICS (Skip entries with unparsable labels) ===
valid_indices = [i for i, pred in enumerate(y_pred) if pred in [0, 1]]
y_true_valid = [y_true[i] for i in valid_indices]
y_pred_valid = [y_pred[i] for i in valid_indices]

if y_true_valid:
    accuracy = accuracy_score(y_true_valid, y_pred_valid)
    precision = precision_score(y_true_valid, y_pred_valid, zero_division=0)
    recall = recall_score(y_true_valid, y_pred_valid, zero_division=0)
    f1 = f1_score(y_true_valid, y_pred_valid, zero_division=0)
    try:
        roc_auc = roc_auc_score(y_true_valid, y_pred_valid)
    except ValueError:
        roc_auc = "N/A (only one class present)"

    # Print metrics
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print(f"ROC AUC:   {roc_auc}")
else:
    print("No valid predictions to compute metrics.")


In [None]:
# import re

# def parse_output(text):
#     prediction_label = None
#     prediction_reasoning = None

#     # Extract Prediction Label
#     match_label = re.search(r"#### Prediction Label:\s*(\d+)", text)
#     if match_label:
#         prediction_label = int(match_label.group(1))

#     # Extract Prediction Reasoning
#     match_reasoning = re.search(r"#### Prediction Reasoning:\s*(.*)", text, re.DOTALL)
#     if match_reasoning:
#         reasoning = match_reasoning.group(1).strip()
#         # If there's any section after Reasoning, clip it
#         end_match = re.search(r"#### ", reasoning)
#         if end_match:
#             reasoning = reasoning[:end_match.start()].strip()
#         prediction_reasoning = reasoning

#     return {
#         "label": prediction_label,
#         "reasoning": prediction_reasoning
#     }

# # Example usage
# # print(parsed)


## Performance Analaysis

In [42]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import pandas as pd
df_train_LLM = pd.read_csv('predictions_vs_ground_truth_train_mortality_notes.csv')
y_true = df_train_LLM["ground_truth"]
y_pred = df_train_LLM["prediction"]
reasoning = df_train_LLM['reasoning']

# # Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
conf_matrix = confusion_matrix(y_true, y_pred)
class_report = classification_report(y_true, y_pred, output_dict=True)

# Prepare report
report = {
    "accuracy": accuracy,
    "confusion_matrix": conf_matrix,
    "classification_report": class_report
}

report

{'accuracy': 0.6140380407338832,
 'confusion_matrix': array([[3357, 1879],
        [ 414,  291]]),
 'classification_report': {'0': {'precision': 0.8902147971360382,
   'recall': 0.6411382734912147,
   'f1-score': 0.7454202287110026,
   'support': 5236.0},
  '1': {'precision': 0.13410138248847928,
   'recall': 0.4127659574468085,
   'f1-score': 0.20243478260869566,
   'support': 705.0},
  'accuracy': 0.6140380407338832,
  'macro avg': {'precision': 0.5121580898122587,
   'recall': 0.5269521154690116,
   'f1-score': 0.4739275056598491,
   'support': 5941.0},
  'weighted avg': {'precision': 0.8004891689040016,
   'recall': 0.6140380407338832,
   'f1-score': 0.6809858339117891,
   'support': 5941.0}}}

In [43]:
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score


# print(classification_report(y_true, y_prd))

# Compute AUC-ROC (requires probability scores for positive class)
auc_roc = roc_auc_score(y_true, y_pred)
print(f"AUC-ROC: {auc_roc:.4f}")

# Compute AUC-PR (also known as Average Precision Score)
auc_pr = average_precision_score(y_true, y_pred)
print(f"AUC-PR: {auc_pr:.4f}")


AUC-ROC: 0.5270
AUC-PR: 0.1250
