In [None]:
import numpy
import os
import torch
from PIL import Image
import lime 
import sys
import pandas
from torch import nn

In [None]:
sys.path.append("../")

In [None]:
PREDICTION_STATS_URL = "prediction_stats/validation_set_pred_stats.csv"

In [None]:
stats = pandas.read_csv(PREDICTION_STATS_URL)

In [None]:
local_int_imgs = stats[stats['output_score'] <= stats['baseline_score']]
global_int_imgs = stats[~local_int_imgs]

In [None]:
def load_images(path_url: int):
    pass

## Loading Images for global interpretation of network performance

In [None]:
global_imgs = load_images(global_int_imgs)

## Loading images for local interpretation of the network performance

In [None]:
local_imgs = load_images(local_int_imgs)

In [None]:
import skimage.segmentation 
import cv2

class NetworkPerformanceExplainer(object):
    """
    Class, which leverages quick shift segmentation
    algorithm in conjunction with linear models
    to achieve LIME interpreter performance

    Parameters:
    -----------

    input_network - Network, used for experiment
    input_height - height of the image
    image_width - width of the image
    """
    def __init__(self, 
        input_network: nn.Module, 
        image_height: int, 
        image_width: int,
        seg_kernel_size: int,
        max_seg_dist: int,
        seg_ratio: float
    ):
        self.input_network = input_network 
        self.image_height = image_height 
        self.image_width = image_width
        self.kernel_size = seg_kernel_size
        self.maximum_seg_dist = max_seg_dist 
        self.seg_ratio = seg_ratio

    def _compute_perturbations(self, image):
        superpixels = skimage.segmentation.quickshift(
            image=image, 
            kernel_size=self.kernel_size, 
            max_dist=self.maximum_seg_dist, 
            ratio=self.seg_ratio
        )
        print("computed number of pixels: ", numpy.unique(superpixels).shape[0])
        return superpixels

    def perturb_image(self, 
        image: numpy.ndarray, 
        perturbation: numpy.ndarray, 
        img_segments: numpy.ndarray
    ):
        active_pixels = numpy.where(perturbation == 1)[0]
        mask = numpy.zeros(shape=image.shape)

        for pixel in active_pixels:
            mask[img_segments == pixel] = 1

        pert_img = numpy.copy(image)
        pert_img = cv2.bitwise_and(src1=pert_img, src2=mask)
        return pert_img
    
    def explain_local(self, input_img: numpy.ndarray):
        pass 

    def explain_global(self, input_img: numpy.ndarray):
        pass