# Simulation Experiment

In [48]:
import numpy as np
import scipy.stats as st
import torch
from sklearn.metrics import roc_auc_score

from mixmil import MixMIL
from mixmil.data import load_data
import pandas as pd

## Utility Functions

In [50]:
def _calc_metrics(u_pred, w_pred, u, w):
    rho_bag = st.spearmanr(u_pred, u).correlation  # bag level correlation
    is_top_instance = (w > np.quantile(w, 0.90)).long().ravel()
    auc_instance = roc_auc_score(is_top_instance, w_pred)  # instance-retrieval AUC
    return rho_bag, auc_instance


def calc_metrics(model, X, u, w):
    u_pred = model.predict(X["test"]).cpu().numpy()
    w_pred = model.get_weights(X["test"])[0].cpu().numpy()

    P = u_pred.shape[1]
    if P > 0:
        rho_bag, auc_instance = [], []
        for i in range(P):
            _rho_bag, _auc_instance = _calc_metrics(
                u_pred[..., i], w_pred[..., i].ravel(), u["test"][..., i], w["test"][..., i].ravel()
            )
            rho_bag.append(_rho_bag)
            auc_instance.append(_auc_instance)

        res_dict = {
            "rho_bag": np.mean(rho_bag),
            "rho_bag_err": np.std(rho_bag) / np.sqrt(P),
            "auc_instance": np.mean(auc_instance),
            "auc_instance_err": np.std(auc_instance) / np.sqrt(P),
        }
    else:
        rho_bag, auc_instance = _calc_metrics(u_pred, w_pred.ravel(), u["test"], w["test"].ravel())
        res_dict = {"rho_bag": rho_bag, "auc_instance": auc_instance}
    return res_dict

In [52]:
def print_metrics(prefix, metrics):
    print(f"{prefix} metrics:")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
    print()

## Training

In [56]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Simulate data as described in the paper
# embeddings, fixed effects, labels, sim bag predictions, sim instance weights
# P: number of outputs, simulated from the same embeddings X
X, F, Y, u, w = load_data(P=10, seed=0)
model = MixMIL.init_with_mean_model(X["train"], F["train"], Y["train"], likelihood="binomial", n_trials=2).to(device)
X, F, Y = [{key: val.to(device) for key, val in el.items()} for el in [X, F, Y]]

print_metrics("[START]", calc_metrics(model, X, u, w))
# Fit model in parallel to each output separately
model.train(X["train"], F["train"], Y["train"], n_epochs=2_000)
print_metrics("[END]", calc_metrics(model, X, u, w))

GLMM Init:   0%|          | 0/10 [00:00<?, ?it/s]

[START] metrics:
rho_bag: 0.6004
rho_bag_err: 0.0131
auc_instance: 0.5000
auc_instance_err: 0.0000



Epoch:   0%|          | 0/2000 [00:00<?, ?it/s]

[END] metrics:
rho_bag: 0.8318
rho_bag_err: 0.0112
auc_instance: 0.9361
auc_instance_err: 0.0078

