# Download datasets and model
(skip if you have already done this)

By default, this will download all datasets and all models. 
You can fine-tune what you download by passing the --datasets flag for download_datasets.py and the --datasets and --seeds flags for download_model.py. 
Files will be downloaded to ./datasets and ./models by default.
To change the directories pass the --dest flag.

In [4]:
# --datasets pokemon celeba stl-10 --dest ./datasets
!python download_datasets.py
# --datasets pokemonceleba stl-10 --seeds 0 1 2 --dest ./models
!python download_models.py 

Downloading 1PLWW05OeHCWAR6o91PI8OyCLbXpOYSrZ into ./datasets/pokemon/pokemon.zip... Done.
Unzipping...Done.
all done!
  from collections import Sequence
Downloading 1okRM6Lqu5XJL2sQFrmOO1KifW2K4d9x2 into ./models/pokemon-0.zip... Done.
Unzipping...Done.
all done!


# Imports

In [26]:
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as ipw
from collections import defaultdict, Sequence, OrderedDict
import os
from datasets import Representations
import cv2

# Load dataset and models

In [31]:
# specify which dataset and models to load
dataset_type = 'pokemon'
seed = 0
model_dir = 'models/'
data_dir = 'datasets/'

In [None]:
# load all models for this dataset and seed
sizes = [4, 8, 16] if dataset_type == 'pokemon' else [3, 6, 12]
input_size = 128 if dataset_type == 'pokemon' else 96
models = defaultdict(dict)
for size in sizes:
    channel_configs = reversed([(3 * input_size ** 2) // (4 ** k * size ** 2) for k in range(4)])
    for channels in channel_configs:
        filename = f'model-{dataset}-{size}-{channels}-{seed}.pt'
        models[size][channels] = pt.load(os.path.join(model_dir, filename), map_location=pt.device('cpu'))

In [32]:
# load dataset
mode = 'test'
dataset = Representations(os.path.join(data_dir, dataset_type, f'codes_{mode}.pt'), fraction=1, shuffle=0)

# Visualize reconstructions

In [17]:
def plot_samples(i, cuda=False):
    global models
    global dataset
    figsize = [16, 9]
    subsize=3
    
    # setup pyplot figure with subplots
    fig= plt.figure(figsize=figsize)
    original = plt.subplot2grid(figsize[::-1], (3, 0), rowspan=3, colspan=3)
    axs = []
    for j in range(3):
        axs += [[plt.subplot2grid(figsize[::-1], (subsize*j, 4+subsize*k), rowspan=subsize, colspan=subsize, xticks=[], yticks=[]) for k in range(4)]]
    
    # populate subplots
    sample = dataset[i][0].float().cuda() if cuda else dataset[i][0].float()
    im_size = sample.shape[-1]
    original.imshow(sample.permute(1,2,0).cpu())
    original.set_axis_off()

    first_row = True
    for ax, (size, models_sub) in zip(axs, models.items()):
        first_col = True
        for a, (channels, model) in zip(ax, models_sub.items()):
            model = model.float().cuda() if cuda else model.float().cpu()
            pred = model(sample[None, ...])[0]
            pred = pred/pred.max()
            a.imshow(pred.permute(1,2,0).cpu())
            if first_row:
                a.set_title(f'{round(100*(channels*size**2)/(3*im_size**2), 2)}%', fontsize=28)
            if first_col:
                a.set_ylabel(size, fontsize=28)
                first_col = False
        first_row = False
    plt.subplots_adjust(0,0,1,1,0.03,0.03)
    plt.show()
    
ipw.interact(plot_samples, i=(0, len(dataset)))

interactive(children=(IntSlider(value=317, description='i', max=634), Checkbox(value=False, description='cuda'…

<function __main__.plot_samples(i, cuda=False)>

# Inspect feature maps

visualize feature maps after each layer

left image is the feature map overlayed onto the input image
right image is just the feature map (red = positive, blue = negative)

name of the layer is shown in the upper left

In [34]:
def normalize(image, positive=False, relu=False, cap=1.1):
    image = np.clip(image, -cap, cap)
    if image.max() > 1:
        image[image>=0] = image[image>=0]/image.max()
    if image.min() < -1:
        image[image<0] = image[image<0] / (-image.min())
    if positive and image.min()<0:
        return (image + 1)/2
    if relu:
        image[image<0] = 0
        return image
    return image

def collect(model, l):
    if hasattr(model, 'register_forward_hook'):
        model.register_forward_hook(lambda module, input, output: l.append((module.__class__, output)))
    if hasattr(model, 'children'):
        for child in model.children(): collect(child, l)
    if isinstance(model, Sequence):
        for layer in model: collect(layer, l)
            
def release(model):
    if hasattr(model, '_forward_hooks'):
        model._forward_hooks = OrderedDict()
    if hasattr(model, 'children'):
        for child in model.children(): release(child)
    if isinstance(model, Sequence):
        for layer in model: release(layer)

class FeatureMaps(object):
    
    def __init__(self, model, dataset, data_transform=lambda x: x[0][None, ...], layer=None, channel=None, interpolation=cv2.INTER_LINEAR, filters=[]):
        self.collection = []
        self.model = model
        self.dataset = dataset
        self.data_transform = data_transform
        self.current_sample = None
        self.current_index = 0
        self.current_layer = None
        self.current_map = None
        self.interpolation = interpolation
        self.filters = filters
        self.layer = ipw.IntSlider(min=0, value=0, continuous_update=False) if layer is None else layer
        self.channel = ipw.IntSlider(min=0, value=0, continuous_update=False) if channel is None else channel
        
        self.change_sample(0)
        self.change_feature_map(0)
        
    def change_sample(self, index):
        self.current_index = index
        sample = self.dataset[index]
        self.current_sample = self.data_transform(sample).to(next(iter(self.model.parameters())))
        self.update_feature_maps()
            
    def change_feature_map(self, layer):
        self.current_map = self.collection[layer][1][0]
        print(self.collection[layer][0].__name__)
        self.channel.max = self.current_map.shape[0]-1
            
    def update_feature_maps(self):
        self.collection = []
        collect(self.model, self.collection)
        self.model(self.current_sample)
        release(self.model)
        if len(self.filters) > 0:
            self.collection = [c for c in self.collection if c[0].__name__ in self.filters]
        self.layer.max = len(self.collection)-1
        
    def get_feature_map(self, index, layer, channel):
        
        if index != self.current_index:
            self.change_sample(index)
        if layer != self.layer:
            self.change_feature_map(layer)
        try:
            fmap = self.current_map[channel].float().cpu().numpy()
            fmap = cv2.resize(fmap, dsize=self.current_sample.shape[2:], interpolation=self.interpolation)
            return fmap
        except IndexError:
            return None
        
    def show_feature_map(self, index, layer, channel, alpha=0.65):
        
        fmap = self.get_feature_map(index, layer, channel)
        if fmap is None: return
        sample = self.current_sample.float().cpu().numpy()[0].mean(0)
        sample = normalize(sample)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10.24, 5.12))
        ax1.imshow(sample, vmin=0, vmax=1, alpha=1-alpha, cmap='gray')
        ax1.imshow(fmap, vmin=0, vmax=1, alpha=alpha, cmap='inferno')
        ax1.set_axis_off()
        im = ax2.imshow(normalize(fmap), 'seismic', vmin=-1, vmax=1)
        ax2.set_axis_off()
        plt.subplots_adjust(0,0,1,1,0,0)
        plt.show()
        
maps = FeatureMaps(models[16][3].cuda(), dataset, filters=[])#['GeneralConvolution', 'ResBlock2d'])
ipw.interact(maps.show_feature_map, layer=maps.layer, channel=maps.channel, index=ipw.IntSlider(min=0, max=len(dataset)-1, value=0, continuous_update=False))

ZeroPad2d


interactive(children=(IntSlider(value=0, continuous_update=False, description='index', max=157), IntSlider(val…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>