In [None]:
import torch, os
import matplotlib.pyplot as plt
from netdissect import setting, show, imgviz, segmenter, renormalize, nethook, tally, upsample

In [None]:
model = setting.load_vgg16()
model = nethook.InstrumentedModel(model)
model.cuda()
ds = setting.load_dataset('places', 'val')
renorm = renormalize.renormalizer(ds, target='zc')
segmodel = segmenter.UnifiedParsingSegmenter(segsizes=[256])
seglabels = [l for l, c in segmodel.get_label_and_category_names()[0]]
iv = imgviz.ImageVisualizer(224, source=ds, percent_level=0.99)
ivsmall = imgviz.ImageVisualizer((56, 56), source=ds, percent_level=0.99)
def resfile(f):
    return os.path.join('results/vgg16-places', f)

In [None]:
indexes = range(100, 112)
batch = torch.stack([ds[i][0] for i in indexes])
preds = model(batch.cuda()).max(1)[1]
show([[
    iv.image(batch[j]),
    'label: ' + ds.classes[ds[i][1]],
    'pred: ' + ds.classes[preds[j]]]
    for j, i in enumerate(indexes)])

In [None]:
seg = segmodel.segment_batch(renorm(batch).cuda(), downsample=4)
show([(iv.image(batch[i]), iv.segmentation(seg[i,0]),
            iv.segment_key(seg[i,0], segmodel))
            for i in range(len(seg))])

In [None]:
layername = 'features.conv5_1'
model.retain_layer(layername)
model(batch.cuda())
acts = model.retained_layer(layername).cpu()
show([
    [
        [ivsmall.masked_image(batch[0], acts[0], u)],
        [ivsmall.heatmap(acts[0], u, mode='nearest')],
        'unit %d' % u
    ]
    for u in range(min(21, acts.shape[1]))
])

In [None]:
upfn = upsample.upsampler(
    target_shape=(56, 56),
    data_shape=(7, 7),
)

def flatten_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

rq = tally.tally_quantile(
    flatten_activations,
    dataset=ds,
    sample_size=1000,
    batch_size=100,
    cachefile=resfile(layername + '_rq.npz'))

In [None]:
plt.plot(rq.quantiles(0.9))

# Which unit is activating more often than the others?
rq.quantiles(0.9).max(0)[1]

In [None]:
sample_size = 1000

def max_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).max(2)[0]

topk = tally.tally_topk(
    max_activations,
    dataset=ds,
    sample_size=sample_size,
    batch_size=100,
    cachefile=resfile(layername + '_topk.npz')
)

top_indexes = topk.result()[1]

In [None]:
show.blocks([
    ['unit %d' % u,
     'img %d' % i,
     'pred: %s' % ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()],
     [iv.masked_image(
        ds[i][0],
        model.retained_layer(layername)[0],
        u)]
    ]
    for u in [12]
    for i in top_indexes[u, :50]
])

In [None]:
level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]

def compute_selected_segments(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

condi99 = tally.tally_conditional_mean(
    compute_selected_segments,
    dataset=ds,
    sample_size=sample_size,
    cachefile=resfile(layername + '_condi99.npz'))