In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import time
import copy
import torch
import torch.nn as nn

list_pathstoadd = ["../../../", "../../../../PyDmed/", "../../../../TGP/",\
                   "../../../../uda_and_microscopyimaging_repo2/Src/BoVWPipeline/"]
for path in list_pathstoadd:
    if(path not in sys.path):
        sys.path.append(path)
#import generalGPmodule
import skimage
from skimage import io
import torchofgp.tgp as tgp
import relatedwork
from relatedwork.utils.generativemodels import ResidualEncoder
import projutils
from projutils.kernelmodules import Resnet18List,\
        Resnet50List, SqueezeNetList,\
        MultiResnet18ListAndOneLayer,\
        MultiResnet50ListAndOneLayer, TinyResNet18List,\
        Resnet18BackboneKernel, Resnet50BackboneKernel,\
        Resnet34BackboneKernel,\
        Resnet18BackboneKernelDivideAvgPool,\
        Resnet34BackboneKernelDivideAvgPool,\
        Resnet50BackboneKernelDivideAvgPool,\
        Resnet101BackboneKernelDivideAvgPool,\
        Resnet152BackboneKernelDivideAvgPool

import akresnetforcifar
from akresnetforcifar import *
import loadcifar
from loadcifar import *
import relatedwork.utils.transforms
tfm_denormalize = relatedwork.utils.transforms.ImgnetDenormalize()

In [None]:
#settings ====
idx_trainingbatch = 1
flag_enabledataaugmentation = True
fname_gpmodel = "TrainingHistory/Phase2/Nov17_Explainattention_onBeluga/"+\
                "output_explainattention_version4_afterepoch2.pt"
int_mode_modulekernel = 16
flag_train_memefficient, memefficeint_heads_in_compgraph = False, None
du_per_class = 20
int_exposedclass = None
idx_split = 0
dim_wideoutput = 1024
dim_before_wideoutput_attention = 200
num_classes = 10
batchsize = 10
flag_efficient = True
flag_detachcovpvn = True
flag_controlvariate = True
flag_setcovtoOne = False
int_mode_controlvariate = 2

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
#make datasets ====
ds_train = Cifar10Dataset(
    rootdir = "./cifar-10-batches-py",
    fname_batchfile = "data_batch_{}".format(idx_trainingbatch),
    str_trainoreval = "train",
    flag_enabledataaugmentation = flag_enabledataaugmentation
)
ds_recurring = Cifar10Dataset(
    rootdir = "./cifar-10-batches-py",
    fname_batchfile = "data_batch_{}".format(idx_trainingbatch),
    str_trainoreval = "eval"
)
ds_test = Cifar10Dataset(
    rootdir = "./cifar-10-batches-py",
    fname_batchfile = "test_batch",
    str_trainoreval = "eval"
)

In [None]:
dl_train = DataLoader(ds_train, batch_size=batchsize,
                      shuffle=True, num_workers=0)
dl_recurring = DataLoader(ds_recurring, batch_size=batchsize,
                          shuffle=True, num_workers=0)
dl_test = DataLoader(ds_test, batch_size=batchsize,
                     shuffle=False, num_workers=0)

In [None]:
class ClampAndTanh(torch.nn.Module):
    def __init_(self, minval=-1.0, maxval=1.0):
        self.minval, self.maxval = minval, maxval
        super(ClampAndTanh, self).__init__()
        
        
    def forward(self, x):
        output = torch.nn.functional.tanh(
                        torch.clamp(
                            x, -1.0, 1.0
                        )
                    )
        return output
    
def initweights_to_zero(m):
    if type(m) in {nn.Linear, nn.Conv2d, nn.Linear}:
        torch.nn.init.zeros_(m.weight)
        m.bias.data.fill_(np.random.randn()*0.1) #TODO:check
    
class ModuleF1(torch.nn.Module):
    def __init__(self, module_caller):
        super(ModuleF1, self).__init__()
        #make internals ===
        if(int_mode_modulekernel == 1):
            #a list of resnet18s
            self.module = TinyResNet18List(
                scale_macrokernel = 1,
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 2):
            #a list of resnet50s
            self.module = Resnet50List(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 3):
            #a list of resnet18s eroded by 2
            self.module = TinyResNet18List(
                scale_macrokernel = 2,
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 4):
            #a list of squeezenets
            self.module = SqueezeNetList(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 5):
            #a list of resnet18s eroded by 4
            self.module = TinyResNet18List(
                scale_macrokernel = 4,
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 6):
            #10 instances of MultiResnet18ListAndOneLayer ====
            self.module = MultiResnet18ListAndOneLayer(
                num_classes = num_classes,
                du_per_class = du_per_class,
                num_backbones = num_backbones
            )
        elif(int_mode_modulekernel == 7):
            self.module = MultiResnet50ListAndOneLayer(
                num_classes = num_classes,
                du_per_class = du_per_class,
                num_backbones = num_backbones
            )
        elif(int_mode_modulekernel == 9):
            self.module = PrimaryNetwork(
                    num_GPs=num_classes,
                    du_per_class = du_per_class
                )
        elif(int_mode_modulekernel == 12):
            self.module = Resnet18BackboneKernel(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 13):
            self.module = Resnet50BackboneKernel(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 14):
            self.module = Resnet34BackboneKernel(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 15):
            self.module = Resnet18BackboneKernelDivideAvgPool(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 16):
            self.module = Resnet50BackboneKernelDivideAvgPool(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 17):
            self.module = Resnet34BackboneKernelDivideAvgPool(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 18):
            self.module = Resnet101BackboneKernelDivideAvgPool(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 19):
            self.module = Resnet152BackboneKernelDivideAvgPool(
                num_classes = num_classes,
                du_per_class = du_per_class
            )
        elif(int_mode_modulekernel == 21):
            self.module = Resnet50List(
                    num_classes = num_classes,
                    du_per_class = du_per_class,
                    scale_macrokernel = 2.0
            )
        elif(int_mode_modulekernel == 22):
            self.module = Resnet34List(
                    num_classes = num_classes,
                    du_per_class = du_per_class,
                    scale_macrokernel = 2.0
            )
        elif(int_mode_modulekernel == 23):
            self.module = Resnet101List(
                    num_classes = num_classes,
                    du_per_class = du_per_class,
                    scale_macrokernel = 2.0
            )
        elif(int_mode_modulekernel == 24):
            self.module = Resnet152List(
                    num_classes = num_classes,
                    du_per_class = du_per_class,
                    scale_macrokernel = 2.0
            )
        else:
            print("Unknown mode_modulekernel {}.".format(int_mode_modulekernel))
        print("<><><><><><><><><> finisehd creating module_tail <><><><><><><><>.")
    
    def set_rng_outputheads(self, rng_outputhead):
        self.module.set_rng_outputheads(rng_outputhead)
    
    
    def forward(self, x):
        toret = self.module(x)
        toret = toret.unsqueeze(-1).unsqueeze(-1)
        return toret
            
class MainModule(nn.Module):
    def __init__(self, num_classes, device, dl_recurring, dl_nonrecurring, dl_test, batchsize, dim_wideoutput):
        '''
        Inputs:
            - size_input: size of the input, e.g., [32 x 2000 x 7 x 7].
            - device: the device on which the GP fields are going to be created.
            - num_outputheads: an integer, number of output heads.
        '''
        super(MainModule, self).__init__()
        #grab args ===
        self.num_classes = num_classes
        self.device = device
        self.dl_recurring = dl_recurring
        self.dl_nonrecurring = dl_nonrecurring
        self.dl_test = dl_test
        self.batchsize = batchsize
        #self.iter_dl_recurring = iter(self.dl_recurring)
        #make internal module_tobecomeGP ===
        self.module_classifier = akresnetforcifar.ResnetClassifierWithAttention(
                num_classes = 10,
                block_classifier = akresnetforcifar.BasicBlock,
                num_blocks_classifier = [2, 2, 2, 2],
                block_attention = akresnetforcifar.BasicBlock,
                num_blocks_attention = [2, 2, 2, 2],
                dim_before_wideoutput_attention = dim_before_wideoutput_attention,
                dim_wideoutput_attention = dim_wideoutput
        )
        #make internals ===
        self.dic_dlname_to_iter = {
            "dl_recurring":iter(self.dl_recurring),
            "dl_nonrecurring":iter(self.dl_nonrecurring),
            "dl_test":iter(self.dl_test)
        }
        #make module f1 ===
        self.module_f1 = ModuleF1(self) #nn.Sequential(
        self._lastidx_recurring = []
        #internal field to subsample when feeding minbatch ====
        self.n_subsampleminibatch = None
            
    def forward(self, x, y, n):
         return self.module_classifier(x, y, n)[0], y, n
        
    
    def func_get_modulef1(self):
        return self.module_f1
    
    def func_mainmodule_to_moduletobecomeGP(self, module_input):
        return module_input.module_classifier.module_attention
    
    
    def _func_feed_minibatch(self, dl_input, str_dlname, flag_addnoisetoX = False):
        #print("reached here 1")
        if(False):#iter_dl is None):
            pass #x, y, n = next(iter(dl_input))
        else:
            try:
                x, y, n = next(self.dic_dlname_to_iter[str_dlname])
            except (StopIteration):
                self.dic_dlname_to_iter[str_dlname] = iter(dl_input)
                x, y, n = next(self.dic_dlname_to_iter[str_dlname])
                
        
        if(flag_addnoisetoX == True):
            idx_permutex = np.random.permutation([u for u in range(list(x.size())[0])]).tolist()
            idx_permutex = torch.LongTensor(idx_permutex)
            x_perumted = x[idx_permutex, :, :, :]
            
            rand_w = torch.rand((list(x.size())[0])).float().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) #[Nx1x1x1]
            rand_w = rand_w*(rng_randw[1]-rng_randw[0]) + rng_randw[0] #in [rng[0] , rng[1]]
            rand_w = rand_w.detach()
            #print("rand_w.shape = {}".format(rand_w.shape))
            x = rand_w*x + (1.0-rand_w)*x_perumted  #[N x 3 x 224 x 224]
            x = x + 0.1*torch.randn_like(x).float()
            
        
        if(self.n_subsampleminibatch is None):
            pass
        else:
            x = x[0:self.n_subsampleminibatch, :,:,:]
            y = y[0:self.n_subsampleminibatch]
            n = n[0:self.n_subsampleminibatch]
        
        if(dl_input.dataset == ds_recurring):
            self._lastidx_recurring = n
            if(len(n) != list(x.size())[0]):
                assert False
        
        output, _, _ = self.forward(x.to(self.device), y, n)
        
        #print("reached here 4")
        return output, y, n
    
    def func_feed_recurring_minibatch(self):
        output, y, n = self._func_feed_minibatch(
                            self.dl_recurring,
                            str_dlname = "dl_recurring"
                         )
        return output, y, n
    
    def func_feed_noise_minibatch(self):
        output, y, n = self._func_feed_minibatch(
                        self.dl_nonrecurring,
                        str_dlname = "dl_nonrecurring",
                        flag_addnoisetoX=True
                      )
        return output, y, n
    
    def func_feed_nonrecurring_minibatch(self):
        output, y, n = self._func_feed_minibatch(
                                self.dl_nonrecurring,
                                str_dlname = "dl_nonrecurring",
                                flag_addnoisetoX=False
                              )
        return output, y, n
    
    def func_feed_test_minibatch(self):
        output, y, n = self._func_feed_minibatch(
                                    self.dl_test,
                                    str_dlname = "dl_test"
                                )
        self._last_idx_test = n.tolist()
        return output, y, n
    
    def func_get_indices_lastrecurringinstances(self):
        return self._lastidx_recurring.cpu().numpy().tolist()
    
    def func_get_idxlastfed_testdl(self):
        return self._last_idx_test


In [None]:
model = MainModule(
    dim_wideoutput=dim_wideoutput,
    num_classes = num_classes,
    device = device,
    dl_recurring = dl_recurring,
    dl_nonrecurring = dl_train,
    dl_test = dl_test,
    batchsize = batchsize
  )
model.to(device)

In [None]:
gpmodel = tgp.TGPModule(
    module_rawmodule = model,
    size_recurringdataset = len(ds_recurring),
    device = device,
    func_mainmodule_to_moduletobecomeGP = model.func_mainmodule_to_moduletobecomeGP, 
    func_feed_noise_minibatch = model.func_feed_noise_minibatch,
    func_feed_recurring_minibatch = model.func_feed_recurring_minibatch,
    func_feed_nonrecurring_minibatch = model.func_feed_nonrecurring_minibatch,
    func_feed_test_minibatch = model.func_feed_test_minibatch,
    func_get_indices_lastrecurringinstances = model.func_get_indices_lastrecurringinstances,
    func_get_modulef1 = model.func_get_modulef1,
    flag_efficient = flag_efficient,
    flag_detachcovpvn = flag_detachcovpvn,
    flag_controlvariate = flag_controlvariate,
    flag_setcovtoOne = flag_setcovtoOne,
    int_mode_controlvariate = int_mode_controlvariate,
    flag_train_memefficient = flag_train_memefficient,
    memefficeint_heads_in_compgraph = memefficeint_heads_in_compgraph
  )
#model.n_subsampleminibatch = 50
gpmodel.sigma2_GP = 1.0 #TODO:check
gpmodel.train()
gpmodel.to(device)
print("gpmodel was created on {}".format(device))

In [None]:
gpmodel.load_state_dict(
    torch.load(
        fname_gpmodel
     ),
    strict = True
)
gpmodel.train()
gpmodel.to(device)
gpmodel.renew_precomputed_XTX()
print("gpmodel was loaded from checkpoint.")

In [None]:
def inspect_kernel(model_input, ds_input, input_device, idx_begin, idx_end, tbegin=0.0, tend=1.0):
    model_input.eval()
    t_stepsize = 0.01
    x_begin, _, _ = ds_input[idx_begin]
    x_end, _, _ = ds_input[idx_end]
    interval_toinspect = [
        x_begin+((tend-tbegin)*t+tbegin)*(x_end-x_begin)\
        for t in np.arange(0.0, 1.0, t_stepsize)
    ]
    #compute the t's for x_begin and x_end ====
    alpha, beta = tbegin, tend
    t1 = (-alpha)/(beta-alpha)
    t2 = (1.0-alpha)/(beta-alpha)
    interval_t = [t for t in np.arange(0.0, 1.0, t_stepsize)]
    idx_beforet1 = np.where(np.array(interval_t)>t1)[0][0]-1
    idx_t1 = idx_beforet1 +\
            (interval_t[idx_beforet1+1]-t1)/(interval_t[1]-interval_t[0])
    idx_beforet2 = np.where(np.array(interval_t)>t2)[0][0]-1
    idx_t2 = idx_beforet2 +\
            (interval_t[idx_beforet2+1]-t2)/(interval_t[1]-interval_t[0])
    
    output_gp, output_nn = [], []
    output_uncertainty = []
    for idx_x, x in enumerate(interval_toinspect):
        print(" instance {} from {}".format(idx_x, len(interval_toinspect)), end='\r')
        #feed to GP ======
        output, uncertainty, output_similarities = \
                model_input.testingtime_forward(x.unsqueeze(0).to(input_device), 0, 0)
        output = output[0]
        output_gp.append(output.squeeze().detach().cpu().numpy())
        output_uncertainty.append(uncertainty.flatten())
        #feed to NN ===
        netout, _, _ = model_input.module_rawmodule(x.unsqueeze(0).to(input_device), 0, 0)
        output_nn.append(netout.squeeze().detach().cpu().numpy())
    print("\n")
    output_gp = np.array(output_gp)
    output_nn = np.array(output_nn)
    output_uncertainty = np.array(output_uncertainty)
    model_input.train()
    return output_gp, output_nn, output_uncertainty, idx_t1, idx_t2

In [None]:
# gpmodel.module_f1.module.set_rng_outputheads(rng_outputhead = None)

In [None]:
# print(gpmodel.module_f1.module._rng_outputheads)

In [None]:
output_gp, output_nn, output_uncertainty, t1, t2 = inspect_kernel(
        model_input = gpmodel,
        ds_input = ds_test,
        input_device = device,
        idx_begin = 10,
        idx_end = 1100,
        tbegin = -5.0,
        tend = 5.0
)

In [None]:
def vis_inspectkernel(output_gp, output_nn, output_uncertainty, t1, t2):
    if(int_exposedclass is None):
        num_classes = output_gp.shape[1]
        for c in range(num_classes):
            fig, ax1 = plt.subplots()
            ln1 = ax1.plot(range(output_gp.shape[0]), output_gp[:,c], label="GP-mean", c='r')
            ln2 = ax1.plot(range(output_gp.shape[0]), output_nn[:,c], label="NN-mean", c='b')

            ax2 = ax1.twinx()
            ln3 = ax2.plot(range(output_gp.shape[0]), 1.0/output_uncertainty[:,c],\
                           label="1.0/GP-uncertainty", c='g')
            ax2.axvline(x=t1, color='k', linestyle='--')
            ax2.axvline(x=t2, color='k', linestyle='--')
            lns = ln1+ln2+ln3
            plt.legend(lns, [u.get_label() for u in lns], loc=0)
            plt.title("output {}".format(c))
            plt.show()
    else:
        #the output index is not None ===
        print(output_nn.shape)
        fig, ax1 = plt.subplots()
        ln1 = ax1.plot(range(output_gp.shape[0]), output_gp, label="GP-mean", c='r')
        ln2 = ax1.plot(range(output_gp.shape[0]), output_nn, label="NN-mean", c='b')
        ax2 = ax1.twinx()
        ln3 = ax2.plot(range(output_gp.shape[0]), 1.0/output_uncertainty,\
                       label="1.0/GP-uncertainty", c='g')
        ax2.axvline(x=t1, color='k', linestyle='--')
        ax2.axvline(x=t2, color='k', linestyle='--')
        lns = ln1+ln2+ln3
        plt.legend(lns, [u.get_label() for u in lns], loc=0)
        plt.title("output {}".format(int_exposedclass))
        plt.show()
        
        plt.figure()
        plt.plot(range(output_gp.shape[0]), output_gp, label="GP-mean", c='r')
        plt.axvline(x=t1, color='k', linestyle='--')
        plt.axvline(x=t2, color='k', linestyle='--')
        plt.show()

In [None]:
vis_inspectkernel(output_gp, output_nn, output_uncertainty, t1, t2)

In [None]:
def evaluate_model(model_input, ds_input, input_device):
    model_input.eval()
    with torch.no_grad():
        toret = []
        list_gty = []
        for n in range(len(ds_input)):
            if(True):#try:
                print(" instance {} from {}".format(n, len(ds_input)), end='\r')
                x, y, _ = ds_input[n]
                output, _, _ = model_input.testingtime_forward(
                                  x.unsqueeze(0).to(input_device), y, n
                               )
                #TODO:check output = output.clamp(min=-clampval_netout, max=clampval_netout)
                toret.append(output[0].detach().cpu().numpy())
                list_gty.append(y)
            #except:
            #    print("An exception occured for instnace {}".format(n))
        print("\n")
        toret = np.array(toret)
        toret = toret[:,0,:]
    model_input.train()
    return toret, list_gty

def evaluate_g(model_input, ds_input, input_device):
    model_input.eval()
    with torch.no_grad():
        toret = []
        list_gty = []
        for n in range(len(ds_input)):
            print(" instance {} from {}".format(n, len(ds_input)), end='\r')
            x, y, n = ds_input[n]
            output, _, _ = model_input(x.unsqueeze(0).to(input_device), y, n)
            toret.append(output[0,:].detach().cpu().numpy())
            list_gty.append(y)
        print("\n")
        toret = np.array(toret) #[N x 10]
        #toret = toret[:,0,:]
    model_input.train()
    return toret, list_gty

In [None]:
#evaluate g(.) i.e. the bypass network ======
#model.to(device)
predy, list_gty = evaluate_g(gpmodel.module_rawmodule, ds_test, device)
from sklearn.metrics import confusion_matrix
np_confmatrix = confusion_matrix(np.argmax(predy,1), list_gty)
print(np_confmatrix)
print("\n\n\n")
print("accuracy of g(.) = {}".format(np.sum(np_confmatrix*np.eye(9))/np.sum(np_confmatrix) ))

In [None]:
#evaluate the GP path ===
gpmodel.renew_precomputed_XTX()
predy, list_gty = evaluate_model(gpmodel, ds_test, device)
from sklearn.metrics import confusion_matrix
np_confmatrix = confusion_matrix(np.argmax(predy[:,:,0,0],1), list_gty)
print(np_confmatrix)
print("\n\n\n") 
print("accuracy = {}".format(np.sum(np_confmatrix*np.eye(9))/np.sum(np_confmatrix) ))

In [None]:
#check if g(.) and GP path match on test instances ====
FLAG_RELU = False
gpmodel.renew_precomputed_XTX()
list_outputgpoutputg = gpmodel.checkequal_f1path_gpath_ontest(10)
for n in range(len(list_outputgpoutputg)):
    a = list_outputgpoutputg[n][0]
    b = list_outputgpoutputg[n][1]
    if(FLAG_RELU == True):
        a = np_relu(a); b = np_relu(b)
    
    min_ab = min(np.min(a), np.min(b))
    max_ab = max(np.max(a), np.max(b))
    print("a-b in range [{} , {}]".format(np.min(a-b), np.max(a-b)))
    #print(np.round(a[0:5, 0:5], 2))
    #print(np.round(b[0:5, 0:5], 2))
    
    #compute the class activations ====
    gpmodel.eval()
    with torch.no_grad():
        try:
            classactivation_a = gpmodel.module_rawmodule.linear(torch.tensor(a).float().to(device))
            classactivation_b = gpmodel.module_rawmodule.linear(torch.tensor(b).float().to(device))
            classactivation_a = classactivation_a.detach().cpu().numpy()
            classactivation_b = classactivation_b.detach().cpu().numpy()
            onehot_a = np.zeros((batchsize, 16))
            onehot_a[list(range(batchsize)), np.argmax(classactivation_a, 1).tolist()] = 1
            onehot_b = np.zeros((batchsize, 16))
            onehot_b[list(range(batchsize)), np.argmax(classactivation_b, 1).tolist()] = 1

            min_classactivations = min(np.min(classactivation_a), np.min(classactivation_b))
            max_classactivations = max(np.max(classactivation_a), np.max(classactivation_b))
        except:
            onehot_a = np.zeros((batchsize, 16))
            onehot_a[list(range(batchsize)), np.argmax(a, 1).tolist()] = 1
            onehot_b = np.zeros((batchsize, 16))
            onehot_b[list(range(batchsize)), np.argmax(b, 1).tolist()] = 1
    gpmodel.train()
    
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.imshow(np.round(a, 2), vmin=min_ab, vmax=max_ab, aspect="auto"); plt.colorbar()
    plt.subplot(1,3,2)
    plt.imshow(np.round(b, 2), vmin=min_ab, vmax=max_ab, aspect="auto"); plt.colorbar()
    plt.subplot(1,3,3)
    plt.imshow(np.round(a-b, 2), cmap="seismic", vmin=-np.max(np.abs(a-b)),\
               vmax=np.max(np.abs(a-b)), aspect="auto"); plt.colorbar()
    plt.colorbar()
    plt.show()        

# ToPublish Tables

In [None]:
#check if g(.) and GP path match on test instances ====
FLAG_RELU = False
gpmodel.renew_precomputed_XTX()
list_outputgpoutputg = gpmodel.checkequal_f1path_gpath_ontest(10)
for n in range(len(list_outputgpoutputg)):
    a = list_outputgpoutputg[n][0]
    b = list_outputgpoutputg[n][1]
    if(FLAG_RELU == True):
        a = np_relu(a); b = np_relu(b)
    
    min_ab = min(np.min(a), np.min(b))
    max_ab = max(np.max(a), np.max(b))
    print("a-b in range [{} , {}]".format(np.min(a-b), np.max(a-b)))
    #print(np.round(a[0:5, 0:5], 2))
    #print(np.round(b[0:5, 0:5], 2))
    
    #compute the class activations ====
    gpmodel.eval()
    with torch.no_grad():
        try:
            classactivation_a = gpmodel.module_rawmodule.linear(torch.tensor(a).float().to(device))
            classactivation_b = gpmodel.module_rawmodule.linear(torch.tensor(b).float().to(device))
            classactivation_a = classactivation_a.detach().cpu().numpy()
            classactivation_b = classactivation_b.detach().cpu().numpy()
            onehot_a = np.zeros((batchsize, 16))
            onehot_a[list(range(batchsize)), np.argmax(classactivation_a, 1).tolist()] = 1
            onehot_b = np.zeros((batchsize, 16))
            onehot_b[list(range(batchsize)), np.argmax(classactivation_b, 1).tolist()] = 1

            min_classactivations = min(np.min(classactivation_a), np.min(classactivation_b))
            max_classactivations = max(np.max(classactivation_a), np.max(classactivation_b))
        except:
            onehot_a = np.zeros((batchsize, 16))
            onehot_a[list(range(batchsize)), np.argmax(a, 1).tolist()] = 1
            onehot_b = np.zeros((batchsize, 16))
            onehot_b[list(range(batchsize)), np.argmax(b, 1).tolist()] = 1
    gpmodel.train()
    
    plt.figure(figsize=(20,10))
    plt.subplot(1,2,1)
    plt.imshow(np.round(a, 2), vmin=min_ab, vmax=max_ab, aspect="auto"); plt.colorbar()
    plt.title("GPs output \n (batchsize x D)", font = 'Formata', fontsize = 22)
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(np.round(b, 2), vmin=min_ab, vmax=max_ab, aspect="auto"); plt.colorbar()
    plt.title("ANN output \n (batchsize x D)", font = 'Formata', fontsize = 22)
    plt.axis('off')
    #plt.show()
    plt.savefig(
        "InterpGP/ToPublish/Tables/Cifar10_Attention/{}.png".format(time.time()),
        dpi=100, bbox_inches='tight', pad_inches=0, Q=100
    )

# Compute GP diff ANN

In [None]:
dict_idx_to_gpout, dict_idx_to_annout = gpmodel.check_GP_match_ANN_on_aDataloader(
    func_feed_dlinstances = model.func_feed_test_minibatch,
    func_get_lastidx_fedinstances = model.func_get_idxlastfed_testdl,
    list_allidx = [n for n in range(len(ds_test))]
)

In [None]:
#convert the dicts to np_arrays ====
list_gpout = []
list_annout = []
for n in range(len(ds_test)):
    list_gpout.append(dict_idx_to_gpout[n])
    list_annout.append(dict_idx_to_annout[n])
np_gpout = np.array(list_gpout)
np_annout = np.array(list_annout)
print(np_gpout.shape)
print(np_annout.shape)

In [None]:
import projutils.evaluation
from projutils.evaluation import GPdiffANN_ontorchdl
from sklearn.metrics import cohen_kappa_score
from scipy import stats

list_correlation = []
for c in range(np_gpout.shape[1]):
    list_correlation.append(
        stats.pearsonr(np_gpout[:, c], np_annout[:, c])[0]
    )
print("list correl = {}".format(list_correlation))


# Depict the Correlation Scatter

In [None]:
from scipy.special import softmax
np_gp = np_gpout #np_gpann[:,0:num_classes]
np_ann = np_annout #np_gpann[:,num_classes::]
np_gp_softmax = softmax(np_gp, 1)
np_ann_softmax = softmax(np_ann, 1)

list_disaggrement = (np.argmax(np_gp, 1) != np.argmax(np_ann, 1)).tolist()
list_c = [[0,0,0,0.05] if(True) else [0,0,0,0.05] for u in list_disaggrement]
count_plotted = 0
for c in range(num_classes):
    plt.ioff()
    plt.figure()
    plt.scatter(np_gp[:,c], np_ann[:,c], c=np.array(list_c), marker='o', facecolors='none')
    plt.axis("equal")
    plt.xlabel("GP output (head {})".format(c+1), fontsize=22, font = 'Formata')
    plt.ylabel("ANN output (head {})".format(c+1), fontsize=22, font = 'Formata')
    plt.savefig(
        "InterpGP/Correlation_Scatters/Attention/{}.png".format(count_plotted),
        dpi=100, bbox_inches='tight', pad_inches=0, Q=100
    )
    count_plotted += 1
    plt.close()
    #plt.show()

# Inspect Nearest Neighbours

In [None]:
def inspect_kernel(model_input, ds_input, input_device):
    model_input.eval()
    toret = []
    list_gty = []
    list_uncertainty = []
    list_similarities = []
    list_x, list_y = [], []
    list_output_g = []
    for n in range(len(ds_input)):
        print(" instance {} from {}".format(n, len(ds_input)), end='\r')
        x, y, _ = ds_input[n]
        output, uncertainty, output_similarities = \
                model_input.testingtime_forward(x.unsqueeze(0).to(input_device), y, n)
        #print("output_similaritites.shape = {}".format(output_similarities.shape))
        output = output[0]
        toret.append(output.detach().cpu().numpy())
        list_gty.append(y)
        list_uncertainty.append(uncertainty)
        list_similarities.append(output_similarities.detach().cpu().numpy())
        list_x.append(x); list_y.append(y)
        #feed the model to g(.) ====
        output_g, _, _ = \
                model_input.module_rawmodule(x.unsqueeze(0).to(input_device), y, n)
        list_output_g.append(output_g.detach().cpu().numpy().flatten().tolist())

    print("\n")
    toret = np.array(toret)
    toret = toret[:,0,:]
    output_g = np.array(list_output_g)
    print(output_g.shape)
    model_input.train()
    return toret, list_gty, list_uncertainty, list_similarities, list_x, list_y, output_g

In [None]:
list_retval_inspectmodel = [inspect_kernel(gpmodel, ds_test, device)]

In [None]:
%%capture
m = 10
import warnings
warnings.filterwarnings(action = "once")
for n in range(len(ds_test)):
    print("n = {} ==========================".format(n))
#     fname_n, _ = ds_test._ntoimage(n)
#     fname_n = os.path.relpath(fname_n, ds_rootdir)
#     if(fname_n not in ds_split["fname_hoskys"]):
#         continue
    #fields common between all models
    _, list_gty, _, _, list_x, list_y, _ = list_retval_inspectmodel[0]
    x, y = list_x[n], list_y[n]
    
    plt.figure(figsize=((m+2)*10, 1*10))
    warnings.filterwarnings("ignore")
    count_subplot = 1
    str_subfolder = ""
    idx_model = 0 #for idx_model in range(len(list_retval_inspectmodel)):
    list_predy, list_gty, list_uncertainty, list_similarities, _, _, output_g = \
                        list_retval_inspectmodel[idx_model]
    np_argmax_list_predyn = np.argmax(output_g[n])
    if(np_argmax_list_predyn == list_gty[n]):
        str_subfolder = str_subfolder + "True"
    else:
        str_subfolder = str_subfolder + "False"
    kn = list_similarities[n][np_argmax_list_predyn, :].flatten()
    idx_similars = np.argsort(-kn).tolist()[0:m]
    plt.ioff()
    plt.subplot(1, 1*(m+1), count_subplot); count_subplot+=1;
    plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
    plt.axis('off')
    print("===== showed.")
    plt.title("gt-label = {}\n predicted = {}\n instance {}".format(
                            ds_recurring.label_names[y],
                            ds_recurring.label_names[np_argmax_list_predyn],
                            n
                ), fontsize=100
             )

    list_relevantinstances =[]
    for count_similars in range(len(idx_similars)):
        plt.subplot(1, 1*(m+2), count_subplot); count_subplot+=1;
        plt.imshow(
          tfm_denormalize(
                  ds_recurring[idx_similars[count_similars]][0]
              ).cpu().numpy().transpose(1,2,0),
        )
        plt.axis('off')
        plt.title("instnace {}".format(idx_similars[count_similars]), fontsize=100)
        list_relevantinstances.append(idx_similars[count_similars])
            
        
                
        
    if(True):#os.path.isfile("InterpGP/{}/{}.png".format(str_subfolder, n)) == False):
        plt.savefig(
                "InterpGP/{}/{}.jpg".format(str_subfolder, n),
                dpi=50, bbox_inches='tight', pad_inches=0, Q=80
            )
        print("================= saved.")
    plt.close()
    #assert False
#enable_print()

# Explain Similarities CAM-like

In [None]:
import torchofgp
import torchofgp.kernel_explainers
from torchofgp.kernel_explainers import *

In [None]:
list_retval_inspectmodel = [inspect_kernel(gpmodel, ds_test, device)]

In [None]:
%%capture
#explain the similarity itself =====
m = 10
import warnings
warnings.filterwarnings(action = "once")
for n in range(len(ds_test)):
    
    #fields common between all models
    _, list_gty, _, _, list_x, list_y, _ = list_retval_inspectmodel[0]
    x, y = list_x[n], list_y[n]
    
    plt.figure(figsize=((m+1)*10, 3*10))
    warnings.filterwarnings("ignore")
    count_subplot = 1
    str_subfolder = ""
    for idx_model in range(len(list_retval_inspectmodel)):
        list_predy, list_gty, list_uncertainty, list_similarities, _, _, output_g = \
                            list_retval_inspectmodel[idx_model]
        np_argmax_list_predyn = np.argmax(output_g[n])
        if(np_argmax_list_predyn == list_gty[n]):
            str_subfolder = str_subfolder + "True"
        else:
            str_subfolder = str_subfolder + "False"
        kn = list_similarities[n][np_argmax_list_predyn, :].flatten()
        idx_similars = np.argsort(-kn).tolist()[0:m]
        plt.ioff()
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        plt.title("gt-label = {}\n predicted = {}\n instance {}".format(
                                ds_recurring.label_names[y],
                                ds_recurring.label_names[np_argmax_list_predyn],
                                n
                    ), fontsize=100
                 )
        plt.axis('off')
        
        list_relevantinstances =[]
        for count_similars in range(len(idx_similars)):
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(
              tfm_denormalize(ds_recurring[idx_similars[count_similars]][0]).cpu().numpy().transpose(1,2,0),
            )
            plt.axis('off')
            plt.title("instnace {}".format(idx_similars[count_similars]), fontsize=100)
            list_relevantinstances.append(idx_similars[count_similars])
            
        
        
        
        #plot rows 2 (explanations for x2)
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        plt.axis('off')
        for count_similars in range(len(idx_similars)):
            x2 = ds_recurring[idx_similars[count_similars]][0].unsqueeze(0).to(device)
            x1 = x.unsqueeze(0).to(device)
            explanationx1x2 = explainkern_imgimg_CAMlike(
              gpmodel = gpmodel,
              func_forward_beforeavgpool = gpmodel.module_f1.module.forward_untilbeforeavgpooling,
              x1 = x.to(device),
              x2 = ds_recurring[idx_similars[count_similars]][0].to(device),
              idx_ann_outputhead = y,
              du_per_gp = du_per_class,
              scale_resizemaps = 1.0
            )
            
            toret1 = np.sum(np.sum(explanationx1x2, axis=-1), axis=-1)
            toret2 = np.sum(np.sum(explanationx1x2, axis=0), axis=0)
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(toret2, cmap="seismic", vmin=np.min(toret2), vmax=np.max(toret2))
            plt.axis('off')
        #plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        #plt.hist(kn.flatten(), bins=200)
        
        #plot rows 3 (explanations for the instance itself)
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        plt.axis('off')
        for count_similars in range(len(idx_similars)):
            x2 = ds_recurring[idx_similars[count_similars]][0].unsqueeze(0).to(device)
            x1 = x.unsqueeze(0).to(device)
            explanationx1x2 = explainkern_imgimg_CAMlike(
              gpmodel = gpmodel,
              func_forward_beforeavgpool = gpmodel.module_f1.module.forward_untilbeforeavgpooling,
              x1 = x.to(device),
              x2 = ds_recurring[idx_similars[count_similars]][0].to(device),
              idx_ann_outputhead = y,
              du_per_gp = du_per_class,
              scale_resizemaps = 1.0
            )
            
            toret1 = np.sum(np.sum(explanationx1x2, axis=-1), axis=-1)
            toret2 = np.sum(np.sum(explanationx1x2, axis=0), axis=0)
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(toret1, cmap="seismic", vmin=np.min(toret1), vmax=np.max(toret1))
                
        
    if(os.path.isfile("InterpGP/Kernel/{}/{}.jpg".format(str_subfolder, n)) == False):
        plt.savefig(
                "InterpGP/Kernel/{}/{}.jpg".format(str_subfolder, n),
                dpi=20, bbox_inches='tight', pad_inches=0, Q=80
            )
    plt.close()
    #assert False
#enable_print()

# Explain by Regions

In [None]:
import torchofgp
import torchofgp.kernel_explainers
from torchofgp.kernel_explainers import *

In [None]:
list_retval_inspectmodel = [inspect_kernel(gpmodel, ds_test, device)]

In [None]:
# file_retvals = open('InterpGP/retvals.pkl', 'wb')
# pickle.dump(list_retval_inspectmodelval_inspectmodel, file_retvals, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
file_retvals = open('InterpGP/retvals.pkl', 'rb')
list_retval_inspectmodel = pickle.load(file_retvals)

In [None]:
#settings ===
n1 = 21
n2 = 9064
rgbsigma = 0.0
#explain the similarity ===
_, list_gty, _, _, list_x, list_y, _ = list_retval_inspectmodel[0]
x1, y1 = list_x[n1], list_y[n1]
x2 = ds_recurring[n2][0]
explanationx1x2 = explainkern_imgimg_CAMlike(
      gpmodel = gpmodel,
      func_forward_beforeavgpool = gpmodel.module_f1.module.forward_untilbeforeavgpooling,
      x1 = x1.to(device),
      x2 = x2.to(device),
      idx_ann_outputhead = y1,
      du_per_gp = du_per_class,
      scale_resizemaps = 1.0,
      mode_upsample='nearest',
      flag_aligcorners = None
)
explanationx1x2 = explanationx1x2 * ((explanationx1x2>0.0) + 0.0)
print("shape of CAM-like output = {}".format(explanationx1x2.shape))

#explanationx1x2 = explanationx1x2/np.max(explanationx1x2) #TODO:check
#explanationx1x2[explanationx1x2 < 0.1] = 0.0 #TODO:check

#compute the heatmaps for img1 and img2 ===
heatmap_1 = np.sum(np.sum(explanationx1x2, 3), 2)
heatmap_2 = np.sum(np.sum(explanationx1x2, 0), 0)
normalized_heatmap_1 = heatmap_1/np.max(heatmap_1)
normalized_heatmap_2 = heatmap_2/np.max(heatmap_2)

#assert False
#make the pixel similarity matrix ====
num_pix1 = explanationx1x2.shape[0] * explanationx1x2.shape[1]
num_pix2 = explanationx1x2.shape[2] * explanationx1x2.shape[3]
np_pixelmatrix = np.zeros((num_pix1+num_pix2 , num_pix1+num_pix2))

#for pixels of image 1
for idx_ij in range(num_pix1):
    i, j = np.unravel_index(idx_ij, [explanationx1x2.shape[0], explanationx1x2.shape[1]])
    #fill the similarities between i,j and image 2
    for idx_kl in range(num_pix2):
        k, l = np.unravel_index(idx_kl, [explanationx1x2.shape[2], explanationx1x2.shape[3]])
        np_pixelmatrix[idx_ij, idx_kl] = explanationx1x2[i,j,k,l]
    #fill the similarities between i,j and image 1 
    for idx_i2j2 in range(num_pix1):
        if(idx_i2j2 != idx_ij):
            i2, j2 = np.unravel_index(
                    idx_i2j2, [explanationx1x2.shape[0], explanationx1x2.shape[1]]
                )
            dist = np.array([i,j]) - np.array([i2,j2])
            if(rgbsigma != 0.0):
                rbf = np.exp(np.sum(-dist*dist) / (2.0*rgbsigma * rgbsigma))
            else:
                rbf = 0.0
            np_pixelmatrix[idx_ij, idx_i2j2] = rbf
            
#for pixels of image 2
for idx_kl in range(num_pix2):
    k, l = np.unravel_index(idx_kl, [explanationx1x2.shape[2], explanationx1x2.shape[3]])
    #fill the similarities between k,l and image 1
    for idx_ij in range(num_pix1):
        i, j = np.unravel_index(idx_ij, [explanationx1x2.shape[0], explanationx1x2.shape[1]])
        np_pixelmatrix[idx_kl, idx_ij] = explanationx1x2[i,j,k,l]
    #fill the similarities between k,l and image 2 
    for idx_k2l2 in range(num_pix2):
        if(idx_k2l2 != idx_kl):
            k2, l2 = np.unravel_index(
                    idx_k2l2, [explanationx1x2.shape[2], explanationx1x2.shape[3]]
                )
            dist = np.array([k,l]) - np.array([k2,l2])
            if(rbf != 0.0):
                rbf = np.exp(np.sum(-dist*dist) / (2.0*rgbsigma * rgbsigma))
            else:
                rbf = 0.0
            np_pixelmatrix[idx_kl, idx_k2l2] = rbf
            
print("Computed the similarity matrix of shape {}".format(np_pixelmatrix.shape))



In [None]:
#plot the histogram of two heatmaps ===
plt.figure()
plt.hist(heatmap_1.flatten(), bins=50)
plt.show()
plt.figure()
plt.hist(heatmap_2.flatten())
plt.show()

In [None]:

#run spectral clustering on pixel similarity matrix ====
#settings ===
n_clusters = 2

from sklearn.cluster import SpectralClustering
clustering = SpectralClustering(
         affinity = "precomputed",
         n_clusters= n_clusters,
         random_state=1
).fit(np_pixelmatrix+0.0)

#show the clustering result
plt.figure(figsize=((clustering.n_clusters+2)*10, 2*10))
count_subplot = 1
plt.subplot(2, clustering.n_clusters+2, count_subplot); count_subplot +=clustering.n_clusters+2;
plt.imshow(tfm_denormalize(x1).cpu().numpy().transpose(1,2,0))
plt.axis("off")

plt.subplot(2, clustering.n_clusters+2, count_subplot); count_subplot -= (clustering.n_clusters+1);
plt.imshow(tfm_denormalize(x2).cpu().numpy().transpose(1,2,0))
plt.axis("off")


plt.subplot(2, clustering.n_clusters+2, count_subplot); count_subplot +=clustering.n_clusters+2;
plt.imshow(heatmap_1, cmap='seismic')
plt.axis("off")

plt.subplot(2, clustering.n_clusters+2, count_subplot); count_subplot -= (clustering.n_clusters+1);
plt.imshow(heatmap_2, cmap='seismic')
plt.axis("off")


for c in range(clustering.n_clusters):
    idx_pixelsinc = np.where(np.array(clustering.labels_) == c)[0].tolist() #in rng pix1+pix2
    
    
    #plot the cluster c for image 1
    img1_inc = np.zeros((explanationx1x2.shape[0], explanationx1x2.shape[1]))
    plt.subplot(
        2, clustering.n_clusters+2,
        count_subplot); count_subplot +=clustering.n_clusters+2;
    idx_img1_inc = [idx for idx in idx_pixelsinc if(idx<num_pix1)]
    if(idx_img1_inc != []):
        np_img1_ij = np.array(np.unravel_index(idx_img1_inc, img1_inc.shape))
        img1_inc[np_img1_ij[0,:] , np_img1_ij[1,:]] = 1
    plt.imshow(img1_inc, vmin=0, vmax=1)
    plt.axis('off')
    
    #plot the cluster c for image 2
    img2_inc = np.zeros((explanationx1x2.shape[2], explanationx1x2.shape[3]))
    plt.subplot(
        2, clustering.n_clusters+2,
        count_subplot); count_subplot -= (clustering.n_clusters+1);
    idx_img2_inc = [idx-num_pix1 for idx in idx_pixelsinc if(idx>=num_pix1)]
    if(idx_img2_inc != []):
        np_img2_ij = np.array(np.unravel_index(idx_img2_inc, img2_inc.shape))
        img2_inc[np_img2_ij[0,:] , np_img2_ij[1,:]] = 1
    plt.imshow(img2_inc, vmin=0, vmax=1)
    plt.axis('off')
plt.show()


In [None]:
#explain the similarity itself =====
m = 5
import warnings
warnings.filterwarnings(action = "once")
for n in range(len(ds_test)):
    fname_n, _ = ds_test._ntoimage(n)
    fname_n = os.path.relpath(fname_n, ds_rootdir)
#     if(fname_n not in ds_split["fname_hoskys"]):
#         continue
    #fields common between all models
    _, list_gty, _, _, list_x, list_y, _ = list_retval_inspectmodel[0]
    x, y = list_x[n], list_y[n]
    
    plt.figure(figsize=((m+1)*10, 3*10))
    warnings.filterwarnings("ignore")
    count_subplot = 1
    str_subfolder = ""
    for idx_model in range(len(list_retval_inspectmodel)):
        list_predy, list_gty, list_uncertainty, list_similarities, _, _, output_g = \
                            list_retval_inspectmodel[idx_model]
        np_argmax_list_predyn = np.argmax(output_g[n])
        if(np_argmax_list_predyn == list_gty[n]):
            str_subfolder = str_subfolder + "True"
        else:
            str_subfolder = str_subfolder + "False"
        kn = list_similarities[n][np_argmax_list_predyn, :].flatten()
        idx_similars = np.argsort(-kn).tolist()[0:m]
        plt.ioff()
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        plt.title("gt-label = {}\n predicted = {}\n instance {}".format(
                                ["dog", "wolf"][y],
                                ["dog", "wolf"][np_argmax_list_predyn],
                                n
                    ), fontsize=100
                 )
        
        list_relevantinstances =[]
        for count_similars in range(len(idx_similars)):
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(
              tfm_denormalize(ds_recurring[idx_similars[count_similars]][0]).cpu().numpy().transpose(1,2,0),
            )
            plt.title("instnace {}".format(idx_similars[count_similars]), fontsize=100)
            list_relevantinstances.append(idx_similars[count_similars])
            
        
        
        
        #plot rows 2 (explanations for x2)
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        for count_similars in range(len(idx_similars)):
            x2 = ds_recurring[idx_similars[count_similars]][0].unsqueeze(0).to(device)
            x1 = x.unsqueeze(0).to(device)
            explanationx1x2 = explainkern_imgimg_CAMlike(
              gpmodel = gpmodel,
              func_forward_beforeavgpool = gpmodel.module_f1.module.forward_untilbeforeavgpooling,
              x1 = x.to(device),
              x2 = ds_recurring[idx_similars[count_similars]][0].to(device),
              idx_ann_outputhead = y,
              du_per_gp = du_per_class,
              scale_resizemaps = 2.0
            )
            
            
            
            
            toret1 = np.sum(np.sum(explanationx1x2, axis=-1), axis=-1)
            toret2 = np.sum(np.sum(explanationx1x2, axis=0), axis=0)
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(toret2, cmap="seismic", vmin=np.min(toret2), vmax=np.max(toret2))
        #plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        #plt.hist(kn.flatten(), bins=200)
        
        #plot rows 3 (explanations for the instance itself)
        plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
        plt.imshow(tfm_denormalize(x).cpu().numpy().transpose(1,2,0))
        for count_similars in range(len(idx_similars)):
            x2 = ds_recurring[idx_similars[count_similars]][0].unsqueeze(0).to(device)
            x1 = x.unsqueeze(0).to(device)
            explanationx1x2 = explainkern_imgimg_CAMlike(
              gpmodel = gpmodel,
              func_forward_beforeavgpool = gpmodel.module_f1.module.forward_untilbeforeavgpooling,
              x1 = x.to(device),
              x2 = ds_recurring[idx_similars[count_similars]][0].to(device),
              idx_ann_outputhead = y,
              du_per_gp = du_per_class,
              scale_resizemaps = 2.0
            )
            
            toret1 = np.sum(np.sum(explanationx1x2, axis=-1), axis=-1)
            toret2 = np.sum(np.sum(explanationx1x2, axis=0), axis=0)
            plt.subplot(3, 1*(m+1), count_subplot); count_subplot+=1;
            plt.imshow(toret1, cmap="seismic", vmin=np.min(toret1), vmax=np.max(toret1))
                
        
    if(os.path.isfile("InterpGP/RegionBased/{}/{}.png".format(str_subfolder, n)) == False):
        plt.savefig(
                "InterpGP/RegionBased/{}/{}.png".format(str_subfolder, n),
                dpi=80, bbox_inches='tight', pad_inches=0, Q=100
            )
    plt.close()
    assert False
#enable_print()


# Depict the Prototypes

In [None]:
list_label_recurring = [ds_recurring.list_labelnames.index(ds_recurring._ntoimage(n)[1]+"/") for n in range(len(ds_recurring))]
def inspect_kernel(module_input, dl_input, list_label_recurring):
    np_lable_recurring = np.array(list_label_recurring)
    dict_class_to_sumranks = {u:0.0 for u in range(module_input.Dv)}
    with torch.no_grad():
        for idx, data in enumerate(dl_input):
            if((idx%10) == 0):
                print("Finished {} out of {}".format(
                    idx, len(dl_input.dataset)/dl_input.batch_size
                ), end='\r')
            x, y, n = data
            output, uncertainty, output_similarities = \
                module_input.testingtime_forward(x.to(device), y, n)
            output_similarities = output_similarities.detach().cpu().numpy() #[9xbatchsizex M]
            for idx_inminibatch, idx_inds in enumerate(n):
                xn_similarities = output_similarities[int(y[idx_inminibatch]),idx_inminibatch, :] #[M]
                xn_similarities = xn_similarities.flatten()
                xn_similarities = xn_similarities[np_lable_recurring == int(y[idx_inminibatch])]
                dict_class_to_sumranks[int(y[idx_inminibatch])] += np.argsort(-xn_similarities)
                
    
    return dict_class_to_sumranks

In [None]:
dict_class_to_sumranks = inspect_kernel(gpmodel, dl_test, list_label_recurring)

In [None]:
# pickle.dump( dict_class_to_sumranks, open("TrainingHistory/dict_class_to_sumranks.pkl", "wb"))

In [None]:
# dict_class_to_sumranks = pickle.load( open( "TrainingHistory/dict_class_to_sumranks.pkl", "rb" ) )

In [None]:
for k in dict_class_to_sumranks.keys():
    plt.figure()
    plt.hist(dict_class_to_sumranks[k]/len(dict_class_to_sumranks[k].tolist()), bins=200)
    plt.title("histogram of avg. ranks among the training instances \n for class {}.".format(k))
    plt.show()

In [None]:
%matplotlib inline
import relatedwork
import relatedwork.utils.transforms
tfm_denormalize = relatedwork.utils.transforms.ImgnetDenormalize()
#open the app to check whether the sample-compression scheme matches human's ====
num_tocompare = 10
input_class = int(input("Please input the class (between 0 and 9)."))
plt.figure()
plt.hist(dict_class_to_sumranks[input_class]/\
         len(dict_class_to_sumranks[input_class].tolist()), bins=200)
plt.show()
input_threshold = float(input("Please enter a threshold based on the histogram."))
#divide instances based on their ranks
np_label_recurring = np.array(list_label_recurring)
idx_k_inds = np.array(range(len(ds_recurring)))[np_label_recurring == input_class]

idx_highrank_inlocal =  np.where(dict_class_to_sumranks[input_class] <\
                          (input_threshold*len(dict_class_to_sumranks[input_class].tolist())))[0]
idx_highrank_inlocal = idx_highrank_inlocal.tolist()
idx_highrank_inds = idx_k_inds[dict_class_to_sumranks[input_class] <\
                               input_threshold*len(dict_class_to_sumranks[input_class].tolist())]
idx_lowrank_inds = np.array(list(set(range(len(ds_recurring))).difference(idx_highrank_inds)))
assert(len(idx_highrank_inds.tolist())+len(idx_lowrank_inds.tolist()) == len(ds_recurring))
print("{} percent of instances are below the threshold. exact num = {}"\
      .format(100*len(idx_highrank_inds)/(np.sum(np_label_recurring==input_class)+0.0),\
              len(idx_highrank_inds.tolist())))
score = 0
for n in range(num_tocompare):
    n_high = random.choice(idx_highrank_inds)
    n_low = random.choice(idx_lowrank_inds)
    img_high = ds_recurring[n_high][0]
    img_low  = ds_recurring[n_low][0]
    
    list_toshow = [img_high, img_low]
    flag_fillped = False
    if(np.random.rand()<0.5):
        list_toshow = [img_low, img_high]
        flag_fillped = True
    
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(tfm_denormalize(list_toshow[0]).cpu().numpy().transpose(1,2,0))
    plt.subplot(1,2,2)
    plt.imshow(tfm_denormalize(list_toshow[1]).cpu().numpy().transpose(1,2,0))
    plt.show()
    input_selection = int(input("Which image (left or right) seems a prototype?"))
    assert(input_selection in [0,1])
    if(input_selection == 0):
        if(flag_fillped == True):
            score += 1
    if(input_selection == 1):
        if(flag_fillped == False):
            score += 1
print("score = {}".format(score/(num_tocompare+0.0)))

# np.max(dict_class_to_sumranks[input_class])

In [None]:
np.min(dict_class_to_sumranks[input_class]/\
         len(dict_class_to_sumranks[input_class].tolist()))

In [None]:
def get_x_onkernelspace(gpmodel, dl_input):
    gpmodel.eval()
    dict_n_to_kernelspace = {}
    with torch.no_grad():
        for idx, data in enumerate(dl_input):
            if((idx%10) == 0):
                print("Finished {} out of {}".format(
                    idx, len(dl_input.dataset)/dl_input.batch_size
                ), end='\r')
            x, _, n = data
            output_f1 = gpmodel.module_f1(x.to(device))
            output_f1 = output_f1[:,:,0,0].detach().cpu().numpy()
            
            for idx_inminibatch, idx_inds in enumerate(n):
                dict_n_to_kernelspace[idx_inds] = output_f1[idx_inminibatch,:].flatten().tolist()
    
    X_on_kernelspace = []
    gpmodel.train()
    return X_on_kernelspace

In [None]:
inspect_kernel(gpmodel, dl_recurring)

In [None]:
a = np.zeros((40000, 40000))

In [None]:
import numpy as np
a = [
1,
2,
5,
6,
7,
8,
9,
11,
14
]
a = np.array(a)
b = [u+1 for u in range(len(a.tolist()))]
b.reverse()
print(np.sum(a*b))

In [None]:
l = [1,2,3]
l.reverse()
print(l)

In [None]:
b