In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data import RandomSampler, BatchSampler
from torch.distributions.categorical import Categorical

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

from copy import deepcopy
from tqdm.auto import tqdm
import itertools
import numpy as np
import pickle
import os.path
# import lightgbm as lgb
import matplotlib.pyplot as plt
import time
import random
import pandas as pd
import collections

import tensorflow as tf
from tensorflow.keras.layers import (Input, Layer, Dense, Lambda, 
                                     Dropout, Multiply, BatchNormalization, 
                                     Reshape, Concatenate, Conv2D, Permute)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers.experimental.preprocessing import Resizing

from tensorflow.keras.datasets import cifar10

from datetime import datetime

#Select GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  try:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=12288)])
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)

# Laod Data

## CIFAR10

In [None]:
from sklearn.model_selection import train_test_split

num_classes = 10

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, train_size=0.5, random_state=420)

x_train = x_train.astype('float32')
x_val = x_val.astype('float32')
x_test = x_test.astype('float32')
#Resize to 224x224

print(x_train.shape[0], 'train samples')
print(x_val.shape[0], 'val samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_val = tf.keras.utils.to_categorical(y_val, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

# Make TF Dataset
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

dataset="CIFAR"

In [None]:
def batch_data(dataset, fn, batch_size=32):
    dataset = dataset.map(fn)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

In [None]:
def reformat(x, y):
    
    x = tf.cast(x, tf.float32)
    x = Resizing(INPUT_SHAPE[0], INPUT_SHAPE[1], interpolation='nearest')(x)
    x = tf.keras.applications.resnet50.preprocess_input(x)
    
    return (x, y)

ds_train = batch_data(ds_train, reformat, BATCH_SIZE)
ds_val = batch_data(ds_val, reformat, BATCH_SIZE)
ds_test = batch_data(ds_test, reformat, BATCH_SIZE)

In [None]:
if os.path.exists('train_numpy.npy'):
    print('Loading numpy arrays...')
    train_numpy = np.load('train_numpy.npy')
    print(train_numpy.shape)
    val_numpy = np.load('val_numpy.npy')
    print(val_numpy.shape)
    test_numpy = np.load('test_numpy.npy')
    print(test_numpy.shape)
else:
    train_numpy= [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_train)]
    train_numpy = np.array(train_numpy)
    print(train_numpy.shape)
    val_numpy = [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_val)]
    val_numpy = np.array(val_numpy)
    print(val_numpy.shape)      
    test_numpy = [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_test)]
    test_numpy = np.array(test_numpy)
    print(test_numpy.shape)
    # Save the numpy arrays
    np.save('train_numpy.npy', train_numpy)
    np.save('val_numpy.npy', val_numpy)
    np.save('test_numpy.npy', test_numpy)

## Imagenette

In [None]:
BATCH_SIZE = 1
EPOCHS = 50
LR = 1e-2
INPUT_SHAPE = (224, 224, 3)

dataset="Imagenette"

In [None]:
import tensorflow_datasets as tfds
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'imagenette/full-size-v2',
    split=['train', 'validation[:50%]', 'validation[-50%:]'],
    as_supervised=False,
    with_info=True
)

In [None]:
def batch_data(dataset, fn, batch_size=32):
    dataset = dataset.map(fn)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

In [None]:
def reformat(input_dict):
    
    i = input_dict['image']
    i = tf.cast(i, tf.float32)
    i = tf.image.resize_with_crop_or_pad(i, 224, 224)
    i = tf.keras.applications.resnet50.preprocess_input(i)
    
    l = tf.one_hot(input_dict['label'], depth = 10)
    
    return (i, l)

ds_train = batch_data(ds_train, reformat, BATCH_SIZE)
ds_val = batch_data(ds_val, reformat, BATCH_SIZE)
ds_test = batch_data(ds_test, reformat, BATCH_SIZE)

In [None]:
if os.path.exists('cache/IMAGENETTE/train_numpy.npy'):
    print('Loading numpy arrays...')
    train_numpy = np.load('cache/IMAGENETTE/train_numpy.npy')
    print(train_numpy.shape)
    val_numpy = np.load('cache/IMAGENETTE/val_numpy.npy')
    print(val_numpy.shape)
    test_numpy = np.load('cache/IMAGENETTE/test_numpy.npy')
    print(test_numpy.shape)
else:
    train_numpy= [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_train)]
    train_numpy = np.array(train_numpy)
    print(train_numpy.shape)
    val_numpy = [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_val)]
    val_numpy = np.array(val_numpy)
    print(val_numpy.shape)
    test_numpy = [np.transpose(x[0].numpy(), (2,0,1)) for x, y in tqdm(ds_test)]
    test_numpy = np.array(test_numpy)
    print(test_numpy.shape)
    # Save the numpy arrays
    np.save('cache/IMAGENETTE/train_numpy.npy', train_numpy)
    np.save('cache/IMAGENETTE/val_numpy.npy', val_numpy)
    np.save('cache/IMAGENETTE/test_numpy.npy', test_numpy)

# Load Black-Box Model

## CIFAR10

In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50

base_model = ResNet50(
    include_top=False, weights='imagenet', 
    input_shape=INPUT_SHAPE, pooling='avg'
)
base_model.trainable = True

model_input = Input(shape=INPUT_SHAPE, name='input')

net = base_model(model_input)
out = Dense(10, activation='softmax')(net)

bb_model = Model(model_input, out)

model_weights_path = 'MODEL/OM/model_weights.h5'

bb_model.load_weights(model_weights_path)
bb_model.trainable = False

## Imagenette

In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50

INPUT_SHAPE = (224,224,3)

base_model = ResNet50(
    include_top=True, weights='imagenet', 
    input_shape=INPUT_SHAPE
)
base_model.trainable = False

model_input = Input(shape=INPUT_SHAPE, dtype='float32', name='input')

net = base_model(model_input)
out = Dense(10, activation='softmax')(net)

bb_model = Model(model_input, out)

model_weights_path = 'MODEL/OM/model_weights.h5'

bb_model.load_weights(model_weights_path)
bb_model.trainable = False

# LightningSHAP

## Code

In [None]:
def validate_STFS(model, loss_fn1, loss_fn2, data_loader, batch_size, num_samples, sampler, sampler_surr, paired_sampling, epoch):
    #print('validate_STFS')
    with torch.no_grad():
        # Setup.
        device = next(model.model.parameters()).device
        mean_loss = 0
        mean_loss1 = 0
        mean_loss2 = 0
        mean_loss3 = 0
        mean_loss4 = 0
        N = 0
        link=nn.Softmax(dim=-1)

        # COMPUTE NULL COALITION
        sample=data_loader.dataset[0][0]
        sample = sample.to(device)
        zeros=torch.zeros(1, model.num_players, device=device)
        zeros=model.resize(zeros)
        null=model(sample, zeros)
        null_reshape = null.reshape(1, -1, model.num_players)
        null_reshape = null_reshape.permute(0, 2, 1)
        null_sum = null_reshape.sum(dim=1)
        null=link(null_sum)

        # print("VALIDATION")

        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            # Generate subsets.
            S = sampler.sample(batch_size*num_samples, paired_sampling=paired_sampling).to(device=device)
            S_surr = sampler_surr.sample(batch_size).to(device=device)

            S_surr = model.resize(S_surr)

            pred_xs = model(x, S_surr)
            pred_xs_reshape = pred_xs.reshape(len(x), -1, model.num_players)
            pred_xs_reshape = pred_xs_reshape.permute(0, 2, 1)
            # print("pred_xs_reshape",pred_xs_reshape.shape)
            pred_xs_sum = pred_xs_reshape.sum(dim=1)

            # print("pred_xs_sum",pred_xs_sum.shape)
            # print("y",y.shape)

            loss1 = loss_fn1(pred_xs_sum, y)

            ones=torch.ones_like(S_surr).to(device)
            pred=model(x, ones)
            pred_reshape = pred.reshape(len(x), -1, model.num_players)
            pred_reshape = pred_reshape.permute(0, 2, 1)
            grand_sum = pred_reshape.sum(dim=1)
            grand=link(grand_sum)
            
            pred_eff = additive_efficient_normalization(pred_reshape, y, null) ################### NORMALIZATION WITH Y
            total=pred_eff.sum(dim=1)

            x_tiled = x.unsqueeze(1).repeat(
                1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                ).reshape(batch_size * num_samples, *x.shape[1:])
            
            S1=model.resize(S)
            val = model(x_tiled, S1)
            val_reshape = val.reshape(len(x_tiled), -1, model.num_players)
            val_reshape = val_reshape.permute(0, 2, 1)
            val_sum = val_reshape.sum(dim=1)

            values = link(val_sum)
            
            S=S.reshape(batch_size, num_samples, model.num_players)
            values=values.reshape(batch_size, num_samples, -1)

            approx = null + torch.matmul(S, pred_eff)
            # loss2 = loss_fn2(approx, values)
            loss4 = loss_fn2(y, grand)

            loss1=loss1*10 # 100  
            loss4=loss4*50 # 100
                  
            if epoch>=model.wait:
                loss2 = loss_fn2(approx, values)
                loss2 = loss2 * model.num_players
                # loss4 = loss4 #* model.num_players
                loss = loss1 + loss2 + loss4#*self.num_players
                
            else:
                loss2 = 0
                # loss4 = loss4 #* model.num_players
                loss = loss1 + loss4
        

            N += len(x)
            mean_loss += len(x) * (loss - mean_loss) / N
            mean_loss1 += len(x) * (loss1 - mean_loss1) / N
            mean_loss2 += len(x) * (loss2 - mean_loss2) / N
            # mean_loss3 += len(x) * (loss3 - mean_loss3) / N
            mean_loss4 += len(x) * (loss4 - mean_loss4) / N
            
    del loss1, loss2 #, loss4
    return mean_loss, mean_loss1, mean_loss2, mean_loss4

def generate_labels_STFS(dataset, model, batch_size):

    with torch.no_grad():
        # Setup.
        preds = []
        if isinstance(model, torch.nn.Module):
            device = next(model.parameters()).device
        else:
            device = torch.device('cpu')
        loader = DataLoader(dataset, batch_size=batch_size)

        for x in tqdm(loader):
            pred = model(x.to(device)).cpu()
            preds.append(pred)

    return torch.cat(preds)

def additive_efficient_normalization(pred, grand, null):
    gap = (grand - null) - torch.sum(pred, dim=1)
    return pred + gap.unsqueeze(1) / pred.shape[1]


def multiplicative_efficient_normalization(pred, grand, null):
    ratio = (grand - null) / torch.sum(pred, dim=1)
    return pred * ratio.unsqueeze(1)


class LightningSHAP:

    def __init__(self, model, om, width, height, superpixel_size=1, groups=None):
        # Store surrogate model.
        self.model = model
        self.batch_size = None
        self.validation_batch_size = None
        self.num_samples = None
        self.link = None
        self.bbm=om

        self.width = width
        self.height = height
        self.supsize = superpixel_size
        if superpixel_size == 1:
            self.upsample = nn.Identity()
        else:
            self.upsample = nn.Upsample(
                scale_factor=superpixel_size, mode='nearest')

        self.small_width = width // superpixel_size
        self.small_height = height // superpixel_size
        self.num_players = self.small_width * self.small_height

    def resize(self, S):
        if len(S.shape) == 2:
            S = S.reshape(S.shape[0], self.small_height, self.small_width).unsqueeze(1)
        return self.upsample(S)
                
    def train_original_model(self,
                             train_data,
                             val_data,
                             original_model,
                             batch_size,
                             max_epochs,
                             loss_fn1,
                             loss_fn2,
                             validation_samples=1,
                             validation_batch_size=None,
                             lr=None,
                             min_lr=None,
                             lr_factor=None,
                             weight_decay=None,
                             lookback=None,
                             num_samples=None,
                             training_seed=None,
                             validation_seed=None,
                             paired_sampling=False,
                             bar=False,
                             verbose=False,
                             debug=False,
                             wait=0
                             ):

        # Set up train dataset.
        if isinstance(train_data, np.ndarray):
            # print('train_data numpy')
            train_data = torch.tensor(train_data, dtype=torch.float32)

        if os.path.isfile('Cifar10_224_TF_lables_train.pkl'):
            print('Loading saved labels')
            with open('Cifar10_224_TF_lables_train.pkl', 'rb') as f:
                y_tr = pickle.load(f)
        else:
            y_tr = generate_labels_STFS(train_data, original_model, batch_size)
            with open('Cifar10_224_TF_lables_train.pkl', 'wb') as f:
                pickle.dump(y_tr, f)

        if isinstance(train_data, torch.Tensor):
            # print('train_data tensor')
            train_set = TensorDataset(train_data, y_tr)
        elif isinstance(train_data, Dataset):
            # print('train_data dataset')
            train_set = train_data
        else:
            raise ValueError('train_data must be either tensor or a PyTorch Dataset')

        # Set up train data loader.
        random_sampler = RandomSampler(train_set, replacement=True, num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
        batch_sampler = BatchSampler(random_sampler, batch_size=batch_size, drop_last=True)
        train_loader = DataLoader(train_set, batch_sampler=batch_sampler, num_workers=4)

        # Set up validation dataset.
        sampler_surr=UniformSampler(self.num_players)
        sampler = ShapleySampler(self.num_players)
        if validation_seed is not None:
            torch.manual_seed(validation_seed)

        if validation_batch_size is None:
            validation_batch_size = batch_size

        if isinstance(val_data, np.ndarray):
            # print('val_data numpy')
            val_data = torch.tensor(val_data, dtype=torch.float32)

        if isinstance(val_data, torch.Tensor):
            if os.path.isfile('Cifar10_224_TF_lables_val.pkl'):
                print('Loading saved labels')
                with open('Cifar10_224_TF_lables_val.pkl', 'rb') as f:
                    y_val = pickle.load(f)
            else:
                y_val = generate_labels_STFS(val_data, original_model, validation_batch_size)
                with open('Cifar10_224_TF_lables_val.pkl', 'wb') as f:
                    pickle.dump(y_val, f)
            # y_val = generate_labels_STFS(val_data, original_model, validation_batch_size)
            val_set = TensorDataset(val_data, y_val)
        else:
            raise ValueError('val_data must be either tuple of tensors or a PyTorch Dataset')

        val_loader = DataLoader(val_set, batch_size=validation_batch_size, drop_last=True, num_workers=4)

        self.batch_size = batch_size
        self.validation_batch_size = validation_batch_size
        self.num_samples = num_samples 
        self.wait=wait

        # Setup for training.
        link=nn.Softmax(dim=-1)
        model = self.model
        device = next(model.parameters()).device
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_factor, patience=int(lookback // 2), min_lr=min_lr,verbose=verbose)
        best_loss = 100000000
        best_epoch = 0
        best_model = deepcopy(model)
        val_loss_list = []
        val_loss1_list = []
        val_loss2_list = []
        train_loss_list = []
        train_loss1_list = []
        train_loss2_list = []
        if training_seed is not None:
            torch.manual_seed(training_seed)

        print('OPT_training')
        for epoch in range(max_epochs):
            # Batch iterable.
            if bar:
                batch_iter = tqdm(train_loader, desc='Training epoch')
            else:
                batch_iter = train_loader

            mean_loss = 0
            mean_loss1 = 0
            mean_loss2 = 0
            mean_loss3 = 0
            mean_loss4 = 0
            N = 0

            iter=0
            for (x,y) in batch_iter:
                iter+=1
                # Prepare data.
                x = x.to(device)
                y = y.to(device)

                # Generate subsets.
                S = sampler.sample(batch_size*num_samples, paired_sampling=paired_sampling).to(device=device)
                S_surr = sampler_surr.sample(batch_size).to(device=device)

                if debug:
                    print("x",x.shape)
                    print("S",S.shape)
                    print("S_surr",S_surr.shape)
                S_surr = self.resize(S_surr)
                if debug:
                    print("S_surr reshape",S_surr.shape)
                    print("y",y.shape,y)

                pred_xs = self.__call__(x, S_surr)
                if debug:
                    print("pred_xs",pred_xs.shape)
                pred_xs_reshape = pred_xs.reshape(len(x), -1, self.num_players)
                if debug:
                    print("pred_xs_reshape",pred_xs_reshape.shape)
                pred_xs_reshape = pred_xs_reshape.permute(0, 2, 1)
                if debug:
                    print("pred_xs permute",pred_xs_reshape.shape)
                pred_xs_sum = pred_xs_reshape.sum(dim=1)
                if debug:   
                    print("pred_xs_sum",pred_xs_sum.shape)

                loss1 = loss_fn1(pred_xs_sum, y)

                # COMPUTE NULL COALITION
                self.model.eval()
                with torch.no_grad():
                    # zeros=torch.zeros(1, self.num_players, device=device)
                    # if debug:
                    #     print("zeros",zeros.shape)
                    # zeros=self.resize(zeros)
                    # if debug:
                    #     print("zeros reshape",zeros.shape)
                    zeros=torch.zeros_like(S_surr[0]).to(device)
                    if debug:
                        print("zeros",zeros.shape)
                    null=self.__call__(x[:1], zeros)
                    if debug:
                        print("null",null.shape)
                    null_reshape = null.reshape(1, -1, self.num_players)
                    if debug:
                        print("null reshape",null_reshape.shape)
                    null_reshape = null_reshape.permute(0, 2, 1)
                    if debug:
                        print("null permute",null_reshape.shape)
                    null_sum = null_reshape.sum(dim=1)
                    null=link(null_sum)

                self.model.train()
                if debug:
                    print("null",null.shape)

                ones=torch.ones_like(S_surr).to(device)
                if debug:
                    print("ones",ones.shape)
                # ones=self.resize(ones)
                # print("ones reshape",ones.shape)
                pred=self.__call__(x, ones)
                if debug:
                    print("pred",pred.shape)
                pred_reshape = pred.reshape(len(x), -1,self.num_players)
                if debug:
                    print("pred_reshape",pred_reshape.shape)
                pred_reshape = pred_reshape.permute(0, 2, 1)
                if debug:
                    print("pred_reshape permute",pred_reshape.shape)
                grand_sum = pred_reshape.sum(dim=1)
                grand=link(grand_sum)
                if debug:
                    print("grand",grand.shape)
                
                pred_eff = additive_efficient_normalization(pred_reshape, y, null)
                if debug:
                    print("pred_eff",pred_eff.shape)
                total=pred_eff.sum(dim=1)
                if debug:
                    print("total",total.shape)


                x_tiled = x.unsqueeze(1).repeat(
                    1, num_samples, *[1 for _ in range(len(x.shape) - 1)]
                    ).reshape(batch_size * num_samples, *x.shape[1:])
                if debug:
                    print("x_tiled",x_tiled.shape)

                S1=self.resize(S)
                if debug:
                    print("S reshape",S1.shape)
                
                val = self.__call__(x_tiled, S1)
                if debug:
                    print("val",val.shape)
                val_reshape = val.reshape(len(x_tiled), -1, self.num_players)
                if debug:
                    print("val_reshape",val_reshape.shape)
                val_reshape = val_reshape.permute(0, 2, 1)
                if debug:
                    print("val_reshape permute",val_reshape.shape)
                val_sum = val_reshape.sum(dim=1)

                values = link(val_sum)
                if debug:
                    print("values",values.shape)
                
                values=values.reshape(batch_size, num_samples, -1)
                if debug:
                    print("values reshape",values.shape)
                S=S.reshape(batch_size, num_samples, self.num_players)
                if debug:
                    print("S reshape",S.shape)
                
                approx = null + torch.matmul(S, pred_eff)
                if debug:
                    print("approx",approx.shape)

                
                loss4 = loss_fn2(y, grand)
                loss4=loss4*50 # 100
                loss1=loss1*10  

                if epoch>=wait:
                    loss2 = loss_fn2(approx, values)
                    loss2 = loss2 * self.num_players
                    # loss4 = loss4 #* self.num_players
                    loss = loss1 + loss2 + loss4#*self.num_players
                    
                else:
                    loss2 = 0
                    # loss4 = loss4 #* self.num_players
                    loss = loss1 #+ loss4
            

                N += len(x)
                mean_loss += len(x) * (loss - mean_loss) / N
                mean_loss1 += len(x) * (loss1 - mean_loss1) / N
                mean_loss2 += len(x) * (loss2 - mean_loss2) / N
                # mean_loss3 += len(x) * (loss3 - mean_loss3) / N
                mean_loss4 += len(x) * (loss4 - mean_loss4) / N
                

                # Optimizer step.
                loss.backward()
                optimizer.step()
                model.zero_grad()


            if epoch==wait:
                print("reset best loss")
                best_loss=100000000

            if verbose:
                print('----- Epoch = {} -----'.format(epoch + 1))

                if epoch>=wait:
                    print('Train loss = {:.6f}'.format(mean_loss))
                    print('Train loss1 = {:.6f}'.format(mean_loss1))
                    print('Train loss2 = {:.6f}'.format(mean_loss2))
                    print('Train loss4 = {:.6f}'.format(mean_loss4))
                else:
                    print('Train loss = {:.6f}'.format(mean_loss))
                    print('Train loss1 = {:.6f}'.format(mean_loss1))
                    print('Train loss4 = {:.6f}'.format(mean_loss4))
                    # print('Train loss2 = {:.6f}'.format(mean_loss2))
                    
                print('')

            # Evaluate validation loss.
            self.model.eval()
            val_loss, val_loss1, val_loss2, val_loss4 = validate_STFS(self, loss_fn1, loss_fn2, val_loader,  batch_size, num_samples, sampler, sampler_surr, paired_sampling, epoch)#.item()
            self.model.train()

            # Print progress.
            if verbose:
                #print('----- Epoch = {} -----'.format(epoch + 1))
                if epoch>=wait:
                    print('Val loss = {:.6f}'.format(val_loss))
                    print('Val loss1 = {:.6f}'.format(val_loss1))
                    print('Val loss2 = {:.6f}'.format(val_loss2))
                    print('Val loss4 = {:.6f}'.format(val_loss4))
                else:
                    print('Val loss = {:.6f}'.format(val_loss))
                    print('Val loss1 = {:.6f}'.format(val_loss1))
                    print('Val loss4 = {:.6f}'.format(val_loss4))
                print('')

            scheduler.step(val_loss)
            val_loss_list.append(val_loss)
            val_loss1_list.append(val_loss1)
            val_loss2_list.append(val_loss2)
            train_loss_list.append(mean_loss)
            train_loss1_list.append(mean_loss1)
            train_loss2_list.append(mean_loss2)

            # Check if best model.
            if val_loss < best_loss and epoch>0:
                best_loss = val_loss
                best_model = deepcopy(model)
                best_epoch = epoch
                if verbose:
                    print('\t=> New best epoch, loss = {:.4f}'.format(val_loss))
                    print('')
            elif epoch - best_epoch == lookback:
                if verbose:
                    print('Stopping early')
                break

        # Clean up.
        for param, best_param in zip(model.parameters(), best_model.parameters()):
            param.data = best_param.data
            
        self.val_loss_list = val_loss_list
        self.val_loss1_list = val_loss1_list
        self.val_loss2_list = val_loss2_list
        self.train_loss_list = train_loss_list
        self.train_loss1_list = train_loss1_list
        self.train_loss2_list = train_loss2_list
        self.model.eval()


    def __call__(self, x, S):

        return self.model((x,S))
    

    def shap_values(self, x, debug=False):

        # Data conversion.
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        elif isinstance(x, torch.Tensor):
            pass
        else:
            raise ValueError('data must be np.ndarray or torch.Tensor')

        # Ensure null coalition is calculated.
        device = next(self.model.parameters()).device
        link=nn.Softmax(dim=-1)
        x=x.to(device)
        
        # Generate explanations.
        with torch.no_grad():
            # Calculate grand coalition (for normalization).
            if debug:
                print("x",x.shape)
            

            # zeros=torch.zeros(1, self.num_players, device=device)
            # print("zeros",zeros.shape)
            # zeros=self.resize(zeros)
            # print("zeros",zeros.shape)

            zeros=torch.zeros(1, x.shape[-2], x.shape[-1], device=device)
            if debug:
                print("zeros",zeros.shape)

            # ones=torch.ones(x.shape[0], self.num_players, device=device)
            # print("ones",ones.shape)
            # ones=self.resize(ones)
            # print("ones",ones.shape)

            # ones=torch.ones_like(x).to(device)
            ones=torch.ones(x.shape[0], 1, x.shape[-2], x.shape[-1], device=device)
            if debug:
                print("ones",ones.shape)

            null=self.__call__(x[0],zeros)
            if debug:
                print("null",null.shape)
            null_reshape = null.reshape(1, -1, self.num_players)
            if debug:
                print("null_reshape",null_reshape.shape)
            null_reshape = null_reshape.permute(0, 2, 1)
            if debug:
                print("null_reshape permute",null_reshape.shape)
            null_sum = null_reshape.sum(dim=1)
            null=link(null_sum)
            # if len(null.shape) == 1:
            #     null = null.reshape(1, 1)

            # ones=torch.ones(1, self.num_players, device=device)
            # ones=torch.ones_like(zeros).to(device)
            # match the dimensionality of the input
            # ones=self.resize(ones
            
            # make ones of the follwing dimensionality(x.shape[0],1,x.shape[2],x.shape[3])
            # ones=torch.ones(1, self.num_players, device=device)
            # ones=self.resize(ones)



            pred=self.__call__(x, ones)
            if debug:
                print("pred",pred.shape)
            image_shape=pred.shape
            pred_reshape = pred.reshape(len(x), -1, self.num_players)
            if debug:
                print("pred_reshape",pred_reshape.shape)
            pred_reshape = pred_reshape.permute(0, 2, 1)
            if debug:
                print("pred_reshape permute",pred_reshape.shape)
            
            # grand_sum = pred_reshape.sum(dim=1)
            # grand=link(grand_sum)

            y=self.bbm(x)

            pred = pred_reshape #additive_efficient_normalization(pred_reshape, y, null)
            # pred = additive_efficient_normalization(pred_reshape, y, null)

            pred = pred.permute(0, 2, 1)
            pred = pred.reshape(image_shape)

        return pred.cpu().data.numpy()

## Train

In [None]:
from importlib import reload
import utils.fastshap
import utils.unet
reload(utils.unet)
reload(utils.fastshap)
from utils.unet import UNet
from utils.fastshap import FastSHAP
import utils.utils
reload(utils.utils)
from utils.utils import MaskLayer2d, KLDivLoss, DatasetInputOnly, UniformSampler, ShapleySampler, DatasetRepeat

In [None]:
class ModelWrapperTFtoPT():
    def __init__(self, model,device):
        self.model = model
        self.device = device

    def __call__(self, x):
        # transform x from tensor to numpy
        x = x.cpu().numpy()
        pred=self.model.predict(np.transpose(x,(0,2,3,1)))
        # transform pred from numpy to tensor
        pred = torch.tensor(pred, dtype=torch.float32).to(self.device)
        return pred

In [None]:
# Select device
device = torch.device('cuda:0')

In [None]:
original_model = ModelWrapperTFtoPT(bb_model, device)

In [None]:
train_set_tensor=torch.tensor(train_numpy, dtype=torch.float32)
val_set_tensor=torch.tensor(val_numpy, dtype=torch.float32)
train_set_tensor.shape, val_set_tensor.shape

In [None]:
from LightningSHAP_Image_NN import ModifiedResNet50

In [None]:
# Check for model
if os.path.isfile('Imagenette_LS_TF_N_x10_NS=4_L4_W2.pt'): ####################################################################
    print('Loading saved explainer model')
    explainer = torch.load('Imagenette_LS_TF_N_x10_NS=4_L4_W2.pt').to(device)
    lshap = LightningSHAP(explainer, original_model, width=224, height=224, superpixel_size=16)

else:
    # Set up explainer model
    explainer = nn.Sequential(
        MaskLayer2d(value=0, append=True),
        # UNet(n_classes=10, num_down=2, num_up=1, num_convs=10, in_channels=4)
        # Explainer18(num_classes=10, in_channels=4)
        ModifiedResNet50(num_input_channels=4, num_classes=10, init_type='kaiming_uniform')
    ).to(device)
    
    # print(explainer)
    # original_model = nn.Sequential(bb_model, nn.Softmax(dim=1))

    # Set up FastSHAP object
    lshap = LightningSHAP(explainer, original_model, width=224, height=224, superpixel_size=16)

    # Set up datasets
    lshap_train = train_set_tensor 
    lshap_val = val_set_tensor

    # Train
    lshap.train_original_model(
        lshap_train,
        lshap_val,
        original_model,
        batch_size=64,
        num_samples=4,
        max_epochs=100,
        paired_sampling=True,
        # eff_lambda=1e-2,
        loss_fn1=KLDivLoss(), #KLDivLoss(),
        loss_fn2=nn.MSELoss(), #KLDivLoss(),
        validation_samples=1,
        lookback=10,
        lr=2e-4,#2e-4
        min_lr=1e-8,
        weight_decay=1e-2, ######################################
        lr_factor=0.5,
        verbose=True,
        debug=False,
        bar=True,
        wait=0
    )

In [None]:
Save explainer
explainer.cpu()
torch.save(explainer, f'{dataset}_LS.pt')
explainer.to(device)

# Results

## Compute SV

In [None]:
images_dir = os.path.join('', 'images')
images = np.load(os.path.join(images_dir, 'processed_images.npy'), allow_pickle=True)
labels = np.load(os.path.join(images_dir, 'labels.npy'), allow_pickle=True)
predictions = np.load(os.path.join(images_dir, 'predictions.npy'), allow_pickle=True)

In [None]:
ds_test = ds_test.unbatch()

In [None]:
processed_imgs = []
labels2 = []
for i, (x, y) in enumerate(ds_test):
    processed_imgs.append(x.numpy())
    labels2.append(y.numpy())
    if i >= 999:
        break

In [None]:
for el1, el2 in zip(labels, labels2):
    # print(el1, el2)
    if el1 != np.argmax(el2):
        print(el1, el2)

In [None]:
for el1, el2 in zip(images, processed_imgs):
    if not np.allclose(el1, el2):
        print('not equal')

In [None]:
for el1, el2 in zip(test_numpy[:1000], processed_imgs):
    tmp=np.transpose(el2,(2,0,1))
    if not np.allclose(el1, tmp):
        print('not equal')

In [None]:
to_process = test_numpy[:1000]
to_process = torch.tensor(to_process)
to_process.shape

In [None]:
t = time.time()
values=lshap.shap_values(to_process.to(device))
explaining_time = time.time() - t
print(values.shape)
values2 = np.repeat(values, 16, axis=2)
values2 = np.repeat(values2, 16, axis=3)
print(values2.shape)
print(f"Explaining time: {explaining_time}")

In [None]:
with open(os.path.join("MODEL/LIGHTNINGSHAP", 'explaining_time.pkl'), 'wb') as f:
    pickle.dump(explaining_time, f)

In [None]:
values3 = np.transpose(values2, (1, 0, 2, 3))
values3.shape

In [None]:
with open(os.path.join("MODEL/LIGHTNINGSHAP", 'shap_values.pkl'), 'wb') as f:
    pickle.dump(values3, f)

## Visualization

In [None]:
import matplotlib.pyplot as plt

In [None]:
targets=[np.argmax(y.numpy()) for x, y in tqdm(ds_val)] #############################################
targets=np.array(targets)[:1000]
print(targets.shape)

num_classes = np.max(targets) + 1
num_classes

inds_lists = [np.where(targets == cat)[0] for cat in range(num_classes)]
inds = [np.random.choice(cat_inds) for cat_inds in inds_lists]
print(inds)

x=val_numpy[inds] ############################################
y=targets[inds]

x = torch.tensor(x)
print(x.shape, y)

pred=original_model(x.to(device)).cpu().data.numpy()

values=lshap.shap_values(x.to(device))
values.shape

# upscale values from 10x10x14x14 to 10x10x224x224. use np.repeat for this
values2 = np.repeat(values, 16, axis=2)
values2 = np.repeat(values2, 16, axis=3)
values2.shape

In [None]:
def restore_original_image_from_array(x, data_format='channels_last'):
    mean = [103.939, 116.779, 123.68]

    # Zero-center by mean pixel
    if data_format == 'channels_first':
        if x.ndim == 3:
            x[0, :, :] += mean[0]
            x[1, :, :] += mean[1]
            x[2, :, :] += mean[2]
        else:
            x[:, 0, :, :] += mean[0]
            x[:, 1, :, :] += mean[1]
            x[:, 2, :, :] += mean[2]
    else:
        x[..., 0] += mean[0]
        x[..., 1] += mean[1]
        x[..., 2] += mean[2]

    if data_format == 'channels_first':
        # 'BGR'->'RGB'
        if x.ndim == 3:
            x = x[::-1, ...]
        else:
            x = x[:, ::-1, ...]
    else:
        # 'BGR'->'RGB'
        x = x[..., ::-1]

    return x

In [None]:
x2=restore_original_image_from_array(x.numpy(), data_format='channels_first')

In [None]:
fig, axarr = plt.subplots(num_classes, num_classes + 1, figsize=(22, 20))

for row in range(num_classes):
    # Image
    classes = ['Tench', 'English\n Springer', 'Cassette\n Player', 'Chain Saw', 'Church', 'French\n Horn', 'Garbage\n Truck', 'Gas Pump', 'Golf Ball', 'Parachute']
    # classes = ['Gas Pump', 'Tench', 'English\n Springer', 'Chain Saw', 'Church', 'French\n Horn', 'Garbage\n Truck', 'Golf Ball', 'Cassette\n Player', 'Parachute']
    im = x2[row]
    im = im.transpose(1, 2, 0).astype(float)
    if im.max() > 1.0:
        im = im / 255.0
    im = np.clip(im, a_min=0, a_max=1)
    axarr[row, 0].imshow(im) #, vmin=0, vmax=1)
    axarr[row, 0].set_xticks([])
    axarr[row, 0].set_yticks([])
    axarr[row, 0].set_ylabel('{}'.format(classes[y[row]]), fontsize=14)
    
    # Explanations
    m = np.abs(values2[row]).max()
    for col in range(num_classes):
        axarr[row, col + 1].imshow(values2[row, col], cmap='seismic', vmin=-m, vmax=m)
        axarr[row, col + 1].set_xticks([])
        axarr[row, col + 1].set_yticks([])
        if col == y[row]:
            axarr[row, col + 1].set_xlabel('{:.2f}'.format(pred[row, col]), fontsize=12, fontweight='bold')
        else:
            axarr[row, col + 1].set_xlabel('{:.2f}'.format(pred[row, col]), fontsize=12)
        
        # Class labels
        if row == 0:
            axarr[row, col + 1].set_title('{}'.format(classes[y[col]]), fontsize=14)

plt.tight_layout()
plt.show()