In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader
import numpy as np
from torchvision.models import resnet18

## Download Data and Pretrained Weights

In [2]:
net = resnet18(pretrained=True)

In [3]:
from torchvision.datasets import CIFAR10

In [4]:
cif = CIFAR10('.', download=True, transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]))
dl = DataLoader(cif)

Files already downloaded and verified


In [5]:
# test network pipeline
for x, y in dl:   
    net(x)
    break

In [6]:
x.shape

torch.Size([1, 3, 224, 224])

## Visualization Pipeline

In [7]:
from visualize_feature import visualize_feats
import matplotlib.pyplot as plt
%matplotlib inline

In [8]:
modules = [('layer4', 0)]
input_args = {'shape': (1, 3, 224, 224)}

In [9]:
imgs = visualize_feats(net, *modules, num_updates=1, **input_args)

ValueError: Too many dimensions: 3 > 2.

In [10]:
from visualize_feature import InputOptimizer
inpopt = InputOptimizer(None, (1, 3, 224, 224), -1.0, 1.0)

In [12]:
rand_inp = inpopt._random_input()

In [13]:
rand_inp

tensor([[[[ 7.1191e-01, -4.2220e-01,  6.8434e-01,  ...,  8.0793e-01,
           -8.7538e-01,  6.3234e-02],
          [-2.0383e-01, -4.4825e-01,  2.6398e-02,  ..., -3.7541e-01,
           -3.8618e-01,  3.9126e-01],
          [-1.2458e-01,  1.5779e-01, -9.2306e-01,  ...,  5.6013e-02,
            2.5133e-01, -9.2479e-01],
          ...,
          [-7.1287e-01,  6.6691e-01,  1.1989e-01,  ...,  8.7140e-01,
            2.4641e-01,  3.4119e-01],
          [ 7.1168e-01,  6.2305e-01, -6.5067e-01,  ...,  9.4101e-01,
            1.4871e-01, -4.6589e-01],
          [ 6.5651e-01, -4.2035e-01, -8.5154e-01,  ..., -2.8733e-01,
            6.1462e-01, -4.1284e-01]],

         [[-2.0758e-01, -5.7112e-01,  5.9778e-01,  ...,  5.1396e-01,
           -8.2715e-01, -9.9032e-01],
          [ 1.3298e-01,  6.3986e-01, -5.9811e-01,  ..., -3.6895e-01,
           -6.1628e-01, -1.4478e-01],
          [-1.0403e-01, -5.2085e-01, -8.9522e-01,  ..., -8.4554e-01,
            1.6304e-01, -8.7285e-01],
          ...,
     

In [16]:
rand_inp = rand_inp.data.cpu().numpy()

In [17]:
rand_inp = rand_inp.squeeze()

In [24]:
rand_inp -= rand_inp.min(axis=(1,2))[:, None, None]

In [25]:
rand_inp *= 255. / rand_inp.max(axis=(1,2))[:, None, None]

In [26]:
rand_inp

array([[[218.27548  ,  73.67101  , 214.75996  , ..., 230.51814  ,
          15.889099 , 135.56621  ],
        [101.51428  ,  70.34941  , 130.86942  , ...,  79.63676  ,
          78.26343  , 177.39142  ],
        [111.61932  , 147.62259  ,   9.809849 , ..., 134.64548  ,
         159.54988  ,   9.588914 ],
        ...,
        [ 36.609367 , 212.53743  , 142.79031  , ..., 238.61133  ,
         158.92133  , 171.0063   ],
        [218.24612  , 206.9452   ,  44.539883 , ..., 247.48615  ,
         146.46503  ,  68.100395 ],
        [211.21187  ,  73.90754  ,  18.928226 , ...,  90.86755  ,
         205.87003  ,  74.86421  ]],

       [[101.03495  ,  54.681976 , 203.72055  , ..., 193.03418  ,
          22.037733 ,   1.2321112],
        [144.45775  , 209.08665  ,  51.24109  , ...,  80.45938  ,
          48.924416 , 109.041595 ],
        [114.23796  ,  61.091465 ,  13.358433 , ...,  19.69254  ,
         148.28952  ,  16.210192 ],
        ...,
        [ 75.5657   , 254.4899   , 233.30627  , ..., 2