In [None]:
import torch
import numpy as np
from tqdm.autonotebook import tqdm
from utils import train_on_subset, record_output, get_loader
np.random.seed(42)  # fix random seed for reproducibility

In [None]:
# 1. create 10 random subsets of the training set
#    (each containing a half of the samples)
train_set_subsets = []
for i in range(10):
    subset = np.random.choice(range(50_000), 25_000, replace=False)
    train_set_subsets.append(subset)

In [None]:
# 2. train a model on each subset
#    and record its output on a target example of choice
val_loader = get_loader(split="val")
target_example = val_loader.dataset[0]

outputs_per_subset = []
for subset in tqdm(train_set_subsets):
    model = train_on_subset(subset)
    out = record_output(model, target_example)
    outputs_per_subset.append(out)

In [None]:
# 3. get predicted model outputs from your attribution method for each subset;
#    here's where linearity comes into play, our prediction is the sum of
#    attribution scores across samples within the subset
dummy_attribution_scores = torch.randn(50_000)  # dummy scores for the tutorial

predictions_per_subset = []
for subset in train_set_subsets:
    prediction = dummy_attribution_scores[subset].sum()
    predictions_per_subset.append(prediction)

In [None]:
# 4. evaluate the rank-correlation between the true model outputs
#    and the predictions from our attribution method
from scipy.stats import spearmanr
LDS = spearmanr(outputs_per_subset, predictions_per_subset)
print(f'LDS: {LDS.correlation:.3f} (p value {LDS.pvalue:.6f})')

In practice, we would evaluate the correlations for many target examples,
and then average the correlations across the target examples. The code below implements the many-target version of LDS efficiently.

In [None]:
from pathlib import Path
import wget
from tqdm import tqdm

def eval_correlations(scores, tmp_path):
    masks_url = 'https://www.dropbox.com/s/x76uyen8ffkjfke/mask.npy?dl=1'
    margins_url = 'https://www.dropbox.com/s/q1dxoxw78ct7c27/val_margins.npy?dl=1'

    masks_path = Path(tmp_path).joinpath('mask.npy')
    wget.download(masks_url, out=str(masks_path), bar=None)
    # num masks, num train samples
    masks = torch.as_tensor(np.load(masks_path, mmap_mode='r')).float()

    margins_path = Path(tmp_path).joinpath('val_margins.npy')
    wget.download(margins_url, out=str(margins_path), bar=None)
    # num , num val samples
    margins = torch.as_tensor(np.load(margins_path, mmap_mode='r'))

    val_inds = np.arange(2_000)
    preds = masks @ scores
    rs = []
    ps = []
    for ind, j in tqdm(enumerate(val_inds)):
        r, p = spearmanr(preds[:, ind], margins[:, j])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')
    return rs.mean()