# Network Dissection

Network dissection is a systematic method for finding single units (filters, or neurons) that match meaningful semantic concepts in a vision network, and for quantifying the closeness of the match.

In this notebook we will use network dissection to explore the neurons in an image classification network.

Our fundamental question is this: how does the network decompose the task of understanding that an image is a baseball field?  Does it identify any features that are understandable to a human?

Simply running this notebook will provide a simple dissection, but at each step, there are exercises for modifying the notebook to find more interesting results.

## About the netdissect library

The netdissect library contains several useful packages for inspecting internals of a vision network.
Here are packages that we use in this notebook:

 * **nethook** wraps any pytorch model, adding the ability to record or modify any internal computation.
 * **imgviz** provides ImageVisualizer, that collects together several useful image visualization functions.
 * **show** arranges nested arrays of PIL images and strings as nicely formatted HTML for display in a notebook.
 * **segmenter** provides an interface and a pretrained implementation for a semantic segmentation network.
 * **tally** gathers statistics over a dataset, based on your function to compute features for each datum.
 * **renormalize** deals with conversions between the zoo of RGB encoding scales typically seen in vision data.
 * **upsample** provids simple functions for resampling grid data at higher or lower resolutions.
 * **pbar** is a progress bar.

These will be explained a bit more in the exercises below.  Of course you can always run `help(object)` for a bit more information on most things in the library.

In addition, for this tutorial we have a package **settting**, which automatically downloads and creates datasets and pretrained models that we will be looking at.

In [None]:
import torch, os, matplotlib.pyplot as plt
from netdissect import nethook, imgviz, show, segmenter, renormalize, upsample, tally, pbar
from netdissect import setting

## Loading pretrained models and data

Here are some fixed variables that we define up-front for all the objects that we will be inspecting in this tutorial.

* **model** is the network we will look at.  It is a VGG convolutional network, trained to classify images of scenes into one of 365 place categories.  We wrap `model` as a `nethook.InstrumenteModel` so that we can easily retrieve and modiry its internal activations.
* **ds** is a small held-out sample from the Places dataset that was used to train the model; each entry is a pytorch tensor representing an image, and an integer representing the class.  A pytorch dataset can be derefernces like an array, so `ds[35]` is a pair `(x, y)` where `x` is a tensor containing RGB image data for a scene and `y` is an integer for the human-given class label.  Classnames are available as `ds.class[y]`.
* **renorm** is a function that renormalizes RGB data from the staistically-based scaling used in `ds` to a simple `[-1...1]` range scale.
* **segmodel** is a semantic segmentation network trained to recognize a large vocabulary of objects and parts of objects within scenes.  We will use it as a reference, to see if there are any internal filters that approximately match the same concepts.
* **seglabels** are human-readable names for the numerical segmentation classes.
* **iv** is an image visualization object that visualizes 2d data such as images and heatmaps as 224x224 images.
* **ivsmall** is another visualization object, but outputs smaller 56x56 images.
* **resfile** is a function that generates filenames in a results subdirectory that we will use for caching data.

In [None]:
model = setting.load_vgg16()
model = nethook.InstrumentedModel(model)
model.cuda()
ds = setting.load_dataset('places', 'val')
renorm = renormalize.renormalizer(ds, target='zc')
segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')
iv = imgviz.ImageVisualizer(224, source=ds, percent_level=0.99)
ivsmall = imgviz.ImageVisualizer((56, 56), source=ds, percent_level=0.99)
def resfile(f):
    return os.path.join('results/vgg16-places', f)

## Step 1: test the model

When the model is run on a batch of images, it e.

In the short example below:
* **indexes** is a list of dataset indexes to retrieve.  `i` indicates a dataset index, and `j` is an index into the indexes array.
* **batch** is a `12 x 3 x 224 x 224` tensor that stacks up twelve RGB 224x224 images from the dataset.
* When we run `model(batch.cuda())`, it scores every image for every class, making a `12 x 365` tensor of scores.
* Then `.max(dim=1)` finds the maximum of 365 scores for each image; it returns a (scores, indexes) tuple.
* **preds** is a tensor of 12 highest scoring class indexes (each one a number up to 365) predicted by the model.
* `iv.image(batch[j])` turns the jth `3 x 224 x 224` tensor into a PIL image for display.
* `ds.classes[ds[i][1]]` shows the human ground-truth label for the `i`th image in the dataset.

So the loop shows a set of twelve images, each with the dataset label and the model prediction.

Scene classification is difficult and sometimes ambiguous; nevertheless the model does reasonably well.

In [None]:
indexes = range(100, 112)
batch = torch.stack([ds[i][0] for i in indexes])
preds = model(batch.cuda()).max(1)[1]
show([[
    iv.image(batch[j]),
    'label: ' + ds.classes[ds[i][1]],
    'pred: ' + ds.classes[preds[j]]]
    for j, i in enumerate(indexes)])

### Exercise 1: measure accuracy

Fix the loop below to measure accuracy of the model on a sample of 2000 images.

In [None]:
correct = 0
tested = 0
for imagebatch, labelbatch in pbar(torch.utils.data.DataLoader(ds, batch_size=100)):
    modelpreds = model(imagebatch.cuda()).max(1)[1]
    correct += 0 # fixme
    tested += len(labelbatch)
    if tested >= 2000:
        break
print('%d correct out of %d' % (correct, tested))


## Step 2: Examine raw unit activations.



In [None]:
layername = 'features.conv5_1'
model.retain_layer(layername)
model(batch.cuda())
acts = model.retained_layer(layername).cpu()
show([
    [
        [ivsmall.masked_image(batch[0], acts[0], u)],
        [ivsmall.heatmap(acts[0], u, mode='nearest')],
        'unit %d' % u
    ]
    for u in range(min(21, acts.shape[1]))
])

In [None]:
upfn = upsample.upsampler(
    target_shape=(56, 56),
    data_shape=(7, 7),
)

def flatten_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

rq = tally.tally_quantile(
    flatten_activations,
    dataset=ds,
    sample_size=1000,
    batch_size=100,
    cachefile=resfile(layername + '_rq.npz'))

In [None]:
plt.plot(rq.quantiles(0.9))

# Which unit is activating more often than the others?
rq.quantiles(0.9).max(0)[1]

In [None]:
seg = segmodel.segment_batch(renorm(batch).cuda(), downsample=4)
show([(iv.image(batch[i]), iv.segmentation(seg[i, 0]),
            iv.segment_key(seg[i,0], segmodel))
            for i in range(len(seg))])

In [None]:
sample_size = 1000

def max_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).max(2)[0]

def mean_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).mean(2)

topk = tally.tally_topk(
    mean_activations,
    dataset=ds,
    sample_size=sample_size,
    batch_size=100,
    cachefile=resfile(layername + '_mean_topk.npz')
)

top_indexes = topk.result()[1]

In [None]:
show.blocks([
    ['unit %d' % u,
     'img %d' % i,
     'pred: %s' % ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()],
     [iv.masked_image(
        ds[i][0],
        model.retained_layer(layername)[0],
        u)]
    ]
    for u in [12]
    for i in top_indexes[u, :20]
])

In [None]:
def compute_activations(image_batch):
    image_batch = image_batch.cuda()
    _ = model(image_batch)
    acts_batch = model.retained_layer(layername)
    return acts_batch

unit_images = iv.masked_images_for_topk(
    compute_activations,
    ds,
    topk,
    k=5,
    num_workers=10,
    pin_memory=True,
    cachefile=resfile(layername + '_top10images.npz'))

In [None]:
level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]

def compute_selected_segments(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

condi99 = tally.tally_conditional_mean(
    compute_selected_segments,
    dataset=ds,
    sample_size=sample_size,
    cachefile=resfile(layername + '_condi99.npz'))

iou99 = tally.iou_from_conditional_indicator_mean(condi99)
iou99.shape

In [None]:
iou_unit_label_99 = sorted([(
    unit, concept.item(), seglabels[concept], bestiou.item())
    for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
    key=lambda x: -x[-1])
for unit, concept, label, score in iou_unit_label_99[:20]:
    show(['unit %d; iou %g; label "%s"' % (unit, score, label),
          [unit_images[unit]]])


In [None]:
iou_threshold = 0.04
unit_label_99 = [
        (concept.item(), seglabels[concept],
            segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*iou99.max(0))]
labelcat_list = [labelcat
        for concept, label, labelcat, iou in unit_label_99
        if iou > iou_threshold]
import IPython
IPython.display.SVG(setting.graph_conceptcatlist(labelcat_list))