Set up the environment, and use cuda if present.

In [2]:
%matplotlib inline
import torch, json, numpy
from netdissect import proggan, nethook, easydict, zdataset
from netdissect.plotutil import plot_tensor_images, plot_max_heatmap

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Load a generator model, and instrument a layer for modification.

In [58]:
model = proggan.from_pth_file('models/karras/diningroom_lsun.pth').to(device)
nethook.retain_layers(model, ['layer4'])
nethook.edit_layers(model, ['layer4'])

In the dissection, find the highest ranked tree units.

In [57]:
dissect = easydict.load_json('dissect/diningroom/dissect.json')
lrec = next(x for x in dissect.layers if x.layer == 'layer4')
rrec = next(x for x in lrec.rankings if x.name == 'table-iou')
ct_units = torch.from_numpy(numpy.argsort(rrec.score)[:20])

Generate 20 example images.

In [59]:
zbatch = zdataset.z_sample_for_model(model, 30)[...].to(device)
base_images = model(zbatch)
plot_tensor_images(base_images)

Define a function to permute the values of the selected units, and generate the resulting images

In [62]:
def make_mixes(model, layer, units, zbatch, mixcount=5):
    model.ablation[layer] = None
    model.replacement[layer] = None
    base_images = model(zbatch)
    base_features = model.retained[layer]
    result = torch.zeros((base_images.shape[0] * mixcount, ) + base_images.shape[1:])
    result[0::mixcount] = base_images
    for i in range(1, mixcount):
        shuf = torch.from_numpy(numpy.random.permutation(len(units)))
        new_base_features = base_features[:, units][:, shuf]
        replacement = base_features.clone()
        replacement[:,units] = new_base_features
        ablation = torch.zeros(base_features.shape[1])
        ablation.scatter_(0, units, 1)
        model.ablation[layer] = ablation
        model.replacement[layer] = replacement
        result[i::mixcount] = model(zbatch)
        model.ablation[layer] = None
        model.replacement[layer] = None
    return result

Call the function to shuffle dining room tables

In [63]:
zbatch = zdataset.z_sample_for_model(model, 30)[[10,15,16,25]].to(device)
plot_tensor_images(make_mixes(model, 'layer4', ct_units, zbatch))

In [28]:
rrec = next(x for x in lrec.rankings if x.name == 'door-iou')
door_units = torch.from_numpy(numpy.argsort(rrec.score)[:20])

zbatch = zdataset.z_sample_for_model(model, 50)[[13,29,34,42]].to(device)
plot_tensor_images(make_mixes(model, 'layer4', door_units, zbatch))


In [61]:
model = proggan.from_pth_file('models/karras/churchoutdoor_lsun.pth').to(device)
nethook.retain_layers(model, ['layer4'])
nethook.edit_layers(model, ['layer4'])

In [18]:
zbatch = zdataset.z_sample_for_model(model, 50)[...].to(device)
base_images = model(zbatch)
plot_tensor_images(base_images)

features = model.retained['layer4']
specific_features = features[:,door_units]
plot_max_heatmap(specific_features, shape=(256,256))