# Network Dissection (for ResNet152)

In this notebook, we will examine internal layer representations for a classifier trained to recognize scene categories.

Setup matplotlib, torch, and numpy for a high-resolution browser.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

load resnet pretrained on places

In [None]:
import torchvision
import torch.hub
from netdissect import oldresnet152
model = oldresnet152.OldResNet152()
# url = 'http://gandissect.csail.mit.edu/models/resnet18_places365-2f475921.pth'
# url = 'http://gandissect.csail.mit.edu/models/resnet50_places365-46529c86.pth'
url = 'http://gandissect.csail.mit.edu/models/resnet152_places365-f928166e5c.pth'
try:
    sd = torch.hub.load_state_dict_from_url(url) # pytorch 1.1
except:
    sd = torch.hub.model_zoo.load_url(url) # pytorch 1.0
# sd = sd['state_dict']
# sd = {k.replace('module.', ''): v for k, v in sd.items()}
model.load_state_dict(sd)

from netdissect import nethook
model = nethook.InstrumentedModel(model)
model = model.cuda()
model

In [None]:
# Load labels
from urllib.request import urlopen

synset_url = 'http://gandissect.csail.mit.edu/models/categories_places365.txt'
classlabels = [r.split(' ')[0][3:] for r in urlopen(synset_url).read().decode('utf-8').split('\n')]

load segmenter

In [None]:
from netdissect import segmenter
segmodel = segmenter.UnifiedParsingSegmenter(segsizes=[256])
seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]]

load places dataset

In [None]:
from importlib import reload
from netdissect import parallelfolder, renormalize
from torchvision import transforms

reload(parallelfolder)

center_crop = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        renormalize.NORMALIZER['imagenet']
])

dataset = parallelfolder.ParallelImageFolders(
    ['dataset/places/val'], transform=[center_crop],
    classification=True,
    shuffle=True)

train_dataset = parallelfolder.ParallelImageFolders(
    ['dataset/places/train'], transform=[center_crop],
    classification=True,
    shuffle=True)

Test classifier on some images

In [None]:
from netdissect import renormalize

indices = [200, 755, 709, 423] #range(200,224)
batch = torch.cat([dataset[i][0][None,...] for i in indices])
preds = model(batch.cuda()).max(1)[1]
imgs = [renormalize.as_image(t, source=dataset) for t in batch]
prednames = [classlabels[p.item()] for p in preds]
truenames = [classlabels[dataset[i][1]] for i in indices]

In [None]:
from netdissect import show

show([[img, pred, tn] for img, pred, tn in zip(imgs, prednames, truenames)])

create adapter to segmenter

In [None]:
from netdissect import renormalize
reload(renormalize)
renorm = renormalize.renormalizer(dataset, mode='zc')

segment single image, and visualize the labels

In [None]:
from netdissect import imgviz, upsample, segviz
reload(segviz)

iv = imgviz.ImageVisualizer(120, source=dataset)
            
seg = segmodel.segment_batch(renorm(batch).cuda(), downsample=4)

show([(iv.image(batch[i]), iv.segmentation(seg[i,0]),
            iv.segment_key(seg[i,0], segmodel))
            for i in range(len(seg))])

visualize activations for single layer of single image

In [None]:
# model.stop_retaining_layers([('features.8', 'conv5')])
model.retain_layer(('7'))
batch = torch.cat([dataset[i][0][None,...] for i in indices])
preds = model(batch.cuda()).max(1)[1]
imgs = [renormalize.as_image(t, source=dataset) for t in batch]
prednames = [classlabels[p.item()] for p in preds]

In [None]:
from netdissect import imgviz

acts = model.retained_layer('7').cpu()
ivsmall = imgviz.ImageVisualizer((56, 56), source=dataset)
show.blocks(
    [[[ivsmall.masked_image(batch[0], acts, (0, u))],
      [ivsmall.heatmap(acts, (0, u), mode='nearest')]] for u in range(max(acts.shape[1], 50))]
)
unit_count = acts.shape[1]

## Collect quantile statistics

First, unconditional quantiles over the activations.  We will upsample them to 56x56 to match with segmentations later.


In [None]:
from netdissect import tally

upfn = upsample.upsampler(
    (56, 56),                     # The target output shape
    (7, 7),
    source=dataset,
)

def compute_samples(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer('7')
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])


rq = tally.tally_quantile(compute_samples, dataset, sample_size=1000)


Next, let's collect bincounts of the segmentations, also unconditional.

In [None]:
from netdissect import tally, upsample

def compute_segments(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    return seg

from netdissect import tally

segbc = tally.tally_bincount(compute_segments,
        dataset, sample_size=1000, multi_label_axis=1)

Here is the main statistic: condq is the conditional quantile statistics.

In [None]:
    from netdissect import tally
    reload(tally)

    def compute_conditional_samples(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer('7')
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        hacts = torch.nn.functional.interpolate(acts, size=seg.shape[2:],
                                                mode='bilinear', align_corners=False)
        hacts = upfn(acts)
        return tally.conditional_samples(hacts, seg)


    condq = tally.tally_conditional_quantile(compute_conditional_samples,
            dataset, sample_size=1000)


Condvar is a conditional mean (and variance) statistic, collected directly so it's more accurate than you can obtain from the quantile sketch.

In [None]:
from netdissect import tally

condvar = tally.tally_conditional_mean(compute_conditional_samples,
        dataset, sample_size=1000)

### Wasserstein distance experiment.

(Commented out for now) - for each segmented concept, find the unit whose behavior changes the most when the segmented concept is present, according to the wasserstein distance between the unconditional distribution and the distribution with the concept present.

In [None]:
import numpy

# Compute the wasserstein distance between the unconditional and conditional distribution of the activation.
# And then rank units by maximium wasserstein disance.
print('Unit with largest wasserstein shift under each segmented concept')
for conceptnum, label in enumerate(seglabels):
    if conceptnum not in condq.keys():
        continue
    amt, ind = ((condq.conditional(conceptnum).readout(1000) -
      condq.conditional(0).readout(1000)).abs().sum(1)/1000).max(0)
    print(conceptnum, amt.item(), 'unit', ind.item(), condq.conditional(conceptnum).size(), label)

In [None]:
# Plot the wasserstein shift for one unit under a concept where the shift is large.

fig, ax = plt.subplots()
segconcept = 9 # window
unit = 1058  # largest wasserstein shift under window
def dens(cum):
    return cum[1:] - cum[:-1]
baseline = condq.conditional(0).readout(1001)[unit].numpy()
conditioned = condq.conditional(segconcept).readout(1001)[unit].numpy()
top = max(baseline.max(), conditioned.max())
buckets = numpy.linspace(0, top, 25)
ax.hist(baseline, buckets, alpha=0.5)
ax.hist(conditioned, buckets, alpha=0.5)

## Visualizing max activating images.

visualize high-activation regions for single image.

To support this function, we need topk statistics over a sample.  This collects them:

In [None]:
from netdissect import tally
reload(tally)


def compute_image_max(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer('7')
    acts = acts.view(acts.shape[0], acts.shape[1], -1)
    acts = acts.max(2)[0]
    return acts

topk = tally.tally_topk(compute_image_max, dataset, sample_size=1000)

This function defines a visualization of one unit.

In [None]:
from netdissect import imgviz
from IPython.display import display

reload(imgviz)
iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq)

def unit_viz_row(unitnum, percent_level=None):
    out = []
    for imgnum in topk.result()[1][unitnum][:8]:
        img = dataset[imgnum][0][None,...].cuda()
        scores = model(img.cuda())
        pred = classlabels[scores.max(1)[1].item()].split('/')[0]
        acts = model.retained_layer('7')
        out.append([# [iv.image(img[0]), pred],
                    # [iv.heatmap(acts, (0, unitnum), mode='nearest'), str(acts[0, unitnum].max().item())[:5]],
                    [iv.masked_image(img[0], acts, (0, unitnum), percent_level=percent_level), imgnum.item()],
                   ])
    return out
display(show.blocks(unit_viz_row(444)))

Simpler version of the wasserstein shift experiment:

For each concept, just show the unit that has the largest difference of conditional mean vs unconditional mean, when normalized by the unconditional mean.


In [None]:
conceptlist = []
for conceptnum, label in enumerate(seglabels):
    if not condq.has_conditional(conceptnum):
        continue
    ratio, index = (abs(condq.conditional(conceptnum).mean() - rq.mean()) / rq.mean()).max(0)
    stdev = ((condq.conditional(conceptnum).stdev()) / rq.mean())[index]
    print(label, 'unit', index.item(), 'ratio', ratio.item(), 'size', condq.conditional(conceptnum).size())
    conceptlist.append(label)


In [None]:
from netdissect import bargraph
reload(bargraph)
from IPython.display import display, SVG, HTML
from collections import defaultdict

def graph_conceptlist(conceptlist):
    count = defaultdict(int)
    for c in conceptlist:
        count[c] += 1
    labels, counts = zip(*sorted(count.items(), key=lambda x: -x[1]))
    return HTML('<div style="height:200px;width:5000px">' + bargraph.make_svg_bargraph(labels, counts) + '</div>')

Experiment 1.

Assign a label to each unit according to the highest iou at fixed threshold.

In [None]:
import math
iouscores_at_99 = torch.zeros((max(condq.keys()) + 1, 2048))
# Compute at fixed quantile
actquantile = 0.01
actlevel = condq.conditional(0).quantiles([1 - actquantile])[:,0]
for c in sorted(condq.keys()):
    if c == 0 or condq.conditional(c).batchcount <= 1:
        continue
    levelp = condq.conditional(c).normalize(actlevel)
    cp = float(condq.conditional(c).size()) / condq.conditional(0).size()
    iouscores_at_99[c] = cp * (1 - levelp) / (actquantile + cp * levelp)
conceptlist_at_99, unitlist_at_99 = [], []
for u in range(2048):
    iou, c = iouscores_at_99[:,u].max(0)
    c = c.item()
    diff = condvar.conditional(c).mean()[u] - condvar.conditional(0).mean()[u]
    unitlist_at_99.append(dict(
        unit=u,
        label=seglabels[c],
        iou=iou.item(),
        diff=diff.item(),
        cnt=condvar.conditional(c).batchcount,
    ))
    conceptlist_at_99.append(seglabels[c])
for d in sorted(unitlist_at_99, key=lambda x: -x['iou'])[:20]:
    display(show.blocks([[d['label'],
                          'iou %.2f' % d['iou'],
                          'dm %.2f' % d['diff'],
                          'cnt %d' % d['cnt'],
                          'unit %d' % d['unit']]] + unit_viz_row(d['unit'])))


In [None]:
for d in unitlist_at_99:
    if d['unit'] in [1242, 636]:
        display(show.blocks([[d['label'],
                          'iou %.2f' % d['iou'],
                          'dm %.2f' % d['diff'],
                          'cnt %d' % d['cnt'],
                          'unit %d' % d['unit']]] + unit_viz_row(d['unit'])))

In [None]:
iouscores_at_99 > 0.01

In [None]:
display(graph_conceptlist([d['label'] for d in unitlist_at_99])) #  if d['iou'] > 0.05]))
print(str(len([d['label'] for d in unitlist_at_99 if d['iou'] > 0.05])) + ' units')

In [None]:
conceptlist_at_99

Experiment 1.2:

Now repeat, but adjust thresholds so that we measure each segmented concept at a quantile that matches the pixel frequency of that concept in the dataset.

In [None]:
import math
from netdissect import pbar

iouscores_at_match = torch.zeros((max(condq.keys()) + 1, unit_count))
# Compute at matching quantile
concept_percent_level = torch.zeros(max(condq.keys()) + 1)

print('step 1')
for c in pbar(sorted(condvar.keys())):
    if c == 0 or condvar.conditional(c).batchcount <= 1:
        continue
    cp = float(condq.conditional(c).size()) / condq.conditional(0).size()
    concept_percent_level[c] = 1 - cp
    actquantile = cp
    actlevel = condq.conditional(0).quantiles([1 - actquantile])[:,0]
    levelp = condq.conditional(c).normalize(actlevel)
    iouscores_at_match[c] = cp * (1 - levelp) / (actquantile + cp * levelp)
print('step 2')
conceptlist_at_match, unitlist_at_match = [], []
for u in pbar(range(unit_count)):
    iou, c = iouscores_at_match[:,u].max(0)
    c = c.item()
    diff = condvar.conditional(c).mean()[u] - condvar.conditional(0).mean()[u]
    unitlist_at_match.append(dict(
        unit=u,
        label=seglabels[c],
        iou=iou.item(),
        diff=diff.item(),
        percent_level=concept_percent_level[c].item(),
        cnt=condvar.conditional(c).batchcount,
    ))
    conceptlist_at_match.append(seglabels[c])
for d in sorted(unitlist_at_match, key=lambda x: -x['iou'])[:20]:
    display(show.blocks([[d['label'],
                          'iou %.2f' % d['iou'],
                          'per %.2f' % d['percent_level'],
                          'dm %.2f' % d['diff'],
                          'cnt %d' % d['cnt'],
                          'unit %d' % d['unit']]] + unit_viz_row(d['unit'], percent_level=d['percent_level'])))


In [None]:
for d in sorted(unitlist_at_match, key=lambda x: -x['iou'])[:20]:
    display(show.blocks([[d['label'],
                          'iou %.2f' % d['iou'],
                          'per %.2f' % d['percent_level'],
                          'dm %.2f' % d['diff'],
                          'cnt %d' % d['cnt'],
                          'unit %d' % d['unit']]] + unit_viz_row(d['unit'], percent_level=d['percent_level'])))


In [None]:
float(condq.conditional(20).size()) / condq.conditional(0).size()

In [None]:
graph_conceptlist(conceptlist_at_match)

Experiment 2.

Assign a label to each unit according to the shift in conditional mean with highest statistical significance.

In [None]:
import math
zscores = torch.zeros((max(condvar.keys()) + 1, 512))
for c in sorted(condvar.keys()):
    if c == 0 or condvar.conditional(c).batchcount <= 1:
        continue
    zscores[c] = ((condvar.conditional(c).mean() - condvar.conditional(0).mean()) /
        (condvar.conditional(0).variance() / condvar.conditional(0).batchcount
         + condvar.conditional(c).variance() / condvar.conditional(c).batchcount).sqrt())
conceptlist_by_zscore = []
for u in range(256):
    zt, c = zscores[:,u].max(0)
    c = c.item()
    diff = condvar.conditional(c).mean()[u] - condvar.conditional(0).mean()[u]
    display(show.blocks([[seglabels[c],
                          'dm %.2f' % diff.item(),
                          'zs %.2f' % zt.item(),
                          'cnt %d' % condvar.conditional(c).batchcount,
                          'unit %d' % u]] + unit_viz_row(u)))
    conceptlist_by_zscore.append(seglabels[c])


In [None]:
graph_conceptlist(conceptlist_by_zscore)

Experiment 3.

Assign a label according to the highest relative mutual information at a threshold.

## Discrimination metric.

Given only the number of pixels in a given segmentation, how accurately can we do binary classification on a particular scene class?

To answer this, we can use conditional quantile information.

Conditioned on each scene class, we collect the fraction of pixels in each segmentation class.
Then at a given threshold t, the accuracy of scene classification is as follows:
p(c | s>t) + p(~c | s < t) = p(c | s > t) + 1 - p(c| s< t)
= p(s > t | c) * p(c)/ p(s>t) + 1 -  p(s<t|c) * p(c) / p(s<t)

In [None]:
reload(tally)

def compute_conditional_discrimination(batch, classnum, *args):
    assert len(batch) == 1
    image_batch = batch.cuda()
    # scores = model(image_batch)
    # pred = scores.max(1)[1]
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    feat = seg.view(-1).bincount(minlength=len(seglabels)).float() / (seg.shape[3] * seg.shape[2])
    feat[0] = 0
    feat = feat[None,...]
    return [(0, feat), (classnum.item() + 1, feat)]

conddis = tally.tally_conditional_quantile(compute_conditional_discrimination,
        train_dataset, sample_size=10000, num_workers=20) # TODO: switch to sample_size=50,000.

In [None]:
sorted(conddis.keys())

In [None]:
conddis.conditional(0).size()
condprob = torch.logspace(-3, 0, 10)[:-1] # p(s>t | c)
accscores = torch.zeros((max(conddis.keys()) + 1, 256))

c = 1
print(classlabels[c])
# if c == 0:
#     continue
level = 1 - conddis.conditional(c).quantiles(condprob)  # Levels at which the conditional quantile is achieved.
segprob = conddis.conditional(0).normalize(level)
margprob = float(conddis.conditional(c).size()) / conddis.conditional(0).size()
acc = condprob * margprob / (segprob) + 1 - (1 - condprob) * margprob / (1 - segprob)
acc1 = condprob * margprob / (segprob)
acc2 = 1 - (1 - condprob) * margprob / (1 - segprob)
#acc = condprob * 0.5 / (condprob  + segprob) # + 1 - (1 - condprob) * margprob / (1 - segprob)
# acc = condprob * 0.5 / ((condprob+ segprob)/2) + 1 - (1 - condprob) * 0.5 / (1 - (segprob + condprob)/2)
# acc = 1 - (1 - condprob) * 0.5 / (1 - (segprob + condprob)/2)
print(acc.shape)

s = 1
print(seglabels[s])
# plt.plot(condprob.numpy(), acc[s].numpy())
# plt.plot(condprob.numpy(), acc1[s].numpy())
# plt.plot(condprob.numpy(), 1 - acc2[s].numpy())
level