In [7]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import logical_not
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from custom_transform import *
from model import CIFAR10Net
from transforms import TRANSFORMS
from mask import *
import imageio
from utils import generate_gif, cifar10_img
import os




TMP_IMG_FOLDER = "./cf10_gif_img"
RESULT_FOLDER_NAME = "result_samples_cf10"
GIF_RESULT = "./" + RESULT_FOLDER_NAME + "/{}/gif_results/img_{}.gif"
IMG_RESULT = "./" + RESULT_FOLDER_NAME + "/{}/img_results/img_{}.png"
MODEL_PATH = "./cifar10_net.pth"

CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
EPOCH = 40
BATCH_SIZE = 128
THRESHOLD = 0.8

def load_model(filepath, device):
    model = CIFAR10Net().to(device)
    model.load_state_dict(torch.load(filepath))
    model.eval()
    return model

In [8]:
device = torch.device("cpu")
model = load_model(MODEL_PATH, device)


# test_data = datasets.CIFAR10('./data', train=False, download=False, transform=TRANSFORMS['Original'])


transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


test_data = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Transform tool
img_eraser = CustomizeMask(value=-1)


In [9]:
def delta_debugging_general_new(img_eraser, model, threshold, img, label):

    c, h, w = img.shape

    current_unmasked_pixels = torch.ones((c, h, w))
    
    keep_going = True

    current_img = img.clone()
    current_pred = None
    current_pred_prob = 1

    prev_img = img.clone()
    prev_pred = None
    prev_pred_prob = 1

    img_counter = 0
    img_names = []
    smaller_group_number = None  
    
    if not os.path.exists(TMP_IMG_FOLDER):
        os.mkdir(TMP_IMG_FOLDER)
        

    # Start with binary search
    smaller_group_number = 2
    
    while keep_going:
        prev_img, prev_pred, prev_pred_prob = current_img, current_pred, current_pred_prob

        # save figures for gif

        figure = plt.figure()
        plt.title("Pred prob: {}".format(prev_pred_prob))
        plt.imshow(cifar10_img(prev_img))
        img_name = '{}.png'.format(img_counter)
        plt.savefig(os.path.join(TMP_IMG_FOLDER, img_name))
        plt.close(figure)
        img_names.append(img_name)
        img_counter += 1

        pixel_left = int(torch.sum(current_unmasked_pixels).item())
        if smaller_group_number > pixel_left:
            smaller_group_number = pixel_left

        try:
            all_new_masks = split_in_smaller_groups(current_unmasked_pixels, smaller_group_number)
        except:
            keep_going = False
            continue

        best_group_num = None
        best_pred_prob = 0
        for m_idx in range(smaller_group_number):
            new_unmask = all_new_masks[m_idx]
            new_img = img_eraser(img, new_unmask)
            new_extend_img = new_img[None, :]
            with torch.no_grad():
                output = model(new_extend_img)
            new_pred_prob = torch.max(torch.softmax(output, dim=1)).item()
            new_pred = torch.argmax(output).item()
            if new_pred_prob > threshold and new_pred == label and new_pred_prob > best_pred_prob:
                current_unmasked_pixels = new_unmask
                current_img, current_pred, current_pred_prob = new_img, new_pred, new_pred_prob
                best_group_num = m_idx

        if best_group_num is None:
            if smaller_group_number == pixel_left:
                keep_going = False
            else:
                smaller_group_number += 1
                
    print("in the end:", smaller_group_number)
    generate_gif(GIF_RESULT.format(CLASSES[label], img_idx), TMP_IMG_FOLDER, img_names, True)
    
    return prev_img , prev_pred, prev_pred_prob

In [10]:
# Create the folders for saving the results

# os.mkdir("./result_samples_cf10/")
# for i in range(10):
#     os.mkdir("./result_samples_cf10/{}/".format(CLASSES[i]))
#     os.mkdir("./result_samples_cf10/{}/{}".format(CLASSES[i],'gif_results'))
#     os.mkdir("./result_samples_cf10/{}/{}".format(CLASSES[i],'img_results'))

In [None]:

for img_idx in range(0, 10):
    (img, label) = torch.utils.data.Subset(test_data, [img_idx])[0]
    pred_img, pred_class, pred_prob = delta_debugging_general_new(img_eraser, model, THRESHOLD, img, label)


    figure = plt.figure()
    f, axarr = plt.subplots(1,2)
    axarr[0].imshow(cifar10_img(img))
    axarr[1].imshow(cifar10_img(pred_img))
    plt.title("Final pred prob: {}".format(pred_prob))
    plt.savefig(IMG_RESULT.format(CLASSES[label], img_idx))
    plt.close(figure)


in the end: 37
in the end: 128
in the end: 33
in the end: 12
in the end: 22
in the end: 20
in the end: 18
in the end: 36
