# Simulation Experiment

> In this notebook we demonstrate the how to train the MixMIL model on data simulated under as specified in the paper in the Binomial likelihood setting. 

In [2]:
import numpy as np
import scipy.stats as st
import torch
from sklearn.metrics import roc_auc_score

from mixmil import MixMIL
from mixmil.data import load_data
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Utility Functions

In [3]:
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy.stats as st

def _calc_metrics(u_pred, w_pred, u, w):
    """
    Calculate correlation and AUC metrics using real and predicted instance weights.

    Parameters:
    - u_pred (numpy.ndarray): Predicted instance-level weights.
    - w_pred (numpy.ndarray): Predicted instance-level weights as instance proportions.
    - u (numpy.ndarray): True instance-level weights.
    - w (numpy.ndarray): True instance-level weights as instance proportions.

    Returns:
    - rho_bag (float): Weight correlation (Spearman's rank correlation coefficient).
    - auc_instance (float): Instance retrieval AUC (Area Under the Receiver Operating Characteristic curve).
    """
    rho_bag = st.spearmanr(u_pred, u).correlation  # bag level correlation
    is_top_instance = (w > np.quantile(w, 0.90)).long().ravel()
    auc_instance = roc_auc_score(is_top_instance, w_pred)  # instance-retrieval AUC
    return rho_bag, auc_instance


def calc_metrics(model, X, u, w):
    """
    Calculate aggregated metrics over multiple bags or instances for a given model.

    Parameters:
    - model: Trained model.
    - X (dict): Dictionary containing input data.
    - u (dict): Dictionary containing true values for instance-level weights.
    - w (dict): Dictionary containing true values for instance-level weights as proportions.

    Returns:
    - res_dict (dict): Dictionary containing aggregated metrics including:
        - 'rho_bag' (float): Mean bag-level weight correlation.
        - 'rho_bag_err' (float): Standard error of the mean for bag-level correlation.
        - 'auc_instance' (float): Mean instance retrieval AUC.
        - 'auc_instance_err' (float): Standard error of the mean for instance retrieval AUC.
    """
    u_pred = model.predict(X["test"]).cpu().numpy()
    w_pred = model.get_weights(X["test"])[0].cpu().numpy()

    P = u_pred.shape[1]
    if P > 0:
        rho_bag, auc_instance = [], []
        for i in range(P):
            _rho_bag, _auc_instance = _calc_metrics(
                u_pred[..., i], w_pred[..., i].ravel(), u["test"][..., i], w["test"][..., i].ravel()
            )
            rho_bag.append(_rho_bag)
            auc_instance.append(_auc_instance)

        res_dict = {
            "rho_bag": np.mean(rho_bag),
            "rho_bag_err": np.std(rho_bag) / np.sqrt(P),
            "auc_instance": np.mean(auc_instance),
            "auc_instance_err": np.std(auc_instance) / np.sqrt(P),
        }
    else:
        rho_bag, auc_instance = _calc_metrics(u_pred, w_pred.ravel(), u["test"], w["test"].ravel())
        res_dict = {"rho_bag": rho_bag, "auc_instance": auc_instance}
    return res_dict

def print_metrics(prefix, metrics):
    """
    Print a formatted representation of metrics with a specified prefix.

    Parameters:
    - prefix (str): Prefix to be added to the printed metrics, for better identification.
    - metrics (dict): Dictionary containing metrics to be printed.

    Returns:
    - None: This function prints the metrics to the console without returning any value.
    """
    print(f"{prefix} metrics:")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    print()

## Training

Train model with simulated data under using a binomial likelihood.

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Simulate data as described in the paper
# embeddings, fixed effects, labels, sim bag predictions, sim instance weights
# P: number of outputs, simulated from the same embeddings X
X, F, Y, u, w = load_data(P=10, seed=0)
model = MixMIL.init_with_mean_model(X["train"], F["train"], Y["train"], likelihood="binomial", n_trials=2).to(device)
X, F, Y = [{key: val.to(device) for key, val in el.items()} for el in [X, F, Y]]

print_metrics("[START]", calc_metrics(model, X, u, w))
# Fit model in parallel to each output separately
model.train(X["train"], F["train"], Y["train"], n_epochs=2_000)
print_metrics("[END]", calc_metrics(model, X, u, w))

GLMM Init: 100%|██████████| 10/10 [00:07<00:00,  1.33it/s]


[START] metrics:
rho_bag: 0.6004
rho_bag_err: 0.0131
auc_instance: 0.5000
auc_instance_err: 0.0000



Epoch: 100%|██████████| 2000/2000 [04:30<00:00,  7.39it/s]

[END] metrics:
rho_bag: 0.8316
rho_bag_err: 0.0111
auc_instance: 0.9362
auc_instance_err: 0.0078




