In [None]:
import torch

import utils
import saliency

In [None]:
def load_model(model_id, name="", subdir="1"):
    # read model
    exp_dict = utils.load_experiment_data(mlrun_id=model_id, default_id=subdir)
    feature_names = exp_dict["feature_names"]
    target_names = exp_dict["target_names"]
    data = exp_dict["data"]
    target = exp_dict["target"]
    pred = exp_dict["pred"]
    args = exp_dict["args"]
    model = exp_dict["model"]

In [None]:
labor_id = 'b71e463e5252431b915954bb91d3d2e4'

phase_id = 'e075eaaf10fe48f2aa3d738cc4495fd7'
icp_id = 'f0d90ecf970f467fa23204e6f4490bd3'

In [None]:
torch.backends.cudnn.benchmark = True
exp_dict = utils.load_experiment_data(mlrun_id=labor_id, default_id="1")
feature_names = exp_dict["feature_names"]
model = exp_dict["model"]
model.cuda()
model.eval()

phase = "long"
phase_idx = 1 if phase == "long" else 0

sal_kwargs = {}

In [None]:
sal = saliency.get_sal_list(model, phase_idx, 1.0, False, ds=None, ig=False, **sal_kwargs)

In [None]:
median_train = saliency.calc_median(model.train_dataloader().dataset)
median_train.shape

In [None]:
sorted(torch.tensor([0.2, 0.6, 0.15, 0.05]).cumsum(dim=0))

In [None]:
import pandas as pd

def get_top_cumsum_idcs(sals, frac, reverse=False):
    # returns the sorted saliency values and indices of them that explain in sum at least the given frac of importance
    sal_series = pd.Series(sals).sort_values(ascending=True)
    sal_series /= sal_series.abs().sum()
    if reverse:
        mask = sal_series.cumsum() <= frac
    else:
        mask = sal_series.cumsum() > 1 - frac
    chosen_vals = sal_series[mask]
    chosen_vals = chosen_vals.sort_values(ascending=False)
    sort_vals = chosen_vals.to_numpy()
    sort_idcs = chosen_vals.index.to_numpy()
    return sort_vals, sort_idcs

In [None]:
array = [0.2, 0.6, 0.15, 0.05]
print(sorted(array, reverse=True))
print(get_top_cumsum_idcs(array, 0.61))
print(get_top_cumsum_idcs(array, 1.0, reverse=True))

In [None]:
from sklearn.metrics import average_precision_score

In [None]:
import numpy as np
import sklearn
from sklearn.metrics import average_precision_score


def perturb_batch(batch_data, batch_sal, lens, median, noise_std, top_n, perturb_frac, reverse, average_over_patient=False):
    if not top_n and not perturb_frac:
        return  batch_data, []
    removed_feat_numbers = []
    for pat_data, pat_sal, pat_len in zip(batch_data, batch_sal, lens):
        # cut to proper len
        pat_data = pat_data[:pat_len]
        pat_sal = pat_sal[:pat_len]
        # determine which to perturb for each patient in batch
        # calc mean ranking for patient if averaging
        if average_over_patient:
            pat_sal = torch.tensor(pat_sal).abs().mean(dim=0)
            pat_sal = pat_sal / pat_sal.sum()
        
        # calc noise idcs
        flat_sal = torch.tensor(pat_sal).flatten().abs()
        noise_idcs = None
        if top_n:
            # select top N features
            used_top_n = min(top_n, len(flat_sal))
            top_k = torch.topk(flat_sal, used_top_n, largest=not reverse)
            sal_vals = top_k.values
            noise_idcs = top_k.indices
        elif perturb_frac:
            # select features such that a certain percentage of importance is explained
            sal_vals, noise_idcs = get_top_cumsum_idcs(flat_sal, perturb_frac, reverse=reverse)
            removed_feat_numbers.append(len(noise_idcs))
        # map averaged idcs to whole timeseries
        num_steps = pat_sal.shape[0]
        num_feats = pat_sal.shape[1]
        if average_over_patient:
            # init tensor with correct number of elements
            new_noise_idcs = torch.ones(num_steps, len(noise_idcs))
            # multiply by time component
            new_noise_idcs *= torch.arange(num_steps).unsqueeze(1)
            # add feature component
            new_noise_idcs += noise_idcs
            # replace
            noise_idcs = new_noise_idcs
        # map flat noise idcs to separate time and feat idcs
        time_idcs = noise_idcs // num_feats
        feat_idcs = noise_idcs % num_feats

        # Input perturbation
        # potentially set some feature to median
        if median is not None:
            if median.shape[0] < num_steps:
                median = median.repeat(num_steps, 1)
            #print(pat_data.shape, median.shape)
            #print(time_idcs)
            #print(feat_idcs)
            #print()
            pat_data[time_idcs, feat_idcs] = median[time_idcs, feat_idcs]
        # potentially add noise
        if noise_idcs is not None and noise_std != 0.0:
            noise_tensor = torch.zeros(len(time_idcs)).normal_(std=noise_std)
            #print(noise_tensor.shape)
            #print(pat_data.shape)
            #print(pat_data[time_idcs, feat_idcs].shape)
            #print(time_idcs)
            #print(feat_idcs)
            #print()
            pat_data[time_idcs, feat_idcs] += noise_tensor
    return batch_data, removed_feat_numbers


def eval_model(model, sal, ds=None, dl=None, batch_size=512, median=None, noise_std=0.0, top_n=0, perturb_frac=0.0, reverse=False):
    # get dataloader
    if ds is None:
        ds = model.val_dataloader().dataset
    if dl is None:
        dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    
    val_preds = []
    val_targets = []
    #top_n_num = int(ds[0][0].shape[-1] * perturb_frac)
    #if top_n:
    #    print(top_n_num)
    total_removed_feats = []
    batch_size = dl.batch_size
    for sal_idx, (pat_data, pat_target, idx, lens) in enumerate(dl):    
        # perturb input
        batch_sal = sal[sal_idx * batch_size: sal_idx * batch_size + batch_size]
        pat_data, removed_feat_numbers = perturb_batch(pat_data, batch_sal, lens, median, noise_std, top_n, perturb_frac, reverse)
        total_removed_feats.extend(removed_feat_numbers)
        # pred
        pat_data = pat_data.cuda(non_blocking=True)
        with torch.no_grad():
            pat_pred = model(pat_data)
        # select phase
        pat_target = pat_target[:, :, phase_idx].cpu()
        pat_pred = pat_pred[:, :, phase_idx].cpu()
        # cut lens
        for idx, len_ in enumerate(lens):
            cut_pred = pat_pred[idx, :len_].flatten()
            cut_target = pat_target[idx, :len_].flatten()
            val_preds.append(cut_pred)
            val_targets.append(cut_target)
    if perturb_frac:
        print("Removed on average: ", np.array(total_removed_feats).mean())
    # flatten list to array
    val_preds = np.array([pred for pat in val_preds for pred in pat])
    val_targets = np.array([target for pat in val_targets for target in pat])
    # remove nans
    nan_mask = np.isnan(val_targets)
    val_targets = val_targets[~nan_mask]
    val_preds = val_preds[~nan_mask]
    # calc score
    val_preds = torch.sigmoid(torch.tensor(val_preds)).numpy()
    score = average_precision_score(val_targets, val_preds)
    return score, val_preds

In [None]:
from tqdm import tqdm


def calc_progression(model, sal, steps, use_top_n, eval_args, reverse=False, ds=None, dl=None, batch_size=128):
    if ds is None:
        ds = model.val_dataloader().dataset
    if dl is None:
        dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    # calc baseline first
    base_score, base_preds = eval_model(model, sal, dl=dl, median=None, noise_std=0.0, top_n=0, perturb_frac=0.0)
    # get perturb range
    if use_top_n:
        perturb_range = list(range(1, steps + 1))
    else:
        step = 1 / steps
        perturb_range = list(np.linspace(step, 1, steps))
    # calc progression
    perturb_params = [0.0] + perturb_range
    score_dict = {"logit": [0.0],
                  "abs_logit": [0.0],
                  "ap_score": [0.0],
                  "abs_ap_score": [0.0]}
    for val in tqdm(perturb_range):
        if use_top_n:
            eval_args["top_n"] = val
        else:
            eval_args["perturb_frac"] = val
        # calc perturbation
        ap_score, val_preds = eval_model(model, sal, dl=dl, reverse=reverse, **eval_args)
        # calc and store scores
        score_dict["abs_logit"].append(np.abs(val_preds - base_preds).mean())
        score_dict["logit"].append((val_preds - base_preds).mean())
        score_dict["ap_score"].append(ap_score - base_score)
        score_dict["abs_ap_score"].append(abs(ap_score - base_score))
    score_dict = {key: np.array(score_dict[key]) for key in score_dict}
    return np.array(perturb_params), score_dict


def calc_MoRF(model, sal, steps, use_top_n, eval_args):
    return calc_progression(model, sal, steps, use_top_n, eval_args, reverse=False)


def calc_LeRF(model, sal, steps, use_top_n, eval_args):
    return calc_progression(model, sal, steps, use_top_n, eval_args, reverse=True)


def calc_ABPC(model, sal, steps, use_top_n, eval_args):
    params, morf_dict = calc_MoRF(model, sal, steps, use_top_n, eval_args)
    params, lerf_dict = calc_LeRF(model, sal, steps, use_top_n, eval_args)
    
    abpc_dict = {key: lerf_dict[key] - morf_dict[key] for key in morf_dict}
    
    return params, morf_dict, lerf_dict, abpc_dict


In [None]:
eval_args = {"median": median_train,
             "noise_std": 0.0}
steps = 5
use_top_n = True

params, morf_dict, lerf_dict, abpc_dict = calc_ABPC(model, sal
                                                    , steps, use_top_n, eval_args)
for key in abpc_dict:
    print(key, np.round(abpc_dict[key], 3), round(np.mean(abpc_dict[key]), 3))

In [None]:
eval_args = {"median": None,
             "noise_std": 1.0}
steps = 5
use_top_n = True

params, morf_dict, lerf_dict, abpc_dict = calc_ABPC(model, sal, steps, use_top_n, eval_args)
for key in abpc_dict:
    print(key, np.round(abpc_dict[key], 3), round(np.mean(abpc_dict[key]), 3))

In [None]:
torch.zeros(1, 1).normal_(std=1.0).shape

In [None]:
eval_args = {"median": None,
             "noise_std": 3.0}
steps = 5
use_top_n = True

params, morf_dict, lerf_dict, abpc_dict = calc_ABPC(model, sal, steps, use_top_n, eval_args)
for key in abpc_dict:
    print(key, np.round(abpc_dict[key], 3), round(np.mean(abpc_dict[key]), 3))

In [None]:
eval_args = {"median": median_train,
             "noise_std": 0.0}
steps = 5
use_top_n = False

params, morf_dict, lerf_dict, abpc_dict = calc_ABPC(model, sal, steps, use_top_n, eval_args)
for key in abpc_dict:
    print(key, np.round(abpc_dict[key], 3), round(np.mean(abpc_dict[key]), 3))

In [None]:
eval_args = {"median": None,
             "noise_std": 3.0}
steps = 5
use_top_n = False

params, morf_dict, lerf_dict, abpc_dict = calc_ABPC(model, sal, steps, use_top_n, eval_args)
for key in abpc_dict:
    print(key, np.round(abpc_dict[key], 3), round(np.mean(abpc_dict[key]), 3))

In [None]:
eval_args = {"median": median_train,
             "noise_std": 0.0}
steps = 5
use_top_n = True

print("MoRF")
params, score_dict = calc_MoRF(model, sal, steps, use_top_n, eval_args)
for key in score_dict:
    print(key, np.round(score_dict[key], 3))
print("LeRF")
params, score_dict = calc_LeRF(model, sal, steps, use_top_n, eval_args)
for key in score_dict:
    print(key, np.round(score_dict[key], 3))

In [None]:
eval_args = {"median": median_train,
             "noise_std": 0.0}
steps = 5
use_top_n = False

print("MoRF")
params, score_dict = calc_MoRF(model, sal, steps, use_top_n, eval_args)
for key in score_dict:
    print(key, np.round(score_dict[key], 3))
print("LeRF")
params, score_dict = calc_LeRF(model, sal, steps, use_top_n, eval_args)
for key in score_dict:
    print(key, np.round(score_dict[key], 3))

In [None]:
# EVAL by adding noise

In [None]:
noise_std = 1.0
top_n = 0
score_baseline, preds_baseline = eval_model(model, sal, noise_std=noise_std, batch_size=128, median=None, top_n=top_n, perturb_frac=0.0)
eval_score, eval_preds = eval_model(model, sal, noise_std=noise_std, batch_size=128, median=None, top_n=top_n, perturb_frac=0.02)
fully_random, fully_random_preds = eval_model(model, sal, noise_std=noise_std, batch_size=128, median=None, top_n=top_n, perturb_frac=1.0)
print(score_baseline)
print(eval_score)#, abs(eval_preds - preds_baseline).mean())
print(fully_random)#, abs(fully_random_preds - preds_baseline).mean())

In [None]:
import matplotlib.pyplot as plt
plt.figure()
p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds, bins=1000, color="red", alpha=0.8)
p = plt.hist(fully_random_preds, bins=1000, color="orange", alpha=0.5)
print(preds_baseline.mean())
print(eval_preds.mean())
print(fully_random_preds.mean())
plt.show()

plt.figure()
#p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds - preds_baseline, bins=1000, color="red")
p = plt.hist(fully_random_preds - preds_baseline, bins=1000, color="orange")
print(np.abs(eval_preds - preds_baseline).mean())
print(np.abs(fully_random_preds - preds_baseline).mean())
plt.show()

In [None]:
# EVAL BY MEDIAN
score_baseline, preds_baseline = eval_model(model, sal, batch_size=128, median=median_train, top_n=0, perturb_frac=0.0)
eval_score, eval_preds = eval_model(model, sal, batch_size=128, median=median_train, top_n=0, perturb_frac=0.5)
fully_random, fully_random_preds = eval_model(model, sal, batch_size=128, median=median_train, top_n=0, perturb_frac=1.0)
print(score_baseline)
print(eval_score)#, abs(eval_preds - preds_baseline).mean())
print(fully_random)#, abs(fully_random_preds - preds_baseline).mean())

In [None]:
import matplotlib.pyplot as plt
plt.figure()
p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds, bins=1000, color="red", alpha=0.8)
p = plt.hist(fully_random_preds, bins=1000, color="orange", alpha=0.5)
print(preds_baseline.mean())
print(eval_preds.mean())
print(fully_random_preds.mean())
plt.show()

plt.figure()
#p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds - preds_baseline, bins=1000, color="red")
p = plt.hist(fully_random_preds - preds_baseline, bins=1000, color="orange")
print(np.abs(eval_preds - preds_baseline).mean())
print(np.abs(fully_random_preds - preds_baseline).mean())
plt.show()

In [None]:
# same sal for all pats
uniform_sal = torch.stack([np.abs(torch.tensor(pat_sal)).mean(dim=0) for pat_sal in sal])
uniform_sal /= uniform_sal.sum(dim=1, keepdim=True)
uniform_sal = uniform_sal.mean(dim=0)
uniform_sal /= uniform_sal.sum(dim=0, keepdim=True)
uniform_sal = uniform_sal.numpy()
# bring into original shape
uniform_sal_list = [np.stack([uniform_sal] * len(pat_sal)) for i, pat_sal in enumerate(sal)]

In [None]:
# EVAL BY MEDIAN
score_baseline, preds_baseline = eval_model(model, uniform_sal_list, batch_size=128, median=median_train, top_n=True, perturb_frac=0.0)
eval_score, eval_preds = eval_model(model, uniform_sal_list, batch_size=128, median=median_train, top_n=True, perturb_frac=0.02)
fully_random, fully_random_preds = eval_model(model, uniform_sal_list, batch_size=128, median=median_train, top_n=True, perturb_frac=1.0)
print(score_baseline)
print(eval_score)#, abs(eval_preds - preds_baseline).mean())
print(fully_random)#, abs(fully_random_preds - preds_baseline).mean())

In [None]:
import matplotlib.pyplot as plt
plt.figure()
p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds, bins=1000, color="red", alpha=0.8)
p = plt.hist(fully_random_preds, bins=1000, color="orange", alpha=0.5)
print(preds_baseline.mean())
print(eval_preds.mean())
print(fully_random_preds.mean())
plt.show()

plt.figure()
#p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds - preds_baseline, bins=1000, color="red")
p = plt.hist(fully_random_preds - preds_baseline, bins=1000, color="orange")
print(np.abs(eval_preds - preds_baseline).mean())
print(np.abs(fully_random_preds - preds_baseline).mean())
plt.show()

In [None]:
# EVAL BY noise
noise_std = 1.0
score_baseline, preds_baseline = eval_model(model, uniform_sal_list, noise_std=noise_std, batch_size=128, median=None, top_n=True, perturb_frac=0.0)
eval_score, eval_preds = eval_model(model, uniform_sal_list, noise_std=noise_std, batch_size=128, median=None, top_n=True, perturb_frac=0.02)
fully_random, fully_random_preds = eval_model(model, uniform_sal_list, noise_std=noise_std, batch_size=128, median=None, top_n=True, perturb_frac=1.0)
print(score_baseline)
print(eval_score)#, abs(eval_preds - preds_baseline).mean())
print(fully_random)#, abs(fully_random_preds - preds_baseline).mean())

In [None]:
import matplotlib.pyplot as plt
plt.figure()
p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds, bins=1000, color="red", alpha=0.8)
p = plt.hist(fully_random_preds, bins=1000, color="orange", alpha=0.5)
print(preds_baseline.mean())
print(eval_preds.mean())
print(fully_random_preds.mean())
plt.show()

plt.figure()
#p = plt.hist(preds_baseline, bins=1000)
p = plt.hist(eval_preds - preds_baseline, bins=1000, color="red")
p = plt.hist(fully_random_preds - preds_baseline, bins=1000, color="orange")
print(np.abs(eval_preds - preds_baseline).mean())
print(np.abs(fully_random_preds - preds_baseline).mean())
plt.show()