In [11]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from fastai.vision.all import *
from torchvision import transforms
import torchvision.transforms.functional as TF
import ipywidgets as widgets
from PIL import Image as pilImage
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

import sys
sys.path.append("../") 
%load_ext autoreload
%autoreload 2
from src.utils import *
from src.gradcam import GuidedGradCam
from src.augmentationImpactAnalyzer import AugmentationImpactAnalyzer 

ROOT_DIR =  Path('../')
DATA_PATH = ROOT_DIR/'data/'
IMGS_PATH = ROOT_DIR/'imgs/'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
np.random.seed(0)
torch.manual_seed(0)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
device

device(type='cpu')

In [13]:
# load IMAGENETTE_160 and 
path = untar_data(URLs.IMAGENETTE_160,dest=DATA_PATH)
#print((path/'train').ls())

## display an example
fname = (path/"train/n02102040").ls()[1]
#PILImage.create(fname).resize((160,160))

In [14]:
# load the classes to get the right labels
import urllib.request, json 
with urllib.request.urlopen('https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json') as url:
    classes = json.loads(url.read().decode())
f_to_idx = {val[0]:idx for idx,val in classes.items()}
my_classes = [f_to_idx[f.name] for f in (path/"train").ls()]


In [15]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
ggc = GuidedGradCam(model,use_cuda,target_type='classification', layer_ids=['layer4.2'])



Using cache found in /Users/ap/.cache/torch/hub/pytorch_vision_v0.6.0


register hooks for:
layer4.2


In [16]:
## ---------- test gradcam ---------- 
#fname = (path/"train/n02102040").ls()[1]
#img = PILImage.create(fname)#.resize((160,160))
#normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                 std=[0.229, 0.224, 0.225])
#x = normalize(transforms.ToTensor()(img)).unsqueeze(0)
#heatmap, gb, cam_gb = ggc(x.requires_grad_(True))
#
#fig, ax = plt.subplots(1,4,figsize=(18,6))
#ax[0].imshow(img)
#cam = np.float32(img)+np.float32(heatmap)
#ax[1].imshow(arr_to_img(cam))
#ax[2].imshow(gb)
#ax[3].imshow(cam_gb)
#plt.show()

In [17]:

# use imagenet stats        
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
size = 160
img = PILImage.create(fname).resize((size,size))
aia = AugmentationImpactAnalyzer(img,
                                model,
                                cuda=use_cuda,
                                add_output_act=True,
                                restrict_classes={i:classes[i][1] for i in my_classes},
                                normalize=normalize,
                                guided_grad_cam=ggc)


In [18]:
## ---------- test image transformation class ---------- 
#aia.tfms(crop_size=140,highlight_act='gradcam',perspective_w=180,perspective_h=180,perspective_d=20,erase_w=10,erase_h=10,rotate_ang=90)

In [19]:
width, height = aia.img.shape

import ipywidgets as widgets
def create_gif_on_click(change):
    os.makedirs('../imgs/',exist_ok=True)
    aia.create_gif(IMGS_PATH/'results.gif')

btn_create_gif = widgets.Button(description='Create gif')
btn_create_gif.on_click(create_gif_on_click)

btn_act_loc = widgets.RadioButtons(
    options=['none','gradcam','guided-gradient','guided-gradcam',],
    description='Activation Localization:',
    disabled=False)

sl_brightness = widgets.FloatSlider(value=1,min=0.1,max=4,step=0.2)
box_brightness = widgets.VBox([widgets.HTML('<em>Brightness</em>'),sl_brightness])

sl_crop_size = widgets.IntSlider(value=width,min=33,max=width,step=5)
box_crop = widgets.VBox([widgets.HTML('<em>Center Crop</em>'),sl_crop_size])

sl_rotate = widgets.IntSlider(value=0,min=0,max=360,step=5)
box_rotate = widgets.VBox([widgets.HTML('<em>Rotation Angle</em>'),sl_rotate])
sl_perspective_w = widgets.IntSlider(value=width,min=0,max=width,step=5)
sl_perspective_h = widgets.IntSlider(value=height,min=0,max=height,step=5)
sl_perspective_d = widgets.IntSlider(value=0,min=0,max=height,step=5)
box_perspective = widgets.VBox([widgets.HTML('<em>Perspective Distortion</em>'),sl_perspective_w,sl_perspective_h,sl_perspective_d])

sl_erase_i = widgets.IntSlider(value=0,min=0,max=width,step=5)
sl_erase_j = widgets.IntSlider(value=0,min=0,max=height,step=5)
sl_erase_w = widgets.IntSlider(value=0,min=0,max=height,step=5)
sl_erase_h = widgets.IntSlider(value=0,min=0,max=height,step=5)
box_erase = widgets.VBox([widgets.HTML('<em>Erase Box</em>'),sl_erase_i,sl_erase_j,sl_erase_w, sl_erase_h])

btn_upload = widgets.FileUpload(description='Your Image')

tfm_args_sl = { "brightness":sl_brightness,
                "crop_size":sl_crop_size,
                "activation_localization":btn_act_loc, 
                "rotate_ang":sl_rotate,
                "perspective_w":sl_perspective_w,
                "perspective_h":sl_perspective_h,
                "perspective_d":sl_perspective_d,
                "erase_i":sl_erase_i,
                "erase_j":sl_erase_j,
                "erase_w":sl_erase_w,
                "erase_h":sl_erase_h}
#tfm_args = {k:v.value for k,v in tfm_args_sl.items()}

In [20]:
#aia.reset(PILImage.create((path/"train/n02102040").ls()[32]).resize((size,size)))  
aia.reset(img)
out_tfms = widgets.interactive_output(aia.tfms, tfm_args_sl)
gui = widgets.HBox([widgets.VBox([btn_upload, box_brightness, box_crop, box_rotate, box_perspective, box_erase,
                            btn_act_loc, btn_create_gif]), out_tfms])

def on_upload_change(change):
    aia.reset(PILImage.create(btn_upload.data[-1]))
    if btn_act_loc.value != 'none':
        btn_act_loc.value = 'none'
    else:
        # quick (ugly) hack to reload the image for sure
        if sl_crop_size.value != width:
            sl_crop_size.value = width
        else:
            sl_crop_size.value = sl_crop_size.value - sl_crop_size.step 
    #aia.tfms(**tfm_args) #doesn't reset the gui output

btn_upload.observe(on_upload_change, names='_counter')
display(gui)

HBox(children=(VBox(children=(FileUpload(value={}, description='Your Image'), VBox(children=(HTML(value='<em>B…

[Open Gif](../imgs/results.gif "segment")


In [38]:
# to save the images separate 
#os.makedirs('../imgs/img_list/',exist_ok=True)
#os.makedirs('../imgs/results/',exist_ok=True)
#
#for i,pic in enumerate(imgs):
#    pic.save(fp=f'../imgs/img_list/img_{str(i).zfill(2)}.png', format='PNG')

# to use ffmpeg I need to deactivate conda and call this in the terminal
# but it doesn't seem for me that the results are better, so I leave it for now
#ffmpeg -f image2 -i imgs/img_list/img_%02d.png -vf scale=2480:-1:sws_dither=ed,palettegen imgs/results/palette.png -y
#ffmpeg -f image2 -framerate 10. -i imgs/img_list/img_%02d.png imgs/results/img.flv -y
#ffmpeg -i imgs/results/img.flv -i imgs/results/palette.png -filter_complex "fps=10,scale=248:-1:flags=lanczos[x];[x][1:v]paletteuse" imgs/results/test.gif -y
#ffmpeg -i imgs/results/img.flv -i imgs/results/palette.png -filter_complex "fps=10,scale=248:-1:flags=lanczos[x];[x][1:v]paletteuse" -loop -1 imgs/results/test_no_loop.gif -y
