In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import warnings
import captum
import matplotlib.pyplot as plt
import math
import networkx as nx
import numpy as np
import os.path
import pandas as pd
import seaborn as sns
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from captum.metrics import infidelity, infidelity_perturb_func_decorator
from collections import namedtuple
from copy import deepcopy
from sklearn.metrics import mean_squared_error
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from txai.datasets import SyntheticCancellationDataset
from txai.experiments import (
    GenericExperimenter,
    TabularInfidelityMetric,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Experiment Utilities

In [3]:
CONTINUOUS_DIM = 8
TOTAL_DIM = CONTINUOUS_DIM * 2
DATA_SEEDS = list(range(42, 42 + 5))
NETWORK_SEEDS = list(range(101, 101 + 5*5, 5))
NUM_EPOCHS = 2000
Experiment = namedtuple('Experiment', ['name', 'continuous_dim', 'cancellation_likelihood', 'nn_dims', 'activation_fun'])
EXPERIMENTS_MANUAL = [
    Experiment(name="relu-2-dim", continuous_dim=2, cancellation_likelihood=0.3, nn_dims=None, activation_fun=nn.ReLU),
    Experiment(name="relu-3-dim", continuous_dim=3, cancellation_likelihood=0.25, nn_dims=None, activation_fun=nn.ReLU),
    Experiment(name="relu-4-dim", continuous_dim=4, cancellation_likelihood=0.2, nn_dims=None, activation_fun=nn.ReLU),
    Experiment(name="relu-5-dim", continuous_dim=5, cancellation_likelihood=0.15, nn_dims=None, activation_fun=nn.ReLU),
    Experiment(name="relu-100-dim", continuous_dim=100, cancellation_likelihood=0.05, nn_dims=None, activation_fun=nn.ReLU),
    Experiment(name="relu-1000-dim", continuous_dim=1000, cancellation_likelihood=0.01, nn_dims=None, activation_fun=nn.ReLU),
]
EXPERIMENTS = [
    Experiment(name="relu-2-dim", continuous_dim=2, cancellation_likelihood=0.3, nn_dims=[4, 16, 16, 1], activation_fun=nn.ReLU),
    Experiment(name="relu-3-dim", continuous_dim=3, cancellation_likelihood=0.25, nn_dims=[6, 24, 24, 1], activation_fun=nn.ReLU),
    Experiment(name="relu-4-dim", continuous_dim=4, cancellation_likelihood=0.2, nn_dims=[8, 32, 32, 1], activation_fun=nn.ReLU),
    Experiment(name="relu-5-dim", continuous_dim=5, cancellation_likelihood=0.15, nn_dims=[10, 40, 40, 1], activation_fun=nn.ReLU),
    Experiment(name="gelu-2-dim", continuous_dim=2, cancellation_likelihood=0.3, nn_dims=[4, 16, 16, 1], activation_fun=nn.GELU),
    Experiment(name="gelu-3-dim", continuous_dim=3, cancellation_likelihood=0.25, nn_dims=[6, 24, 24, 1], activation_fun=nn.GELU),
    Experiment(name="gelu-4-dim", continuous_dim=4, cancellation_likelihood=0.2, nn_dims=[8, 32, 32, 1], activation_fun=nn.GELU),
    Experiment(name="gelu-5-dim", continuous_dim=5, cancellation_likelihood=0.15, nn_dims=[10, 40, 40, 1], activation_fun=nn.GELU),
]

In [4]:
def construct_nn(nn_dims, activation_fun):    
    layers = []
    for i in range(1, len(nn_dims)):
        in_dim, out_dim = nn_dims[i-1], nn_dims[i]
        layers.append(nn.Linear(in_dim, out_dim))
        layers.append(activation_fun())
    # Remove the last activation layer (regression problem)
    layers = layers[:-1]
    
    return nn.Sequential(*layers).to(DEVICE)

def train_nn(model, train_dl):
    loss_fun = nn.MSELoss()
    opt = torch.optim.AdamW(model.parameters())
    model.train()
    
    losses = []
    for epoch in tqdm(range(NUM_EPOCHS), leave=False):
        total_loss = 0
        for i, (x, y) in list(enumerate(train_dl)):
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            out = model(x)
            loss = loss_fun(out.squeeze(-1), y.float())
            total_loss += loss.item()
            loss.backward()
            opt.step()
            opt.zero_grad()
        losses.append(total_loss)

def eval_nn(model, test_dataset):
    model.eval()

    predictions = model(torch.stack(test_dataset.samples).to(DEVICE)).detach().numpy()
    labels = torch.tensor(test_dataset.labels).unsqueeze(-1).detach().numpy()

    rmse = mean_squared_error(
        labels, predictions, multioutput='uniform_average', squared=False
    )

    return rmse

In [5]:
def explain_cafe_const(c, total_dim):
    def explain_cafe(model, x):
        cafe = CafeUpgradeExplainer(model, c=c)
        split_x = torch.split(x, 50)
        all_attrs = []
        for x_part in split_x:
            ((s_feat_plus, s_feat_minus), _) = cafe.attribute(x_part, ref=torch.zeros(total_dim), target=0)
            attribution = s_feat_plus - s_feat_minus
            all_attrs.append(attribution)
        attributions = torch.cat(tuple(all_attrs))
        return attributions

    return explain_cafe

def explain_cafe_const_smooth(c, total_dim):
    def explain_cafe(model, x):
        cafe = CafeSmoothExplainer(model, c=c)
        ((s_feat_plus, s_feat_minus), _) = cafe.attribute(x, ref=torch.zeros(1, total_dim), target=0)
        attribution = s_feat_plus - s_feat_minus
        return attribution

    return explain_cafe

def explain_ixg(model, x):
    with warnings.catch_warnings():
        # Suppress annoying warnings about setting activation hooks & time consuming operations
        # IMPORTANT: Re-enable warnings when changing code to ensure things still work
        warnings.simplefilter('ignore')
        ixg_captum = captum.attr.InputXGradient(model)
        attribution = ixg_captum.attribute(x)
    return attribution

def explain_lrp(model, x):
    with warnings.catch_warnings():
        # Suppress annoying warnings about setting activation hooks & time consuming operations
        # IMPORTANT: Re-enable warnings when changing code to ensure things still work
        warnings.simplefilter('ignore')
        lrp_captum = captum.attr.LRP(model)
        attribution = lrp_captum.attribute(x)
    return attribution

def explain_dl_const(total_dim, multiply_by_inputs=True):
    def explain_dl(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            dl_captum = captum.attr.DeepLift(model, multiply_by_inputs=multiply_by_inputs)
            attribution, delta = dl_captum.attribute(x, baselines=torch.zeros(1, total_dim), return_convergence_delta=True)
        return attribution

    return explain_dl

def explain_ig_const(total_dim, multiply_by_inputs=True):
    def explain_ig(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            ig_captum = captum.attr.IntegratedGradients(model, multiply_by_inputs=multiply_by_inputs)
            attribution, delta = ig_captum.attribute(x, baselines=torch.zeros(1, total_dim), return_convergence_delta=True)
        return attribution

    return explain_ig

def explain_sg_const(multiply_by_inputs=True):
    def explain_sg(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            nt_captum = captum.attr.NoiseTunnel(captum.attr.Saliency(model))
            attribution = nt_captum.attribute(x)
            if multiply_by_inputs:
                attribution *= x
        return attribution

    return explain_sg

def explain_gs_const(total_dim, multiply_by_inputs=True):
    def explain_gs(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            gs_captum = captum.attr.GradientShap(model, multiply_by_inputs=multiply_by_inputs)
            attribution = gs_captum.attribute(x, baselines=torch.zeros(1, total_dim))
        return attribution

    return explain_gs

def explain_ks_const(total_dim, multiply_by_inputs=False):
    def explain_ks(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            ks_captum = captum.attr.KernelShap(model)
            attribution = ks_captum.attribute(x, baselines=torch.zeros(1, total_dim))
            if multiply_by_inputs:
                attribution *= x
        return attribution

    return explain_ks

def explain_svs_const(total_dim, multiply_by_inputs=False):
    def explain_svs(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            svs_captum = captum.attr.ShapleyValueSampling(model)
            attribution = svs_captum.attribute(x, baselines=torch.zeros(1, total_dim), target=0)
            if multiply_by_inputs:
                attribution *= x
        return attribution

    return explain_svs

def explain_lime_const(total_dim, multiply_by_inputs=False):
    def explain_lime(model, x):
        with warnings.catch_warnings():
            # Suppress annoying warnings about setting activation hooks & time consuming operations
            # IMPORTANT: Re-enable warnings when changing code to ensure things still work
            warnings.simplefilter('ignore')
            lime_captum = captum.attr.Lime(model)
            split_x = torch.split(x, 50)
            all_attrs = []
            for x_part in split_x:
                attribution = lime_captum.attribute(x_part, baselines=torch.zeros(1, total_dim))
                if multiply_by_inputs:
                    attribution *= x_part
                all_attrs.append(attribution)
            attributions = torch.cat(tuple(all_attrs))
        return attributions

    return explain_lime

## Manually Constructed Model Test

In [6]:
def construct_manual_relu_nn(continuous_dim, weights):
    # First linear layer returns output in the form [x1, x2, -x1, -x2, c1, c2]
    # where xs are the continuous features and cs are the cancellation features
    l1 = nn.Linear(continuous_dim * 2, continuous_dim * 3, bias=False)
    w = torch.zeros((continuous_dim * 3, continuous_dim * 2))
    w[0:continuous_dim, 0:continuous_dim] = torch.diag_embed(weights)
    w[continuous_dim:continuous_dim * 2, 0:continuous_dim] = torch.diag_embed(-weights)
    w[continuous_dim * 2:continuous_dim * 3, continuous_dim:continuous_dim * 2] = torch.eye(continuous_dim)
    l1.weight = nn.Parameter(w)

    # Second layer returns output in the form [ReLU(x1) - 100 * c1, ReLU(x2) - 100 * c2, -ReLU(x1) - 100, -ReLU(x2) * (1 - c2)]
    l2 = nn.Linear(continuous_dim * 3, continuous_dim * 2, bias=False)
    w = torch.zeros((continuous_dim * 2, continuous_dim * 3))
    w[0:continuous_dim, 0:continuous_dim] = torch.eye(continuous_dim)
    w[0:continuous_dim, continuous_dim*2:continuous_dim*3] = -50 * torch.eye(continuous_dim)
    w[continuous_dim:continuous_dim*2, continuous_dim:continuous_dim*2] = torch.eye(continuous_dim)
    w[continuous_dim:continuous_dim*2, continuous_dim*2:continuous_dim*3] = -50 * torch.eye(continuous_dim)
    l2.weight = nn.Parameter(w)

    # Finally, the output layer computes
    l3 = nn.Linear(continuous_dim * 2, 1, bias=False)
    w = torch.zeros((1, continuous_dim * 2))
    w[0, 0:continuous_dim] = 1.
    w[0, continuous_dim:continuous_dim*2] = -1.
    l3.weight = nn.Parameter(w)
    return nn.Sequential(l1, nn.ReLU(), l2, nn.ReLU(), l3)

In [None]:
for experiment_id, experiment in enumerate(EXPERIMENTS_MANUAL):
    experiment_name = experiment.name
    if 'relu' not in experiment_name:
        # Only ReLU supported for manually consturcted networks
        continue
    continuous_dim = experiment.continuous_dim
    total_dim = continuous_dim * 2
    cancellation_likelihood = experiment.cancellation_likelihood
    activation_fun = experiment.activation_fun
    nn_dims = [continuous_dim * 2, continuous_dim * 3, continuous_dim * 2, 1]
    adjusted_experiment = deepcopy(experiment)
    adjusted_experiment = adjusted_experiment._replace(nn_dims=nn_dims)

    print(f"———————————————[ Running experiment #{experiment_id + 1}/{len(EXPERIMENTS_MANUAL)} ]———————————————")
    print(f"———————[ Experiment parameters ]———————")
    print(adjusted_experiment)
    print()

    # Note: Whether to multiply attributions returned by each method by the inputs was
    #       chosen according to which of the two variants exhibited better performance.
    methods = [
        ("Gradient x Input", explain_ixg),
        ("LRP", explain_lrp),
        ("DeepLIFT Rescale", explain_dl_const(total_dim)),
        # ("Integrated Gradients", explain_ig_const(total_dim)),
        # ("SmoothGrad (Multiplicative)", explain_sg_const()),
        ("Gradient SHAP", explain_gs_const(total_dim)),
        ("Kernel SHAP", explain_ks_const(total_dim)),
        ("Shapley Value Sampling", explain_svs_const(total_dim)),
        # ("LIME", explain_lime_const(total_dim)),
        # ("CAFE (c = 0.0)", explain_cafe_const(0.0, total_dim)),
        # ("CAFE (c = 0.25)", explain_cafe_const(0.25, total_dim)),
        # ("CAFE (c = 0.5)", explain_cafe_const(0.5, total_dim)),
        # ("CAFE (c = 0.75)", explain_cafe_const(0.75, total_dim)),
        # ("CAFE (c = 1.0)", explain_cafe_const(1.0, total_dim)),
        ("Cafe Smooth (c = 0.0)", explain_cafe_const_smooth(0.0, total_dim)),
        ("Cafe Smooth (c = 0.25)", explain_cafe_const_smooth(0.25, total_dim)),
        ("Cafe Smooth (c = 0.5)", explain_cafe_const_smooth(0.5, total_dim)),
        ("Cafe Smooth (c = 0.75)", explain_cafe_const_smooth(0.75, total_dim)),
        ("Cafe Smooth (c = 1.0)", explain_cafe_const_smooth(1.0, total_dim)),
    ]

    result_headers = ["Validation RMSE", "Test RMSE"] + [m[0] for m in methods]
    results = []
    runtime_headers = [f"{m[0]} Runtime" for m in methods]
    runtimes = []
    
    for data_seed, nn_start_seed in tqdm(list(zip(DATA_SEEDS, NETWORK_SEEDS)), leave=False):
        current_results = []
        current_runtimes = []
        
        dataset = SyntheticCancellationDataset.generate(
            10000,
            seed=data_seed,
            continuous_dim=continuous_dim,
            cancellation_likelihood=cancellation_likelihood,
            standard_dev=10.,
            weight_range=(-1., 1.),
        )

        train_dataset, tmp_val_dataset = dataset.split(0.6)
        val_dataset, test_dataset = tmp_val_dataset.split(0.5)
        train_dl = DataLoader(
            dataset=train_dataset,
            batch_size=64,
            shuffle=False,
        )

        selected_model = None
        best_rmse = np.inf
        for nn_seed in tqdm(range(nn_start_seed, nn_start_seed + 5), leave=False):
            torch.manual_seed(nn_seed)
            model = construct_manual_relu_nn(continuous_dim, dataset.weights)
            val_rmse = eval_nn(model, val_dataset)
            if val_rmse < best_rmse:
                selected_model = model
                best_rmse = val_rmse
        model = selected_model
        
        val_rmse = eval_nn(model, val_dataset)
        current_results.append(val_rmse)
        test_rmse = eval_nn(model, test_dataset)
        current_results.append(test_rmse)

        model.eval()
        model.zero_grad()

        for method_name, method_fun in tqdm(methods, leave=False):
            torch.manual_seed(data_seed)
            np.random.seed(data_seed)
            exp_model = deepcopy(model)
            start_time = time.time()
            all_attributions = method_fun(exp_model, torch.stack(test_dataset.samples).to(DEVICE))
            end_time = time.time()
            current_runtimes.append(end_time - start_time)
            attributions_error = mean_squared_error(
                torch.stack(test_dataset.ground_truth_attributions).detach().numpy(),
                all_attributions.detach().numpy(),
                multioutput='uniform_average',
                squared=False
            )
            current_results.append(attributions_error)

        results.append(current_results)
        runtimes.append(current_runtimes)

    results = np.array(results)
    means = results.mean(axis=0)
    stds = results.std(axis=0)

    print(f"———————[ Experiment results ]———————")
    for header, mean, std in zip(result_headers, means, stds):
        print(f"{header}: {'{:.2f}'.format(round(mean, 2))}±{'{:.3f}'.format(round(std, 3))}")
    print()

    results = np.array(runtimes)
    means = results.mean(axis=0)
    stds = results.std(axis=0)

    print(f"———————[ Experiment runtimes ]———————")
    for header, mean, std in zip(runtime_headers, means, stds):
        print(f"{header}: {'{:.3f}'.format(round(mean, 3))}±{'{:.4f}'.format(round(std, 4))}")
    print()
    print()

## Ground-Truth Test

In [None]:
for experiment_id, experiment in enumerate(EXPERIMENTS):
    experiment_name = experiment.name
    continuous_dim = experiment.continuous_dim
    total_dim = continuous_dim * 2
    cancellation_likelihood = experiment.cancellation_likelihood
    nn_dims = experiment.nn_dims
    activation_fun = experiment.activation_fun

    print(f"———————————————[ Running experiment #{experiment_id + 1}/{len(EXPERIMENTS)} ]———————————————")
    print(f"———————[ Experiment parameters ]———————")
    print(experiment)
    print()

    # Note: Whether to multiply attributions returned by each method by the inputs was
    #       chosen according to which of the two variants exhibited better performance.
    methods = [
        ("Gradient x Input", explain_ixg),
        ("LRP", explain_lrp),
        ("DeepLIFT Rescale", explain_dl_const(total_dim)),
        # ("Integrated Gradients", explain_ig_const(total_dim)),
        # ("SmoothGrad (Multiplicative)", explain_sg_const()),
        ("Gradient SHAP", explain_gs_const(total_dim)),
        ("Kernel SHAP", explain_ks_const(total_dim)),
        ("Shapley Value Sampling", explain_svs_const(total_dim)),
        # ("LIME", explain_lime_const(total_dim)),
        ("CAFE (c = 0.0)", explain_cafe_const(0.0, total_dim)),
        ("CAFE (c = 0.25)", explain_cafe_const(0.25, total_dim)),
        ("CAFE (c = 0.5)", explain_cafe_const(0.5, total_dim)),
        ("CAFE (c = 0.75)", explain_cafe_const(0.75, total_dim)),
        ("CAFE (c = 1.0)", explain_cafe_const(1.0, total_dim)),
        ("Cafe Smooth (c = 0.0)", explain_cafe_const_smooth(0.0, total_dim)),
        ("Cafe Smooth (c = 0.25)", explain_cafe_const_smooth(0.25, total_dim)),
        ("Cafe Smooth (c = 0.5)", explain_cafe_const_smooth(0.5, total_dim)),
        ("Cafe Smooth (c = 0.75)", explain_cafe_const_smooth(0.75, total_dim)),
        ("Cafe Smooth (c = 1.0)", explain_cafe_const_smooth(1.0, total_dim)),
    ]

    result_headers = ["Validation RMSE", "Test RMSE"] + [m[0] for m in methods]
    results = []
    
    for data_seed, nn_start_seed in tqdm(list(zip(DATA_SEEDS, NETWORK_SEEDS)), leave=False):
        current_results = []
        
        dataset = SyntheticCancellationDataset.generate(
            10000,
            seed=data_seed,
            continuous_dim=continuous_dim,
            cancellation_likelihood=cancellation_likelihood,
            standard_dev=10.,
            weight_range=(-1., 1.),
        )

        train_dataset, tmp_val_dataset = dataset.split(0.6)
        val_dataset, test_dataset = tmp_val_dataset.split(0.5)
        train_dl = DataLoader(
            dataset=train_dataset,
            batch_size=64,
            shuffle=False,
        )

        selected_model = None
        best_rmse = np.inf
        for nn_seed in tqdm(range(nn_start_seed, nn_start_seed + 5), leave=False):
            torch.manual_seed(nn_seed)
            model = construct_nn(nn_dims, activation_fun)
            checkpoint_path = f"models/{experiment_name}-d{data_seed}-nn{nn_seed}.pth"
            if os.path.isfile(checkpoint_path):
                # Load the saved checkpoint
                model.load_state_dict(torch.load(checkpoint_path))
            else:
                # Train and save the model
                train_nn(model, train_dl)
                torch.save(model.state_dict(), checkpoint_path)
            val_rmse = eval_nn(model, val_dataset)
            if val_rmse < best_rmse:
                selected_model = model
                best_rmse = val_rmse
        model = selected_model
        
        val_rmse = eval_nn(model, val_dataset)
        current_results.append(val_rmse)
        test_rmse = eval_nn(model, test_dataset)
        current_results.append(test_rmse)

        model.eval()
        model.zero_grad()

        for method_name, method_fun in tqdm(methods, leave=False):
            torch.manual_seed(data_seed)
            np.random.seed(data_seed)
            exp_model = deepcopy(model)
            all_attributions = method_fun(exp_model, torch.stack(test_dataset.samples).to(DEVICE))
            attributions_error = mean_squared_error(
                torch.stack(test_dataset.ground_truth_attributions).detach().numpy(),
                all_attributions.detach().numpy(),
                multioutput='uniform_average',
                squared=False
            )
            current_results.append(attributions_error)

        results.append(current_results)

    results = np.array(results)
    means = results.mean(axis=0)
    stds = results.std(axis=0)

    print(f"———————[ Experiment results ]———————")
    for header, mean, std in zip(result_headers, means, stds):
        print(f"{header}: {'{:.2f}'.format(round(mean, 2))}±{'{:.3f}'.format(round(std, 3))}")
    print()
    print()