In [None]:
%matplotlib inline
import torch, json, numpy
from netdissect import proggan, nethook, easydict, zdataset
from netdissect.plotutil import plot_tensor_images, plot_max_heatmap
from scipy.stats import wasserstein_distance
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Create the outdoor church model and instrument all the layers past layer4

In [None]:
model = proggan.from_pth_file('models/karras/churchoutdoor_lsun.pth').to(device)
nethook.edit_layers(model, ['layer4'])
instrumented_layers = ['layer%d' % i for i in range(4, 15)] + ['output_256x256']
nethook.retain_layers(model, instrumented_layers)

Select door units, and visualize them on two different images.

In [None]:
dissect = easydict.load_json('dissect/churchoutdoor/dissect.json')
lrec = next(x for x in dissect.layers if x.layer == 'layer4')
rrec = next(x for x in lrec.rankings if x.name == 'door-iou')
units = torch.from_numpy(numpy.argsort(rrec.score)[:20])

In [17]:
zbatch = zdataset.z_sample_for_model(model, 350)[[13,34,109,134,139,346]].to(device)
plot_tensor_images(model(zbatch))
plot_max_heatmap(model.retained['layer4'][:,units], shape=(256,256))

Use the maximum activation of each of the door features in the first image as the `canonical door`

In [None]:
model(zbatch)
# door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]
door_feature = model.retained['layer4'][0][units][:,5,3]
baseline_target = model.retained['layer4'][1:].clone()

Create modification #1: put a door in the church wall.

How does this change become RGB?  Trace through the layers.

In [None]:
modified_target = baseline_target.clone()
modified_target[:,units,7,4] = door_feature
model.ablation['layer4'] = torch.tensor([1,0])[:,None,None,None]
model.replacement['layer4'] = modified_target
plot_tensor_images(model(zbatch))
model.ablation['layer4'] = None
model.replacement['layer4'] = None
graph1 = []
for layer in instrumented_layers:
    features = model.retained[layer]
    changed_features = features[0] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]
    base_features = features[1] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]
    diff = changed_features - base_features
    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))
                   for c, b in zip(changed_features, base_features)])
    norm = wd.mean()
    print(layer, norm.item())
    graph1.append(norm.item())

Create modification #2: try to put a door in the sky

In [None]:
modified_target_2 = baseline_target.clone()
modified_target_2[:,units,5,6] = door_feature
model.ablation['layer4'] = torch.tensor([1,0])[:,None,None,None]
model.replacement['layer4'] = modified_target_2
plot_tensor_images(model(zbatch))
model.ablation['layer4'] = None
model.replacement['layer4'] = None
graph2 = []
for layer in instrumented_layers:
    features = model.retained[layer]
    changed_features = features[0] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]
    base_features = features[1] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]
    diff = changed_features - base_features
    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))
                   for c, b in zip(changed_features, base_features)])
    norm = wd.mean()
    print(layer, norm.item())
    graph2.append(norm.item())

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(4,16), graph2, label="sky")
plt.plot(range(4,16), graph1, label="wall")
plt.legend()
plt.show()

Some ground truth evaluation.

0 = no significant change
1 = changed or added a door
2 = changed or added a window
3 = some other change

We don't bother evaluating perturbations on the edges.

In [None]:
gt = torch.tensor([
  [[0,0,0,0,0,0], [0,0,0,0,0,3], [0,0,2,0,0,0], [0,3,0,0,0,0], [0,0,3,0,2,0], [0,0,0,0,0,0]],
  [[0,0,0,0,0,0], [2,0,0,2,2,2], [0,2,2,0,0,0], [0,0,0,2,2,0], [2,0,2,2,2,2], [2,0,0,2,3,0]],
  [[0,0,0,0,0,0], [0,0,0,0,2,3], [0,2,2,0,2,0], [0,0,0,0,1,2], [0,0,0,0,2,2], [0,0,0,3,1,0]],
  [[0,0,0,3,0,0], [0,2,0,0,3,3], [0,2,0,2,2,0], [0,2,2,0,2,2], [0,0,2,0,2,2], [0,0,0,3,2,0]],
  [[0,0,0,1,1,0], [0,0,0,1,1,2], [0,2,0,1,0,0], [0,3,1,1,2,2], [0,0,1,0,0,1], [0,3,2,1,2,2]],
  [[0,2,0,1,1,1], [1,2,0,0,1,1], [0,0,1,0,1,0], [0,0,1,1,1,0], [0,2,0,0,1,1], [0,2,1,1,1,0]]
])


In [None]:
with torch.no_grad():
    
    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)
    myzbatch[0::2] = zbatch
    myzbatch[1::2] = zbatch
    ablation = torch.zeros(len(myzbatch))[:,None,None,None]
    ablation[1::2] = 1

    model.ablation['layer4'] = None
    model.replacement['layer4'] = None
    model(myzbatch)
    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]
    door_feature = model.retained['layer4'][0][units][:,5,3]
    baseline_target = model.retained['layer4'].clone()
    all_results = []
    for r in range(1,7):
        for c in range(1,7):
            modified_target_2 = baseline_target.clone()
            modified_target_2[1::2,units,r,c] = door_feature
            model.ablation['layer4'] = ablation
            model.replacement['layer4'] = modified_target_2
            plot_tensor_images(model(myzbatch))
            model.ablation['layer4'] = None
            model.replacement['layer4'] = None
            loc_results = []
            for case in range(0,12,2):
                graphv = []
                for layer in instrumented_layers:
                    features = model.retained[layer]
                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]
                    changed_features = features[case+1] / avg_features
                    base_features = features[case] / avg_features
                    # diff = changed_features - base_features
                    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))
                                   for c, b in zip(changed_features, base_features)])
                    norm = wd.mean()
                    graphv.append(norm.item())
                plt.plot(range(4,16), graphv, label="%d p%d,%d" % (case, r, c))
                # plt.plot(range(4,16), graph1, label="wall")
                plt.legend()
                plt.show()
                loc_results.append(graphv)
                print('%d location results' % len(loc_results))
            all_results.append(loc_results)
            print('%d results so far' % len(all_results))

In [None]:
import json
with open('norm_was_dist_all_results.json', 'w') as f:
    json.dump(all_results, f)

In [None]:
effect = []
noeffect = []

for loc, locg in enumerate(all_results):
    r = loc // 6
    c = loc % 6
    for i, graphv in enumerate(locg):
        t = gt[r, c, i]
        # if i == 0:
        #     continue
        if t != 0:
            effect.append(graphv)
        else:
            noeffect.append(graphv)
avg_effect = torch.tensor(effect).mean(0).numpy()
avg_noeffect = torch.tensor(noeffect).mean(0).numpy()

f = plt.figure(figsize=(2.5, 2))
plt.plot(range(4,16), avg_effect, label="effect")
plt.plot(range(4,16), avg_noeffect, label="not")
plt.legend(frameon=False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.xlabel('layer after intervention')
plt.ylabel('Wasserstein feat dist')
plt.show()
f.savefig("layerwass.pdf", bbox_inches='tight')

In [None]:
layer = 'layer5'
features = model.retained[layer]
changed_features = features[0]
base_features = features[1]
wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))
                   for c, b in zip(changed_features, base_features)])
wd.mean()


In [None]:
with torch.no_grad():
    
    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)
    myzbatch[0::2] = zbatch
    myzbatch[1::2] = zbatch
    ablation = torch.zeros(len(myzbatch))[:,None,None,None]
    ablation[1::2] = 1

    model.ablation['layer4'] = None
    model.replacement['layer4'] = None
    model(myzbatch)
    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]
    door_feature = model.retained['layer4'][0][units][:,5,3]
    baseline_target = model.retained['layer4'].clone()
    mean_results = []
    for r in range(1,7):
        for c in range(1,7):
            modified_target_2 = baseline_target.clone()
            modified_target_2[1::2,units,r,c] = door_feature
            model.ablation['layer4'] = ablation
            model.replacement['layer4'] = modified_target_2
            plot_tensor_images(model(myzbatch))
            model.ablation['layer4'] = None
            model.replacement['layer4'] = None
            loc_results = []
            for case in range(0,12,2):
                graphv = []
                for layer in instrumented_layers:
                    features = model.retained[layer]
                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]
                    changed_features = features[case+1] / avg_features
                    base_features = features[case] / avg_features
                    diff = changed_features - base_features
                    norm = diff.abs().mean()
                    graphv.append(norm.item())
                plt.plot(range(4,16), graphv, label="%d p%d,%d" % (case, r, c))
                # plt.plot(range(4,16), graph1, label="wall")
                plt.legend()
                plt.show()
                loc_results.append(graphv)
                print('%d location results' % len(loc_results))
            mean_results.append(loc_results)
            print('%d results so far' % len(mean_results))

In [None]:
effect = []
noeffect = []

for loc, locg in enumerate(mean_results):
    r = loc // 6
    c = loc % 6
    for i, graphv in enumerate(locg):
        t = gt[r, c, i]
        if i == 0:
            continue
        if t != 0:
            effect.append(graphv)
        else:
            noeffect.append(graphv)
avg_effect = torch.tensor(effect)[:].mean(0).numpy()
avg_noeffect = torch.tensor(noeffect)[:].mean(0).numpy()

f = plt.figure(figsize=(2.5,2))
plt.plot(range(4,16), avg_effect, label="effect")
plt.plot(range(4,16), avg_noeffect, label="not")
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.legend(frameon=False)
plt.xlabel('layer after intervention')
plt.ylabel('normalized feature diff')
plt.show()
f.savefig("layernorm.pdf", bbox_inches='tight')

In [None]:
t_all_results = torch.tensor(all_results).view(6,6,6,12).permute(2,3,0,1)
t_mean_results = torch.tensor(mean_results).view(6,6,6,12).permute(2,3,0,1)
pad_all_results = torch.zeros(6,12,8,8)
pad_all_results[:,:,1:7,1:7] = t_all_results
pad_mean_results = torch.zeros(6,12,8,8)
pad_mean_results[:,:,1:7,1:7] = t_mean_results

for i in range(11,12):
    print('wass', i + 4)
    plot_max_heatmap(t_all_results[:,i:i+1,:,:], shape=(256,256))
    print('mean', i + 4)
    plot_max_heatmap(t_mean_results[:,i:i+1,:,:], shape=(256,256))


In [18]:
with torch.no_grad():
    
    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)
    myzbatch[0::2] = zbatch
    myzbatch[1::2] = zbatch
    ablation = torch.zeros(len(myzbatch))[:,None,None,None]
    ablation[1::2] = 1

    model.ablation['layer4'] = None
    model.replacement['layer4'] = None
    model(myzbatch)
    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]
    door_feature = model.retained['layer4'][0][units][:,5,3]
    baseline_target = model.retained['layer4'].clone()
    full_results = []
    for r in range(0,8):
        for c in range(0,8):
            modified_target_2 = baseline_target.clone()
            modified_target_2[1::2,units,r,c] = door_feature
            model.ablation['layer4'] = ablation
            model.replacement['layer4'] = modified_target_2
            plot_tensor_images(model(myzbatch))
            model.ablation['layer4'] = None
            model.replacement['layer4'] = None
            loc_results = []
            for case in range(0,12,2):
                graphv = []
                for layer in instrumented_layers:
                    features = model.retained[layer]
                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]
                    changed_features = features[case+1] / avg_features
                    base_features = features[case] / avg_features
                    diff = changed_features - base_features
                    norm = diff.abs().mean()
                    graphv.append(norm.item())
                plt.plot(range(4,16), graphv, label="%d p%d,%d" % (case, r, c))
                # plt.plot(range(4,16), graph1, label="wall")
                plt.legend()
                plt.show()
                loc_results.append(graphv)
                print('%d location results' % len(loc_results))
            full_results.append(loc_results)
            print('%d results so far' % len(mean_results))

In [19]:
t_full_results = torch.tensor(full_results).view(8,8,6,12).permute(2,3,0,1)

for i in range(0,12):
    print('full', i + 4)
    plot_max_heatmap(t_full_results[:,i:i+1,:,:], shape=(256,256))
