## import packages

In [None]:
import pickle

## load logits data (EHR model and MAS)

In [None]:
outs_dir = "response/mimic-iv_outcome_test_RETAIN_llama-3.3-70b-instruct_PubMed_MSD/agent_logits_outs.pkl"

with open(outs_dir, "rb") as f:
    data = pickle.load(f)

preds = data["preds"]
labels = data["labels"]
models = data["model"]

with open(val_outs_dir, "rb") as f:
    data = pickle.load(f)
val_preds = data["preds"]
val_labels = data["labels"]
val_models = data["model"]

print(len(preds), len(labels), len(models))


466 466 466


## Metrics_utils

In [None]:
import torch
from torchmetrics import AUROC, Accuracy, AveragePrecision
from torchmetrics.classification import BinaryF1Score
import numpy as np
from sklearn import metrics as sklearn_metrics


def minpse(preds, labels):
    precisions, recalls, _ = sklearn_metrics.precision_recall_curve(labels, preds)
    minpse_score = np.max([min(x, y) for (x, y) in zip(precisions, recalls)])
    return minpse_score


def get_binary_metrics(preds, labels):
    preds = torch.tensor(preds, dtype=torch.float32)
    labels = torch.tensor(labels, dtype=torch.float32)
    
    accuracy = Accuracy(task="binary", threshold=0.5)
    auroc = AUROC(task="binary")
    auprc = AveragePrecision(task="binary")
    f1 = BinaryF1Score()

    # convert labels type to int
    labels = labels.type(torch.int)
    accuracy(preds, labels)
    auroc(preds, labels)
    auprc(preds, labels)
    f1(preds, labels)

    # return a dictionary
    return {
        "auprc": auprc.compute().item(),
        "auroc": auroc.compute().item(),
        "minpse": minpse(preds, labels),
        "accuracy": accuracy.compute().item(),
        "f1": f1.compute().item(),
    }


def bootstrap(preds, labels, K=100, seed=42):
    """Bootstrap resampling for binary classification metrics. Resample K times"""
    np.random.seed(seed)
    n = len(preds)    
    # Initialize a list to store bootstrap samples
    samples = []
    # Create K bootstrap samples
    for _ in range(K):
        # Sample with replacement from the indices
        sample_idx = np.random.choice(n, n, replace=True)
        # Get bootstrap sample of preds and labels and store them
        samples.append((preds[sample_idx], labels[sample_idx]))
    return samples


def bootstrap_metrics(samples):
    metrics = {k: [] for k in ["auprc", "auroc", "minpse", "accuracy", "f1"]}
    for sample_p, sample_l in samples:
        res = get_binary_metrics(sample_p, sample_l)
        for k, v in res.items():
            metrics[k].append(v)
    # calculate mean and std
    for k, v in metrics.items():
        arr_v = np.array(v)
        metrics[k] = {"mean": np.mean(arr_v), "std": np.std(arr_v)}
    return metrics


def run_bootstrap(preds, labels, K=100, seed=42):
    bootstrap_samples = bootstrap(preds, labels, K=K, seed=seed)
    metrics = bootstrap_metrics(bootstrap_samples)
    return metrics

  from .autonotebook import tqdm as notebook_tqdm


## Result (val + CV + bootstrap)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

p_base = np.array(models)
p_llm = np.array(preds)
y_true = np.array(labels)

val_p_base = np.array(val_models)
val_p_llm = np.array(val_preds)
val_y_true = np.array(val_labels)

# ====== Fusion =======
X_train = np.stack([val_p_base, val_p_llm], axis=1)
y_train = val_y_true

X_test = np.stack([p_base, p_llm], axis=1)
y_test = y_true


# Step 1: Perform hyperparameter tuning on the train set
param_grid = {'C': [0.01, 0.1, 1, 10, 100]}
clf = LogisticRegression(max_iter=1000)
grid = GridSearchCV(clf, param_grid, cv=3)
grid.fit(X_train, y_train)
print("Best params:", grid.best_params_)

# Step 2: Retrain on the entire train set using the optimal parameters
best_model = LogisticRegression(C=grid.best_params_['C'], max_iter=1000)
best_model.fit(X_train, y_train)

# Step 3: Perform final evaluation on the test set
p_fused = best_model.predict_proba(X_test)[:, 1]


def show_metrics(name, preds, labels):
    metrics = run_bootstrap(preds, labels)
    metrics = {k: f"{v['mean']*100:.2f} ± {v['std']*100:.2f}" for k, v in metrics.items()}
    print(f"\n▶ {name}")
    for k,v in metrics.items():
        print(f"{k}: {v}")

# Baseline metrics
show_metrics("Baseline", p_base, y_true)
#show_metrics("Baseline", p_base_test, y_test)
# Fusion metrics
show_metrics("Fusion (baseline + llm)", p_fused, y_test)
# LLM prob metrics
show_metrics("LLM Prob", p_llm, y_true)



Best params: {'C': 100}

▶ Baseline
auprc: 54.21 ± 4.13
auroc: 63.89 ± 2.94
minpse: 55.44 ± 3.40
accuracy: 63.34 ± 2.31
f1: 45.58 ± 3.59

▶ Fusion (baseline + llm)
auprc: 54.33 ± 4.13
auroc: 64.30 ± 2.88
minpse: 55.73 ± 3.34
accuracy: 62.89 ± 2.15
f1: 47.65 ± 3.41

▶ LLM Prob
auprc: 54.00 ± 4.01
auroc: 64.06 ± 2.79
minpse: 54.54 ± 3.44
accuracy: 60.71 ± 2.26
f1: 56.61 ± 3.11
