In [1]:
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
import lime_image
import torch
from torchvision import models, transforms
from sal_resnet import resnet50, resnext50_32x4d, wide_resnet50_2
from torch.autograd import Variable
import torch.nn.functional as F
import glob
from skimage.segmentation import mark_boundaries

In [2]:
import matplotlib 
font = {
        'size'   : 26}

matplotlib.rc('font', **font)

In [3]:
def turn_off(ax):
    ax.tick_params(
                axis='both',  
                which='both',  
                bottom=False,  
                top=False,     
                labelbottom=False) 
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    return

In [4]:
def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 
        
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batch_predict(images, masks=None):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    
    model.to(device)
    batch = batch.to(device)     
    with torch.no_grad():
        if masks is None:
            logits = model(batch)
        else:
            masks = 1 - torch.Tensor(masks).float().to(device)
            logits = model((batch, masks, [100, 5]))
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [5]:
def block_segment(img):
    segments = np.arange(14*14).reshape(1, 1, 14, 1, 14, 1)
    seg_mask = np.tile(segments, (1, 1, 1, 16, 1, 16))
    seg_mask = seg_mask.reshape(224, 224)
    return seg_mask

In [6]:
from skimage.segmentation import quickshift, slic

images = ['./data/cat_mouse.jpg', './data/dogs.png']
seg_fn = lambda x : quickshift(x,kernel_size=2, max_dist=200, ratio=0.2, random_seed=0)

In [7]:
bs=256
num_samples=3000
topk = 2
model = resnet50(pretrained=True)
model.eval()
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath('./imagenet_class_index.json'), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}    

grey = 255*np.array(preprocess_transform.transforms[1].mean)

miss_approxns = [0, grey, None]
miss_approxns_name = ['Blackout', 'Greyout', r'$\bf{Layer\ Masking}$']

for im_no, imgpath in enumerate(images):
#     if os.path.isfile(f"./results/ambiguous_lime/{imgpath.split('/')[-1]}"):
#         continue
    img = get_image(imgpath)
#     img_t = get_input_tensors(img)
#     logits = model(img_t)
#     probs = F.softmax(logits, dim=1)
#     probs5 = probs.topk(5)
#     tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))
#     test_pred = batch_predict([pill_transf(img)])
#     test_pred.squeeze().argmax()
    
    f, axarr = plt.subplots(topk, len(miss_approxns), figsize=(len(miss_approxns)*10, topk*10)) 
    for j, miss_approx in enumerate(miss_approxns):
        if miss_approx is None:
            explainer = lime_image.MyLimeImageExplainer()
            explanation = explainer.explain_instance(np.array(pill_transf(img)), 
                                                     batch_predict, 
                                                     batch_size=bs,
                                                     top_labels=topk, 
                                                     segmentation_fn=seg_fn,
                                                     num_samples=num_samples) 
        else:
            explainer = lime_image.LimeImageExplainer()
            explanation = explainer.explain_instance(np.array(pill_transf(img)), 
                                                     batch_predict, 
                                                     batch_size=bs,
                                                     top_labels=topk, 
                                                     hide_color=miss_approx, 
                                                     segmentation_fn=seg_fn,
                                                     num_samples=num_samples) 
        for i in range(topk):
            im_label = explanation.top_labels[i]
            temp, mask = explanation.get_image_and_mask(im_label, positive_only=False, num_features=50, hide_rest=False)
            lime_exp_img = mark_boundaries(temp/255.0, mask)
            turn_off(axarr[i][j])
            axarr[i][j].imshow(lime_exp_img)
            if j == 0:
                axarr[i][j].set_ylabel(f'{idx2label[im_label]}')
            if i == 0:
                axarr[i][j].set_title(f'{miss_approxns_name[j]}')
                
    
    plt.savefig(f"./results/lime_{imgpath.split('/')[-1]}", bbox_inches='tight')
    plt.close()
#     plt.savefig(f'./results/lime_{imgpath.split('/')[-1]}')
