In [1]:
from model.EfficientNet import EfficientNet
from model.Mango import Mango
from model import Explainable
import torch
import torchvision.transforms as transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import math


TRAIN_DIR = "./data/C1-P1_Train/"
TRAIN_CSV = "./data/train.csv"
DEV_DIR = "./data/C1-P1_Dev/"
DEV_CSV = "./data/dev.csv"
Mango_Class = {'A': 0, 'B': 1, 'C': 2}

# hyper parameters
DEPTH = 2
WIDTH = 1.5
RESOLUTION = 0.25
BS_PER_GPU = 5
NUM_CHANNELS = 3
NUM_CLASSES = 3
IMG_SIZE = int(224 * RESOLUTION)

# DataLoader

In [2]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomCrop(int(IMG_SIZE*2/3)),
    transforms.ToTensor(),
    transforms.RandomErasing(),
])
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

trainset = Mango(TRAIN_CSV, TRAIN_DIR, Mango_Class, train_transform)
testset = Mango(DEV_CSV, DEV_DIR, Mango_Class, test_transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BS_PER_GPU, shuffle=False, num_workers=6)
testloader = torch.utils.data.DataLoader(testset, batch_size=BS_PER_GPU, shuffle=False, num_workers=6)

In [None]:
model = EfficientNet(DEPTH, WIDTH, NUM_CHANNELS, IMG_SIZE, dropout=0.2, classes=NUM_CLASSES)
model.load_state_dict(torch.load('./model/weights/Efficient_'+ str(DEPTH) + str(WIDTH) + '.pkl'))

# Saliency

In [None]:
data = next(iter(testloader))
imgs, labels = data['data'], data['label']
saliencies = Explainable.get_saliency(imgs, labels, model)

plt_num = 5
fig, axes = plt.subplots(2, plt_num, figsize=(15, 8))
for i in range(plt_num):
    axes[0, i].set_title(labels[i])
    axes[0, i].imshow(imgs[i].permute(1, 2, 0).detach().numpy())
    axes[1, i].imshow(saliencies[i])
plt.show()

# Fisher Sensitivity

In [None]:
data = next(iter(testloader))
imgs, labels = data['data'], data['label']
fisher = Explainable.fisher_sensitivity(imgs, labels, model, 2, 0)

plt_num = 5
fig, axes = plt.subplots(2, plt_num, figsize=(15, 8))
for i in range(plt_num):
    axes[0, i].set_title(labels[i])
    axes[0, i].imshow(imgs[i].permute(1, 2, 0).detach().numpy())
    axes[1, i].imshow(fisher[i])
plt.show()

# Filter Explain

In [None]:
img_num = 0
plt_num = 5

data = next(iter(testloader))
imgs, labels = data['data'], data['label']
filter_activations, filter_visualization = Explainable.filter_explaination(imgs, model, model.stage1)

filter_num = round(filter_activations.shape[1] / plt_num) - 1 
fig, axes = plt.subplots(filter_num + 2, plt_num, figsize=(15, 45))
for i in range(plt_num):
    axes[0, i].set_title(labels[i])
    axes[0, i].imshow(imgs[i].permute(1, 2, 0))
    axes[1, i].imshow(filter_visualization[i])
for i in range(filter_num):
    for j in range(plt_num):
        axes[i+2, j].imshow(filter_activations[img_num][i * (plt_num) + j])
fig.show()