In [None]:
import os, math, collections

if not os.path.exists('CIFAR10_data'):
    
    !wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    !mkdir CIFAR10_data
    !tar -xf cifar-10-python.tar.gz -C CIFAR10_data

from tqdm import tqdm

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

from utils import save, scale, unpickle, batch_run, preprocess
from deepexplain.tensorflow import DeepExplain
from models import CIFAR_CNN
from trainer import Trainer

datadir = 'CIFAR10_data/cifar-10-batches-py/'
batches_train = sorted([datadir + batch for batch in os.listdir(datadir) if 'data_batch' in batch], key=lambda x: int(x[-1]))
batch_test = datadir + 'test_batch'

In [None]:
for i in tqdm(range(5)):
    
    batch = unpickle(batches_train[i])

    if i == 0:
        data = batch[b'data'].astype(np.float32)
        cifar = np.transpose(np.reshape(data, [-1,3,32,32]), [0,2,3,1])
        labels = batch[b'labels']
    else:
        data = batch[b'data'].astype(np.float32)
        cifar = np.concatenate((cifar, np.transpose(np.reshape(data, [-1,3,32,32]), [0,2,3,1])), axis=0)
        labels = np.concatenate((labels, batch[b'labels']), axis=0)

test_batch = unpickle(batch_test)
cifar_test = np.transpose(np.reshape(test_batch[b'data'], [-1,3,32,32]), [0,2,3,1])
labels_test = np.array(test_batch[b'labels'])

data_train = (cifar / 127.5 - 1.0, labels)
data_test = (cifar_test / 127.5 - 1.0, labels_test)

cifar_mean = np.mean(cifar, axis=(0,1,2))

In [None]:
def random_remove(images, percentile, keep=False):
    
    images = np.copy(images)
    
    mask = np.random.binomial(1, (100 - percentile) / 100, size=images.shape[:-1])
    
    if keep:
        images[mask == 1] = cifar_mean
    else:
        images[mask == 0] = cifar_mean
    
    return images


def remove(images, attributions, percentile, keep=False, random=False):
    '''
    images       : tensor of shape [N,H,W,C]
    attributions : tensor of shape [N,H,W]
    percentile   : scalar between 0 and 100, inclusive
    keep         : if true keep q percent; otherwise remove q percent
    '''
    
    images = np.copy(images)
    
    thresholds = np.percentile(attributions, 100 - percentile, axis=(1,2), keepdims=True)
    
    if keep:
        images[attributions < thresholds] = cifar_mean
    else:
        images[attributions > thresholds] = cifar_mean
    
    return images


def occlude_dataset(DNN, de, attribution, percentiles, test=False, keep=False, random=False, batch_size=1000, savedir=''):
    
    if test:
        Xs = cifar_test
        ys = labels_test
    else:
        Xs = cifar
        ys = labels
    
    total_batch = math.ceil(len(Xs) / batch_size)
    
    if not random:
    
        hmaps = []

        for i in tqdm(range(total_batch)):

            batch_xs = Xs[i*batch_size:(i+1)*batch_size]
            batch_xs_scaled = scale(batch_xs)

            if 'edge' in attribution:
                attrs = Canny(batch_xs)
            else:
                
                if 'rectgrad' in attribution:
                    attrs = de.explain(attribution, DNN.yv, DNN.X, batch_xs_scaled)
                    attrs = np.sum(np.where(attrs > 0, attrs, 0.0), axis=-1)
                else:
                    attrs = preprocess(de.explain(attribution, DNN.yv, DNN.X, batch_xs_scaled), 0, 100)
            
            # Add small noise so np.percentile functions correctly in feature removal process
            attrs += np.random.normal(scale=1e-4, size=attrs.shape)
            hmaps.append(attrs)

        hmaps = np.concatenate(hmaps, axis=0)
    
    for percentile in tqdm(percentiles):
        
        dataset = []
        
        for i in range(total_batch):
            
            batch_xs, batch_ys = Xs[i*batch_size:(i+1)*batch_size], ys[i*batch_size:(i+1)*batch_size]
            
            if random:
                occluded_images = random_remove(batch_xs, percentile, keep)
            else:
                batch_attrs = hmaps[i*batch_size:(i+1)*batch_size]
                occluded_images = remove(batch_xs, batch_attrs, percentile, keep)
            
            dataset.append(scale(occluded_images))
        
        save(np.concatenate(dataset, axis=0), savedir + '{}_{}_{}.pickle'.format('test' if test else 'train', attribution, percentile))


def roar_kar(keep, train_only=False):
    
    logdir = 'tf_logs/standard/'
    
    def get_savedir():
        
        savedir = logdir.replace('tf_logs', 'KAR' if keep else 'ROAR')
        
        if not os.path.exists(savedir):
            
            os.makedirs(savedir)
        
        return savedir
    
    percentiles = [10, 30, 50, 70, 90]
    
    heavy_methods = ['SmoothGrad', 'IntegGrad', 'DeepLIFT']

    attribution_methods = [
                           ('Random'           , 'zero'),
                           ('RectGrad'         , 'rectgrad'),
                           ('Saliency Map'     , 'saliency'),
                           ('Guided BP'        , 'guidedbp'),
                           ('SmoothGrad'       , 'smoothgrad'),
                           ('Gradient * Input' , 'grad*input'),
                           ('IntegGrad'        , 'intgrad'),
                           ('DeepLIFT'         , 'deeplift')
                          ]

    attribution_methods = collections.OrderedDict(attribution_methods)
    
    if not train_only:
    
        tf.reset_default_graph()

        sess = tf.InteractiveSession()

        with DeepExplain(session=sess, graph=sess.graph) as de:

            DNN = CIFAR_CNN(logdir, 'CNN')
            DNN.load(sess)

            for k, v in attribution_methods.items():

                if k in heavy_methods:
                    batch_size = 2500
                else:
                    batch_size = 5000

                occlude_dataset(DNN, de, v, percentiles, False, keep, k == 'Random', batch_size, get_savedir())
                occlude_dataset(DNN, de, v, percentiles, True,  keep, k == 'Random', batch_size, get_savedir())

        sess.close()
    
    ress = {k : [] for k in attribution_methods.keys()}
    
    for _ in range(3):
    
        for k, v in attribution_methods.items():
            
            res = []

            for p in percentiles:
                
                occdir = get_savedir() + '{}_{}_{}.pickle'.format('{}', v, p)
                data_train = (unpickle(occdir.format('train')), labels)
                data_test = (unpickle(occdir.format('test')), labels_test)
                
                tf.reset_default_graph()

                DNN = CIFAR_CNN('tf_logs/roar_kar/', 'CNN')

                sess = tf.InteractiveSession()
                sess.run(tf.global_variables_initializer())

                trainer = Trainer(sess, DNN, data_train)
                trainer.train(20, p_epochs=30)

                acc = DNN.evaluate(sess, data_test)

                print('{}{} | Test Accuracy : {:.5f}'.format(k, p, acc))

                res.append(acc)

                sess.close()
            
            ress[k].append(res)
    
    res_mean = {k: np.mean(v, axis=0) for k, v in ress.items()}
    
    print(res_mean)
    
    return res_mean

In [None]:
results = roar_kar(keep=False)