In [None]:
#Please set Colaboratory Runtime to GPU before running this notebook

In [None]:
#Clone this repository to access 
#i.   Base code of ULIME, L2 LIME and Cosine LIME.
#ii.  Dataset and preset perturbations + labels to test all three LIME variants on
#iii. requirements.txt for installing necessary packages
!git clone https://github.com/ansariminhaj/ulime_github.git

In [None]:
#Install all packages
!pip install -r ulime_github/requirements.txt

In [None]:
#Uzip folder containig base code of ULIME, L2 LIME and Cosine LIME.
!unzip ulime_github/lime_lib_fixed.zip

In [None]:
#Import all required packages
import torch
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import copy
import skimage.io 
import skimage.segmentation
import copy
import os
from PIL import Image
import ast
from itertools import product
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
%matplotlib inline

In [None]:
#Black box model to explain
#We use the ImageNet pretrained inception_v3 model in this tutorial
model = torchvision.models.inception_v3(pretrained=True) #Load pretrained model
model.cuda()

In [None]:
from lime_lib_fixed.lime_org_U.lime import lime_image #(Use a fresh kernel before importing so that cache is empty)
#from lime_lib_fixed.lime_org_l2.lime import lime_image #(Use a fresh kernel before importing so that cache is empty)
#from lime_lib_fixed.lime_org_cosine.lime import lime_image #(Use a fresh kernel before importing so that cache is empty)

In [None]:
#Heatmaps are used to visualize the LIME explanations.
#We use the Viridis Colormap for visualization
def heatmap_image(img,c_list,segments):
    num_superpixels = np.unique(segments).shape[0]
    mask = np.zeros(segments.shape)
    for i in range(num_superpixels):
        mask[segments == i] = c_list[i] 
    return mask

#This functions imports an image and converts it to RGB
def get_image(path):
  with open(os.path.abspath(path), 'rb') as f:
      with Image.open(f) as img:
          return img.convert('RGB') 
        
#Resize and take the center part of image to what our model expects
# def get_input_transform():
#     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                     std=[0.229, 0.224, 0.225])       
#     transf = transforms.Compose([
#         transforms.Resize((256, 256)),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         normalize
#     ])    

#     return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

#This function is used to feed input to the black box model 
#The purpose is to obtain class probabilities for perturbed image
#These probabilities are used as labels along with the perturbations to train a linear model
def batch_predict(images):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [None]:
#GLOBAL PARAMETERS

#List of perturbation quantities P
perturbations_num = [10, 50, 100, 200, 300, 500, 700, 1000, 1500, 2000, 2500, 3000]

#This folder contains fixed perturbations (data.txt) + associated labels (labels.txt) for the first 10 ImageNet classes.
data_labels_folder = "ulime_github/pert_labels"

#Folder to store explanations and metrics (LnO, heatmaps, Combined Variance)
folder_name = "imagenet_ulime"
os.mkdir(folder_name)

#Number of times LIME runs per image
runs = 10

#Number of images that this tutorial needs to run on
#We have provided 10 images for the tutorial, please contact us at minhaj3737@gmail.com if you need the entire dataset.
#The 10 images are from ImageNet classes 0-9, 1 image from each class
number_images = 10

In [None]:
explainer = lime_image.LimeImageExplainer(feature_selection='none')

for class_i in range(number_images):
    
    #Make a folder for each image that is being explained
    os.mkdir(folder_name+"/"+str(class_i))
    
    #Walk through all images per class folder. In this tutorial, we only have one image per class folder
    for path_dir, dirnames, filenames in os.walk('ulime_github/imagenet_10images/'+str(class_i)):
        coeffs_list=[]
        segments = "" #Stores superpixel segments to calculate LnO
        num_superpixels = 0 #Stores number of superpixels
        image = 0 #Stores the image being explained
        
        #This returns the image file in the folder
        for file in filenames:
            img_path = path_dir+"/"+file
            
            #Convert image to RGB
            img = get_image(img_path)
            plt.imshow(img)
            plt.close()
            
            #Tranform image (Only resize). Normalization will be applied in the batch_predict function
            #This is because the segmentation algorithm needed by LIME should only be applied on the actual image
            #Not a normalized image. The normalization is only needed for the black box model, therefore,
            #it has been exclusively kept in the black box prediction function.
            image = np.array(pill_transf(img))
            
            #Create folder for the predefined perturbations (data) and associated labels being used.
            #This is useful when you want to create your own set of random perturbations and labels.
            #You run one LIME variant once, store the perturbations and labels, and use those for
            #other LIME variants so that the comparison is fair.
            #Modifications in the code needed to create your own set of perturbations given in the
            #last code block.
            os.mkdir(folder_name+"/"+str(class_i)+"/data")
            os.mkdir(folder_name+"/"+str(class_i)+"/labels")
            
            #Run 'r' times per image. 
            for i in range(runs): 
                
                #Create a perturbation and labels subfolder for each run.
                os.mkdir(folder_name+"/"+str(class_i)+"/data/"+"run_"+str(i))
                os.mkdir(folder_name+"/"+str(class_i)+"/labels/"+"run_"+str(i))  

                print("RUN: ", i+1)
                
                #Coeff_list stores the coefficients of the linear model for each LIME explanation
                coeff_list = []
                
                #Run through all the perturbation quantities defined in perturbations_num
                for p in perturbations_num:
                    num_perturb = p
                    
                    #Extract the fixed perturbations
                    data_fixed = open(data_labels_folder+"/"+str(class_i)+"/data/"+"run_"+str(i)+"/"+str(p)+".txt", "r")
                    data_fixed.readline()
                    data_fixed = np.array(ast.literal_eval(data_fixed.read()))
                    print("Perturbations shape: ", data_fixed.shape)
                    
                    #Extract the associated labels for the fixed perturbations
                    labels_fixed = open(data_labels_folder+"/"+str(class_i)+"/labels/"+"run_"+str(i)+"/"+str(p)+".txt", "r")
                    labels_fixed.readline()
                    labels_fixed = np.array(ast.literal_eval(labels_fixed.read()))
                    print("Labels shape: ", labels_fixed.shape)

                    predictions = []
                    
                    #Explain the image with the fixed perturbations and labels
                    #This is where the linear model is trained 
                    #The output is the:
                    #i. explanation: Coefficients of the linear model
                    #ii. segments: Superpixels. Used to view the explanation heatmaps
                    #iii. label: The label we gave LIME to generate an explanation for the image
                    #iv. data: Fixed perturbations we gave as input to LIME
                    #v. labels: Fixed labels we gave as input to LIME
                    explanation, segments, label, data, labels = explainer.explain_instance(image, 
                                                            batch_predict, data_fixed, labels_fixed,
                                                            class_i, 
                                                            top_labels=1, 
                                                            hide_color=0, 
                                                            num_samples=num_perturb) # number of images that will be sent to classification function
                    
                    #Write purterbations in the perturbations folder
                    fX= open(folder_name+"/"+str(class_i)+"/data/"+"run_"+str(i)+"/"+str(p)+".txt","w+")
                    fX.write("Class: "+str(class_i))
                    fX.write("\n")
                    data_list_all = []
                    
                    for pert_index in range(len(data)):
                      data_list = list(data[pert_index])
                      data_list_all.append(data_list)

                    fX.write(str(data_list_all))
                    fX.write("\n\n")
                    fX.close()
                    
                    #Write labels in the perturbations folder
                    fY= open(folder_name+"/"+str(class_i)+"/labels/"+"run_"+str(i)+"/"+str(p)+".txt","w+")
                    fY.write("Class: "+str(class_i))
                    fY.write("\n")

                    labels_list = list(labels)
                    fY.write(str(labels_list))
                    fY.write("\n\n")
                    fY.close()

                    coeff = [None] * (np.unique(segments).shape[0])

                    print("Number of superpixels: ", np.unique(segments).shape[0])
                    num_superpixels = np.unique(segments).shape[0]
                    
                    #Save the linear model coefficients in the coeff list
                    for exp_tuple in explanation.local_exp[label]:
                      coeff[exp_tuple[0]] = exp_tuple[1]

                    #Normalize coeff list and store normalized coefficients in another list coeff_list
                    coeff_norm = [(float(i)-min(coeff))/(max(coeff)-min(coeff)) for i in coeff]
                    coeff_list.append(coeff_norm)
                    
                #coeffs_list contains coefficients of all the linear models trained on all the perturbations
                coeffs_list.append(coeff_list)

            
            #Show the superpixilized image and save it 
            plt.imshow(skimage.segmentation.mark_boundaries(image, segments))
            plt.axis('off')
            plt.savefig(folder_name+"/"+str(class_i)+'/seg_image.png', bbox_inches='tight')
            plt.show()
            plt.close()
            
            #Show the original image and save it 
            plt.imshow(image)
            plt.axis('off')
            plt.savefig(folder_name+"/"+str(class_i)+'/org_image.png', bbox_inches='tight')
            plt.show()
            plt.close()
            
            #Save all coefficients in a text file in case they need to be accessed in the future.
            fc= open(folder_name+"/"+str(class_i)+"/coeffs_list.txt","w+")
            fc.write("Class: "+str(class_i))
            fc.write("\n")
            fc.write(str(coeffs_list))
            fc.write("\n\n")
            fc.close()
            
            ### This section is for creating LnO's values ###
            #probability list pert stores all probabilities of a particular class
            #as the superpixels are removed from most important to least important
            #for all linear models trained in the previous step
            probability_list_run = []
            probability_list_pert = []
            for i in range(runs): #Runs
                probability_list_pert=[]

                for j in range(len(perturbations_num)): #Perturbations
                    probability_list_n=[]
                    
                    #k goes from 1 to number of coefficients + 1 (since we start from the whole image)
                    #Each coefficient is associated with a superpixel. The magnitude of the coefficient
                    #specifies the importance of the superpixel in the explanation.
                    #We do an argsort and remove superpixels from most to least important.
                    for k in range(0, len(coeffs_list[i][j])+1): #Coefficients
                        num_top_features = k
                        if num_top_features == 0:
                            top_features = []
                        else:
                            top_features = np.argsort(coeffs_list[i][j])[-num_top_features:] 


                        mask_exp = np.ones(num_superpixels) 
                        mask_exp[top_features]= False

                        image = np.array(pill_transf(img))
                        fudged_image = image.copy()
                        fudged_image[:] = 0

                        temp = copy.deepcopy(image)
                        zeros = np.where(mask_exp == 0)[0]
                        mask = np.zeros(segments.shape).astype(bool)

                        for z in zeros:
                            mask[segments == z] = True
                        temp[mask] = fudged_image[mask]

                        imgs = []
                        imgs.append(temp)
                        
                        #Each time we remove a superpixel, we note down the black box classification value
                        preds = batch_predict(np.array(imgs))

                        probability_list_n.append(preds[0][class_i])

                    probability_list_pert.append(probability_list_n)

                #probability_list_run is structured as (Run, Perturbation, LnO (Superpixels + 1))
                probability_list_run.append(probability_list_pert)

            #Save the LnO's in a text file
            f= open(folder_name+"/"+str(class_i)+"/LnO.txt","w+")
            for run in range(runs):
                f.write("RUN: %d\n\n" % run)
                for pert in range(len(perturbations_num)):
                    f.write("PERTURBATION: %d\n" % perturbations_num[pert])
                    f.write(str(probability_list_run[run][pert]))
                    f.write("\n")
            f.close()

            ### This section is for graphing the LnO's values ###
            n_list = [n for n in range(num_superpixels+1)]
            _list = []
            _list1 = []

            for pert in range(len(perturbations_num)):
                _list=[]
                for run in range(runs):
                    _list.append(probability_list_run[run][pert])
                _list1.append(_list)
            
            #Graph LnO for each run and perturbation individually
            os.mkdir(folder_name+"/"+str(class_i)+'/LnO')
            for pert in range(len(perturbations_num)):
                os.mkdir(folder_name+"/"+str(class_i)+'/LnO/'+str(perturbations_num[pert]))
                for run in range(runs):
                    plt.xlabel("n")
                    plt.ylabel("LnO Accuracy")
                    plt.plot(n_list, _list1[pert][run])
                    plt.savefig(folder_name+"/"+str(class_i)+'/LnO/'+str(perturbations_num[pert])+"/run_"+str(run)+'.png', bbox_inches='tight')
                    plt.close()
            
            #Graph all 10 runs for a perturbation in a single figure
            for pert in range(len(perturbations_num)):
                for run in range(runs):
                    plt.xlabel("n")
                    plt.ylabel("LnO Accuracy")
                    plt.plot(n_list, _list1[pert][run])
                plt.savefig(folder_name+"/"+str(class_i)+'/LnO/'+str(perturbations_num[pert])+"/all_runs.png", bbox_inches='tight')
                plt.close()

            #Plot heatmaps of each explanation based on the normalized coefficient values
            for run in range(runs):
                os.mkdir(folder_name+"/"+str(class_i)+"/"+str(run))
                os.mkdir(folder_name+"/"+str(class_i)+"/"+str(run)+"/heatmaps")
                for pert in range(len(perturbations_num)):
                    heat_img = heatmap_image(image,coeffs_list[run][pert],segments)
                    plt.imshow(heat_img)
                    plt.axis('off')
                    plt.savefig(folder_name+"/"+str(class_i)+"/"+str(run)+"/heatmaps/"+str(perturbations_num[pert])+".png", bbox_inches='tight')
                    plt.close()
            
            ### This section is for plotting the combined variance ###
            #We do this for each perturbation. We find how much the combined variance exists across the 10 runs
            #per perturbation
            l1=[]
            l2=[]
            l3=[]

            for pert in range(len(perturbations_num)): #Pick perturbation number
                for n in range(len(probability_list_run[0][0])-1): #Coefficients
                    for run in range(runs): #runs
                        l1.append(coeffs_list[run][pert][n]) #List of kth coeff of all runs for perturbation j 
                    l2.append(l1) #List all coeffs for perturbation pert  
                    l1=[]
                l3.append(l2) #List of all perturbations 
                l2 = []

            # l2 is structured like (per pert) [ [1st coeff for all runs], [2nd coeff for all runs], [3rd coeff for all runs] ... ]
            # l3 is structured like [ [1st pert l2], [2nd pert l2], [3rd pert l2] ... ]

            total_var_l = []
            total_mean_l = []
            f= open(folder_name+"/"+str(class_i)+"/pert_var.txt","w+")
            for pert in range(len(perturbations_num)):
                print("Perturbation: ", perturbations_num[pert],"\n")
                f.write("Perturbation: " + str(perturbations_num[pert]) +"\n")
                total_var = 0
                total_mean = 0
                for n in range(len(l3[pert])): #go through all coefficients
                    total_var += np.var(l3[pert][n])
                    total_mean += np.mean(l3[pert][n])
                    for j in range(n+1, len(l3[pert])):
                        total_var += 2 * np.cov(l3[pert][n], l3[pert][j], ddof=0)[0][1] #Extract covariance value from covariance matrix
                print("Mean: ", total_mean)
                print("Var: ",total_var)
                f.write("Mean: "+str(total_mean)+"\n")
                f.write("Var: "+str(total_var)+"\n")
                total_var_l.append(total_var)
            f.close()
            
            #Plot the combined variance graph.
            plt.xlabel("Perturbations")
            plt.ylabel("Variance")
            plt.plot(perturbations_num, total_var_l)
            plt.savefig(folder_name+"/"+str(class_i)+'/pert_var.png', bbox_inches='tight')
    

In [None]:
#For generating your own perturbations and labels, follow the steps below: 

#UNCOMMENT these lines in lime_image.py: 
#201 data, labels = self.data_labels(image, fudged_image, segments,
#202                                 classifier_fn, num_samples,
#203                                 batch_size=batch_size,
#204                                 progress_bar=progress_bar)

#COMMENT these lines in lime_image.py:
#206 data = data_fixed
#207 labels = labels_fixed

#UNCOMMENT #[:, label] in line 184 (labels_column = neighborhood_labels#[:, label])