Just some imports and jupyter setup.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

In [None]:
import torch

device = torch.device('cuda:0')
# device = torch.device('cpu') # uncomment if no GPU

We're going to use a progressive GAN.

Let's analyze a pretrained model.  I've chosen an outdoor church model.

You can uncomment the model of your choice.

In [None]:
import torchvision
import torch.hub
from netdissect import nethook, proggan

# n = 'proggan_bedroom-d8a89ff1.pth'
n = 'proggan_churchoutdoor-7e701dd5.pth'
# n = 'proggan_conferenceroom-21e85882.pth'
# n = 'proggan_diningroom-3aa0ab80.pth'
# n = 'proggan_kitchen-67f1e16c.pth'
# n = 'proggan_livingroom-5ef336dd.pth'
# n = 'proggan_restaurant-b8578299.pth'

url = 'http://gandissect.csail.mit.edu/models/' + n
sd = torch.hub.load_state_dict_from_url(url)
model = proggan.from_state_dict(sd).to(device)
model

The GAN generator is just a function z->x that transforms random z to realistic images x.

To generate images, all we need is a source of random z.  Let's make a micro dataset with six random z.

In [None]:
from netdissect import zdataset
zds = zdataset.z_dataset_for_model(model, size=30, seed=5555)
len(zds), zds[0][0].shape

We can just invoke model(z[None,...]) to generate a single image.

In [None]:
# By data
model(zds[0][0][None,...].to(device))

In [None]:
from netdissect import renormalize, show
# from IPython.display import display

show([
    [renormalize.as_image(model(z[None,...].to(device))[0])]
    for [z] in zds
])


To analyze what a model is doing inside, we can wrap it with an InstrumentedModel, which makes it easy to hook or modify a particular layer.

In [None]:
# TODO: add a summary of what InstrumentedModel can do.
# retain a layer, get a retined layer, edit a layer

from netdissect import nethook
if not isinstance(model, nethook.InstrumentedModel):
    model = nethook.InstrumentedModel(model)
    model.retain_layer('layer4')

Now we can run the model and inspect the internal units.

In [None]:
from netdissect import imgviz
from importlib import reload
from netdissect import upsample
reload(upsample)
reload(imgviz)
img = model(zds[0][0][None,...].to(device))
acts = model.retained_layer('layer4')

# This is the intermediate value inside the network - how much data is it?
acts.shape

In [None]:
show([[1, 2, 3], [3,4], [5,6]])

In [None]:
iv = imgviz.ImageVisualizer((100, 100), image_size=(256,256))
show(
    [['unit %d' % u,
      [iv.image(img[0])],
      [iv.masked_image(img[0], acts, (0,u))],
      [iv.heatmap(acts, (0,u), mode='nearest')],
     ] for u in range(400, 424)]  
)

Each unit has a idfferent scale, which makes the heatmaps harder to interpret.

We can normalize the scales by collecting stats.

In [None]:
print(acts.shape)
print(acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1]).shape)

In [None]:
from netdissect import tally
reload(tally)
upfn = upsample.upsampler(
    (64, 64),                     # The target output shape
    (8, 8),                       # The source data shape
    image_size=(256, 256)         # The actual image shape
)

# To collect stats, define a function that returns 2d [samples, units]
def compute_samples(batch):
    image_batch = batch[0].cuda()
    _ = model(image_batch)
    acts = model.retained_layer('layer4')
    # hacts = upfn(acts)
    return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

rq = tally.tally_quantile(compute_samples, zds)


In [None]:
# This tells me now, for example, what the means are for channel,
# rq.mean()
# what median is,
# rq.quantiles([0.5])
# Or what the 99th percentile quantile is.
# rq.quantiles([0.99])

rq.mean()

Now we can see all the activations on a reasonable scale.

In [None]:
iv = imgviz.ImageVisualizer((100, 100), image_size=(256,256), quantiles=rq)
show(
    [[
       'unit %d' % u,
       [iv.image(img[0])],
       [iv.masked_image(img[0], acts, (0,u))],
       [iv.heatmap(acts, (0,u), mode='nearest')],
    ]
      for u in range(400, 424)]
)

Let's quantify what's inside these images by segmenting them.

In [None]:
# TODO, I need to make this make this downloadable

from netdissect import segmenter
segmodel = segmenter.UnifiedParsingSegmenter(segsizes=[256])
seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]]

indices = range(200,204)
batch = torch.cat([dataset[i][0][None,...] for i in indices])
preds = model(batch.cuda()).max(1)[1]
imgs = [renormalize.as_image(t, source=dataset) for t in batch]
prednames = [classlabels[p.item()] for p in preds]

