In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pickle
from pathlib import Path

import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
import ipywidgets as widgets
from PIL import Image as pilImage
from fastai.vision.all import *

import sys
sys.path.append("../") 

from src.utils import *
from src.gradcam import GuidedGradCam
from src.augmentationImpactAnalyzer import AugmentationImpactAnalyzer 

ROOT_DIR =  Path('../')
DATA_PATH = ROOT_DIR/'data/'
IMGS_PATH = ROOT_DIR/'imgs/'
MODEL_API_PATH = ROOT_DIR/'model_api'

In [2]:
np.random.seed(0)
torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
device

device(type='cpu')

In [3]:

# load IMAGENETTE_160 and 
#path = untar_data(URLs.IMAGENETTE_160,dest=DATA_PATH)

# load the classes to get the right labels
#import urllib.request, json 
#with urllib.request.urlopen('https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json') as url:
#    classes = json.loads(url.read().decode())
#f_to_idx = {val[0]:idx for idx,val in classes.items()}
#my_classes = [f_to_idx[f.name] for f in (path/"train").ls()]

#def save_pickle(item, path):
#    with open(path, 'wb') as handle:
#        pickle.dump(item, handle)
#
#save_pickle(my_classes, MODEL_API_PATH/'IMAGENETTE_160_classes')
#save_pickle(classes, MODEL_API_PATH/'IMAGENETTE_classes')

def load_pickle(path):
    with open(path, 'rb') as handle:
        item = pickle.load(handle)
    return item

classes = load_pickle(MODEL_API_PATH/'IMAGENETTE_classes')
my_classes = load_pickle(MODEL_API_PATH/'IMAGENETTE_160_classes')

In [4]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
ggc = GuidedGradCam(model,use_cuda,target_type='classification', layer_ids=['layer4.2'])



Using cache found in /Users/ap/.cache/torch/hub/pytorch_vision_v0.6.0


register hooks for:
layer4.2


In [5]:

# use imagenet stats        
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
size = 160
img = pilImage.open(IMGS_PATH/'example_input.png').resize((size,size))
aia = AugmentationImpactAnalyzer(img,
                                model,
                                cuda=use_cuda,
                                add_output_act=True,
                                restrict_classes={i:classes[i][1] for i in my_classes},
                                normalize=normalize,
                                guided_grad_cam=ggc)


In [6]:
## ---------- test image transformation class ---------- 
#aia.tfms(crop_size=140,highlight_act='gradcam',perspective_w=180,perspective_h=180,perspective_d=20,erase_w=10,erase_h=10,rotate_ang=90)

In [7]:
width, height = aia.img.shape

import ipywidgets as widgets
def create_gif_on_click(change):
    os.makedirs('../imgs/',exist_ok=True)
    aia.create_gif(IMGS_PATH/'results.gif')

btn_create_gif = widgets.Button(description='Create gif')
btn_create_gif.on_click(create_gif_on_click)

btn_act_loc = widgets.RadioButtons(
    options=['none','gradcam','guided-gradient','guided-gradcam',],
    description='Activation Localization:',
    disabled=False)

sl_brightness = widgets.FloatSlider(value=1,min=0.1,max=4,step=0.2)
box_brightness = widgets.VBox([widgets.HTML('<em>Brightness</em>'),sl_brightness])

sl_crop_size = widgets.IntSlider(value=width,min=33,max=width,step=5)
box_crop = widgets.VBox([widgets.HTML('<em>Center Crop</em>'),sl_crop_size])

sl_rotate = widgets.IntSlider(value=0,min=0,max=360,step=5)
box_rotate = widgets.VBox([widgets.HTML('<em>Rotation Angle</em>'),sl_rotate])
sl_perspective_w = widgets.IntSlider(value=width,min=0,max=width,step=5)
sl_perspective_h = widgets.IntSlider(value=height,min=0,max=height,step=5)
sl_perspective_d = widgets.IntSlider(value=0,min=0,max=height,step=5)
box_perspective = widgets.VBox([widgets.HTML('<em>Perspective Distortion</em>'),sl_perspective_w,sl_perspective_h,sl_perspective_d])

sl_erase_i = widgets.IntSlider(value=0,min=0,max=width,step=5)
sl_erase_j = widgets.IntSlider(value=0,min=0,max=height,step=5)
sl_erase_w = widgets.IntSlider(value=0,min=0,max=height,step=5)
sl_erase_h = widgets.IntSlider(value=0,min=0,max=height,step=5)
box_erase = widgets.VBox([widgets.HTML('<em>Erase Box</em>'),sl_erase_i,sl_erase_j,sl_erase_w, sl_erase_h])

btn_upload = widgets.FileUpload(description='Your Image')

tfm_args_sl = { "brightness":sl_brightness,
                "crop_size":sl_crop_size,
                "activation_localization":btn_act_loc, 
                "rotate_ang":sl_rotate,
                "perspective_w":sl_perspective_w,
                "perspective_h":sl_perspective_h,
                "perspective_d":sl_perspective_d,
                "erase_i":sl_erase_i,
                "erase_j":sl_erase_j,
                "erase_w":sl_erase_w,
                "erase_h":sl_erase_h}
#tfm_args = {k:v.value for k,v in tfm_args_sl.items()}

In [8]:
#aia.reset(PILImage.create((path/"train/n02102040").ls()[32]).resize((size,size)))  
aia.reset(img)
out_tfms = widgets.interactive_output(aia.tfms, tfm_args_sl)
gui = widgets.HBox([widgets.VBox([btn_upload, box_brightness, box_crop, box_rotate, box_perspective, box_erase,
                            btn_act_loc, btn_create_gif]), out_tfms])

def on_upload_change(change):
    aia.reset(PILImage.create(btn_upload.data[-1]))
    if btn_act_loc.value != 'none':
        btn_act_loc.value = 'none'
    else:
        # quick (ugly) hack to reload the image for sure
        if sl_crop_size.value != width:
            sl_crop_size.value = width
        else:
            sl_crop_size.value = sl_crop_size.value - sl_crop_size.step 
    #aia.tfms(**tfm_args) #doesn't reset the gui output

btn_upload.observe(on_upload_change, names='_counter')
display(gui)

HBox(children=(VBox(children=(FileUpload(value={}, description='Your Image'), VBox(children=(HTML(value='<em>B…

[Open Gif](../imgs/results.gif "segment")
