In [None]:
# basic settings

running_remote = False
use_colab = False
use_kaggle = False

In [None]:
# Pre-Setting
import os


if running_remote:
    if use_colab:
        from google.drive import drive
        drive.mount('/content/drive')
        
        target_path = "/content/drive/MyDrive/XAI"
        if os.path.exists(target_path):
            %cd /content/drive/MyDrive/XAI
        else:
            os.makedirs(target_path)
            %cd /content/drive/MyDrive/XAI   
        !pwd
        !ls

        # download dataset.zip
        !gdown --id '1cYBWwYab3djiaYuOU6CxkYHQyUYws4Ce' --output food.zip
        !unzip food.zip
        !pip install lime=0.1.1.37
        # download pre-trained Model
        !gdown --id '1CShZHsO8oAZwxQkMe7jRtEgSNb2w_OZu' --output checkpoint.pth
        !ls
    if use_kaggle:
        %cd /kaggle/input/XAT
        !ls 

In [None]:
import torch 
import numpy as np
import os, sys, argparse
import matplotlib.pyplot as plt 
import torch.nn as nn
import torch.functional as F
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
from skimage.segmentation import slic
from lime import lime_image
from pdb import  set_trace
from torch.autograd import Variable

In [None]:
# hyper-args
args = {
    'ckptpath':    './checkpoint.pth',
    'dataset_dir': './food/'
}

renew = False

device = "cuda" if torch.cuda.is_available() else "cpu"

args = argparse.Namespace(**args)

In [None]:
# Define Model

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        
        def building_block(in_dim, out_dim):
            return [
                nn.Conv2d(in_dim, out_dim, 3, 1, 1),
                nn.BatchNorm2d(out_dim),
                nn.ReLU(inplace=True),
            ]
        
        def stack_blocks(in_dim, out_dim, block_num):
            layers = building_block(in_dim, out_dim)
            for i in range(block_num - 1):
                layers += building_block(out_dim, out_dim)
            layers.append(nn.MaxPool2d(2, 2, 0))
            return layers
        
        cnn_list = []
        cnn_list += stack_blocks(3, 128, 3)
        cnn_list += stack_blocks(128, 256, 3)
        cnn_list += stack_blocks(256, 512, 3)
        cnn_list += stack_blocks(512, 1024, 1)
        cnn_list += stack_blocks(1024, 1024, 1)
        
        self.cnn = nn.Sequential(*cnn_list)
        
        self.fcn = nn.Sequential(
            nn.Linear(1024 * 4 * 4, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p = 0.3),
            nn.Linear(2048, 11)
        )
        
    def forward(self, x):
        out = self.cnn(x)
        out = out.reshape(out.size()[0], -1)
        return self.fcn(out)
    
model = Classifier().to(device)
if renew:
    ckpt = torch.load(args.ckptpath)
    model.load_state_dict(ckpt['model_state_dict'])


In [None]:
# Define Dataset and Dataloader

class FoodDataset(Dataset):
    def __init__(self, paths, labels, mode):
        
        self.paths = paths
        self.labels = labels
        
        train_trainsform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
        ])

        test_transform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.ToTensor(),
        ])

        self.transform = train_trainsform if mode == 'train' else test_transform

    def __getitem__(self, idx):
        x = Image.open(self.paths[idx])
        x = transforms(x)
        y = self.labels[idx]

        return x, y
    
    def __len__(self):
        return len(self.paths)

    def getbatch(self, indices):
        imgs = []
        labels = []
        for idx in indices:
            img, label = self.__getitem__(idx)
            imgs.append(img)
            labels.append(label)

        return torch.stack(imgs), torch.stack(labels)

def get_paths_labels(path):
    def my_key(name):
        return int(name.replace(".jpg", ""))+1000000*int(name.split("_")[0])
    imgnames = os.listdir(path)
    imgnames.sort(key=my_key)
    imgpaths = []
    labels = []
    for name in imgnames:
        imgpaths.append(os.path.join(path, name))
        labels.append(int(name.split('_')[0]))    
    return imgpaths, labels

train_paths, train_labels = get_paths_labels(args.dataset_dir)

train_set = FoodDataset(train_paths, train_labels, mode='eval')

## Start XAI Homework!!

In [None]:
img_indices = [i for i in range(10)]
images, labels = train_set.getbatch(img_indices)
fig, axs = plt.subplots(1, len(img_indices), figsize=(15, 8))
for i, img in enumerate(images):
  axs[i].imshow(img.cpu().permute(1, 2, 0))
# print(labels)

In [None]:
def predict(input):
    # input : [batches, height, width, channels]
    model.eval()
    input = torch.FloatTensor(input).permute(0, 1, 3, 2)
    output = model(input.cuda())

    return output.detach().cpu().numpy()
 
def segmentation(input):
    return slic(input, n_segments=200, compactness=1, sigma=1)
    