# Load data

In [2]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [3]:
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('../', train=True, download=True, transform=transform_train)

# Load test set (using as validation)
val_set = dsets.CIFAR10('../', train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
import sys
print(sys.path)
sys.path.append('c:/Users/bryan/fastshap')

['c:\\Users\\bryan\\fastshap\\notebooks', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\\python310.zip', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\\DLLs', 'C:\\Program Files\\WindowsApps\\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\\lib', 'C:\\Users\\bryan\\AppData\\Local\\Microsoft\\WindowsApps\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0', '', 'C:\\Users\\bryan\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages', 'C:\\Users\\bryan\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\win32', 'C:\\Users\\bryan\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\win32\\lib', 'C:\\Users\\bryan\\AppData\\Loca

# Train model with missingness

In [5]:
import torch
import torch.nn as nn
import os.path
from resnet import ResNet18
from fastshap import ImageSurrogate
from fastshap.utils import MaskLayer2d, KLDivLoss, DatasetInputOnly

In [6]:
# Select device
device = torch.device('cuda')

In [7]:
# Check for model
if os.path.isfile('cifar missingness.pt'):
    print('Loading saved model')
    model = torch.load('cifar missingness.pt').to(device)
    imputer = ImageSurrogate(model, width=32, height=32, superpixel_size=2)

else:
    # Create model
    model = nn.Sequential(
        MaskLayer2d(value=0, append=True),
        ResNet18(in_channels=4, num_classes=10)).to(device)
    
    # familiariaze with DataLoader
    # Set up surrogate wrapper (although this is not a surrogate model)
    imputer = ImageSurrogate(model, width=32, height=32, superpixel_size=2)

    # Train
    imputer.train(train_set,
                  val_set,
                  batch_size=256,
                  max_epochs=100,
                  loss_fn=nn.CrossEntropyLoss(),
                  lookback=10,
                  bar=True,
                  verbose=True)
    
    # Save model
    model.cpu()
    torch.save(model, 'cifar missingness.pt')
    model.to(device)

Loading saved model


# Train FastSHAP

In [8]:
from unet import UNet
from fastshap import FastSHAP

In [9]:
# Check for model
if os.path.isfile('cifar missingness explainer.pt'):
    print('Loading saved explainer model')
    explainer = torch.load('cifar missingness explainer.pt').to(device)
    fastshap = FastSHAP(explainer, imputer, link=nn.LogSoftmax(dim=1))

else:
    # Set up explainer model
    explainer = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)

    # Set up FastSHAP object
    fastshap = FastSHAP(explainer, imputer, link=nn.LogSoftmax(dim=1))

    # Set up datasets
    # the train and val set are just the CIFAR dataset
    fastshap_train = DatasetInputOnly(train_set)
    fastshap_val = DatasetInputOnly(val_set)

    # Train
    fastshap.train(
        fastshap_train,
        fastshap_val,
        batch_size=128,
        num_samples=2,
        max_epochs=200,
        eff_lambda=1e-2,
        validation_samples=1,
        lookback=10,
        bar=True,
        verbose=True)
    
    # Save explainer
    explainer.cpu()
    torch.save(explainer, 'cifar missingness explainer.pt')
    explainer.to(device)

Loading saved explainer model


# Visualize results

In [10]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [155]:
def calculate_exclusion(values, samples, targets):
    inclusion_auc = []

    for i in range(len(values)):
        # Sort SHAP values for each instance
        print(i)
        sorted_val = torch.sort(values[i, targets[i]].flatten(), descending=True)[0]
        #print(sorted_val)
        
        # Compute inclusion percentage thresholds
        inclusion_percentages = torch.linspace(0, 1, len(sorted_val) + 1)
        #print(inclusion_percentages)
        
        top1_accuracies = []
        # Iterate over inclusion percentage thresholds

        for threshold in inclusion_percentages:
            # Create binary mask for the top features
            mask = values[i, targets[i]] >= sorted_val[int(threshold * (len(sorted_val)-1))]
            mask = mask.reshape(-1)
            mask = mask.view((1, 256))
            #print(mask.shape)
            
            # Upsample binary mask if necessary
            #mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), (32, 32), mode='nearest').bool()
            
            # Apply mask to input image
            # masked_image = x.clone()[i]
            # masked_image[:, mask[0][0]] = 100  # Example modification
            
            # Downsample masked image (if needed) and perform inference
            #downsampled_image = F.interpolate(masked_image, (16, 16), mode='bilinear', align_corners=False)
            S = torch.ones(1, imputer.num_players, device=device)
            S[mask] = 0
            #print(S)
            output = imputer(samples[i].to(device), S).softmax(dim=1).cpu().data
            
            # Compute top-1 accuracy
            _, predicted = torch.max(output, dim=1)
            #print(predicted)
            correct = (predicted == targets[i]).sum().item()
            accuracy = correct
            
            top1_accuracies.append(accuracy)
        
        # Compute AUC for inclusion percentages vs. top-1 accuracies curve
        top1_accuracies=torch.tensor(top1_accuracies)
        auc = torch.trapz(top1_accuracies, inclusion_percentages)
        inclusion_auc.append(auc)

    # Compute average inclusion AUC across all instances
    average_auc = torch.mean(torch.tensor(inclusion_auc))
    #print(inclusion_auc)
    print("Average exclusion AUC:", average_auc.item())

In [156]:
def calculate_inclusion(values, samples, targets):
    inclusion_auc = []

    for i in range(len(values)):
        # Sort SHAP values for each instance
        print(i)
        sorted_val = torch.sort(values[i, targets[i]].flatten(), descending=True)[0]
        #print(sorted_val)
        
        # Compute inclusion percentage thresholds
        inclusion_percentages = torch.linspace(0, 1, len(sorted_val) + 1)
        #print(inclusion_percentages)
        
        top1_accuracies = []
        # Iterate over inclusion percentage thresholds

        for threshold in inclusion_percentages:
            # Create binary mask for the top features
            # what I changed for exclusion is that there is a < instead of a >=
            mask = values[i, targets[i]] < sorted_val[int(threshold * (len(sorted_val)-1))]
            mask = mask.reshape(-1)
            mask = mask.view((1, 256))
            #print(mask.shape)
            
            # Upsample binary mask if necessary
            #mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), (32, 32), mode='nearest').bool()
            
            # Apply mask to input image
            # masked_image = x.clone()[i]
            # masked_image[:, mask[0][0]] = 100  # Example modification
            
            # Downsample masked image (if needed) and perform inference
            #downsampled_image = F.interpolate(masked_image, (16, 16), mode='bilinear', align_corners=False)
            S = torch.ones(1, imputer.num_players, device=device)
            S[mask] = 0
            #print(S)
            output = imputer(samples[i].to(device), S).softmax(dim=1).cpu().data
            
            # Compute top-1 accuracy
            _, predicted = torch.max(output, dim=1)
            #print(predicted)
            correct = (predicted == targets[i]).sum().item()
            accuracy = correct
            
            top1_accuracies.append(accuracy)
        
        # Compute AUC for inclusion percentages vs. top-1 accuracies curve
        top1_accuracies=torch.tensor(top1_accuracies)
        auc = torch.trapz(top1_accuracies, inclusion_percentages)
        inclusion_auc.append(auc)

    # Compute average inclusion AUC across all instances
    average_auc = torch.mean(torch.tensor(inclusion_auc))
    #print(inclusion_auc)
    print("Average inclusion AUC:", average_auc.item())

In [157]:
# Select one image from each class
dset = val_set
samples = np.array(dset.data)
targets = np.array(dset.targets)
num_classes = targets.max() + 1
inds_lists = [np.where(targets == cat)[0] for cat in range(num_classes)]
inds = [np.random.choice(cat_inds) for cat_inds in inds_lists]
x, y = zip(*list(dset))
x = torch.stack(x)

# # Get explanations
# values = torch.tensor(fastshap.shap_values(x.to(device)))
# print(values.shape)
# sorted_val = []
# sorted_masks = []
# for i in range(len(values)):
#     sorted_val.append(values[i][i].flatten())
#     sorted_val[i] = sorted(sorted_val[i], reverse=True)
#     temp_coord = []
#     for j in range(len(sorted_val[i])):
#         mask = (values[i][i] == sorted_val[i][j])
#         #print(mask)
#         temp_coord.append(mask)
#     sorted_masks.append(temp_coord)
#     print(len(sorted_val))

# eval_x = x
# print(eval_x.shape)
# top1_accuracies = []
# for i in range(len(sorted_val[i])):
#     for j in range(len(sorted_val)):
#         mask = sorted_masks[j][0]
#         print(mask)
#         downsampled_tensor = F.interpolate(eval_x, (16, 16), mode='bilinear', align_corners=False)
#         downsampled_tensor[j,:,mask] = 100
#         eval_x = F.interpolate(downsampled_tensor, (32, 32), mode='bilinear', align_corners=False)

import torch
import torch.nn.functional as F

# Assuming values is your SHAP values tensor with shape (batch_size, num_classes, num_features, height, width)
#values = torch.tensor(fastshap.shap_values(x.to(device)))
values_val = torch.tensor(fastshap.shap_values(x.to(device)))
print("Shape of SHAP values tensor:", values_val.shape)

calculate_exclusion(values_val, x, y)
calculate_inclusion(values_val, x, y)




# basically get the classification accuracy of the imputer model whenever as you remove the top included pixel features and map this
# do this for all the images
# we have the way to get top pixel imp values at the top



# Get predictions
pred = imputer(
    x.to(device),
    torch.ones(num_classes, imputer.num_players, device=device)
).softmax(dim=1).cpu().data.numpy()

fig, axarr = plt.subplots(num_classes, num_classes + 1, figsize=(22, 20))

for row in range(num_classes):
    # Image
    classes = ['Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
    mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]
    im = x[row].numpy() * std + mean
    im = im.transpose(1, 2, 0).astype(float)
    im = np.clip(im, a_min=0, a_max=1)
    axarr[row, 0].imshow(im, vmin=0, vmax=1)
    axarr[row, 0].set_xticks([])
    axarr[row, 0].set_yticks([])
    axarr[row, 0].set_ylabel('{}'.format(classes[y[row]]), fontsize=14)
    
    # Explanations
    m = np.abs(values[row]).max()
    for col in range(num_classes):
        axarr[row, col + 1].imshow(values[row, col], cmap='seismic', vmin=-m, vmax=m)
        axarr[row, col + 1].set_xticks([])
        axarr[row, col + 1].set_yticks([])
        if col == y[row]:
            axarr[row, col + 1].set_xlabel('{:.2f}'.format(pred[row, col]), fontsize=12, fontweight='bold')
        else:
            axarr[row, col + 1].set_xlabel('{:.2f}'.format(pred[row, col]), fontsize=12)
        
        # Class labels
        if row == 0:
            axarr[row, col + 1].set_title('{}'.format(classes[y[col]]), fontsize=14)

plt.tight_layout()
plt.show()

Shape of SHAP values tensor: torch.Size([10000, 10, 16, 16])


In [15]:
num_shadow = 3
train_set = dsets.CIFAR10('../', train=True, download=True, transform=transform_train)
# random_indices = np.random.default_rng(seed=None).permutation(len(train_set))
random_indices = torch.randperm(len(train_set))
selected_i = random_indices[:30000]
n_selected_i = random_indices[30000:]
shadow_set = torch.utils.data.Subset(train_set, selected_i)
target_set = torch.utils.data.Subset(train_set, n_selected_i)

random_indices = torch.randperm(len(shadow_set))
first_i = random_indices[:10000]
second_i = random_indices[10000:20000]
third_i = random_indices[20000:]
shadow_set1 = torch.utils.data.Subset(shadow_set, first_i)
shadow_set2 = torch.utils.data.Subset(shadow_set, second_i)
shadow_set3 = torch.utils.data.Subset(shadow_set, third_i)

# Set up explainer model
explainer1 = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)
explainer2 = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)
explainer3 = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)

# Set up FastSHAP object
fastshap_shadow1 = FastSHAP(explainer1, imputer, link=nn.LogSoftmax(dim=1))
fastshap_shadow2 = FastSHAP(explainer2, imputer, link=nn.LogSoftmax(dim=1))
fastshap_shadow3 = FastSHAP(explainer3, imputer, link=nn.LogSoftmax(dim=1))

# Set up datasets
# the train and val set are just the CIFAR dataset
fastshap_train1 = DatasetInputOnly(shadow_set1)
fastshap_train2 = DatasetInputOnly(shadow_set2)
fastshap_train3 = DatasetInputOnly(shadow_set3)
fastshap_val = DatasetInputOnly(val_set)

# Train
fastshap_shadow1.train(
    fastshap_train,
    fastshap_val,
    batch_size=128,
    num_samples=2,
    max_epochs=200,
    eff_lambda=1e-2,
    validation_samples=1,
    lookback=10,
    bar=True,
    verbose=True)

fastshap_shadow2.train(
    fastshap_train,
    fastshap_val,
    batch_size=128,
    num_samples=2,
    max_epochs=200,
    eff_lambda=1e-2,
    validation_samples=1,
    lookback=10,
    bar=True,
    verbose=True)

fastshap_shadow3.train(
    fastshap_train,
    fastshap_val,
    batch_size=128,
    num_samples=2,
    max_epochs=200,
    eff_lambda=1e-2,
    validation_samples=1,
    lookback=10,
    bar=True,
    verbose=True)

# Save explainer
explainer1.cpu()
torch.save(explainer1, './ckpt/cifar missingness explainer1.pt')
explainer1.to(device)

explainer2.cpu()
torch.save(explainer2, './ckpt/cifar missingness explainer2.pt')
explainer2.to(device)

explainer3.cpu()
torch.save(explainer3, './ckpt/cifar missingness explainer3.pt')
explainer3.to(device)

# for i in range(num_shadow):
#     # Set up explainer model
#     explainer = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)

#     # Set up FastSHAP object
#     fastshap_shadow = FastSHAP(explainer, imputer, link=nn.LogSoftmax(dim=1))

#     # Set up datasets
#     # the train and val set are just the CIFAR dataset
#     fastshap_train = DatasetInputOnly(train_set)
#     fastshap_val = DatasetInputOnly(val_set)

#     # Train
#     fastshap.train(
#         fastshap_train,
#         fastshap_val,
#         batch_size=128,
#         num_samples=2,
#         max_epochs=200,
#         eff_lambda=1e-2,
#         validation_samples=1,
#         lookback=10,
#         bar=True,
#         verbose=True)
    
#     # Save explainer
#     explainer.cpu()
#     torch.save(explainer, 'cifar missingness explainer.pt')
#     explainer.to(device)

Files already downloaded and verified
