In [1]:
import sys
sys.path.append('./code')
import os
import multiprocessing

import matplotlib.pylab as plt
import PIL.Image as Image
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from skimage.segmentation import quickshift

from sedc_t2_fast import sedc_t2_fast

In [2]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
# Import model
classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2"

# Define image shape
IMAGE_SHAPE = (224, 224)

# Configure classifier
classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,), trainable=True)
])

# Get labels
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())


In [None]:
def find_cf_region(i_data):
    
    img_path = i_data[0]
    output_folder = i_data[1]
    
    img_filename = img_path.split('/')[-1]
    
    image = Image.open(f'{img_path}')
    image = image.resize(IMAGE_SHAPE)
    image = np.array(image)/255.0
    image = image[:,:,0:3]
    
    segments = quickshift(image, kernel_size=4, max_dist=200, ratio=0.2)
    
    explanation, segments_in_explanation, perturbation, new_class = sedc_t2_fast(
        image, 
        classifier, 
        segments, 
        np.where(imagenet_labels=='French bulldog')[0][0], 
        'blur', 
        max_time=10)
    
    if explanation is not None:
        mask_true_single = np.isin(segments, segments_in_explanation)

        mask_true_full = []
        for row in mask_true_single:
            mask_true_col = []
            for col in row:
                if col:
                    mask_true_col.append([1, 1, 1])
                else:
                    mask_true_col.append([0, 0, 0])
            mask_true_full.append(mask_true_col)

        mask_true_full = np.array(mask_true_full)

        Image.fromarray(( ( image*(mask_true_full==0) + (mask_true_full!=0)*(0,1,0)) * 255).astype(np.uint8)).save(f'{output_folder}/{img_filename}')

In [None]:
img_factual = os.listdir('./chihuahua_test/')
img_factual_paths = [[ './chihuahua_test/'+x, './cf_region_test/'] for x in img_factual]

In [None]:
%%time
for i_data in img_factual_paths:
    find_cf_region(i_data)