In [1]:
import torch
import utils
import os
import numpy as np
from collections import OrderedDict
from vgg_face_dag import vgg_face_dag

SKIP_LAYERS = [
    'relu1_1',
    'relu1_2',
    'relu2_1',
    'relu2_2',
    'relu3_1',
    'relu3_2',
    'relu3_3',
    'relu4_1',
    'relu4_2',
    'relu4_3',
    'relu5_1',
    'relu5_2',
    'relu5_3',
    'relu6',
    'dropout6',
    'relu7',
    'dropout7',
    'fc8',
]

device = torch.device('cuda')
vgg_weight = './vgg_face_dag.pth'
vgg_net = vgg_face_dag(vgg_weight)
vgg_net.to(device)
vgg_net.eval()
vgg_net.register_my_hook(skip_layers=SKIP_LAYERS)

register hook for conv1_1
register hook for conv1_2
register hook for pool1
register hook for conv2_1
register hook for conv2_2
register hook for pool2
register hook for conv3_1
register hook for conv3_2
register hook for conv3_3
register hook for pool3
register hook for conv4_1
register hook for conv4_2
register hook for conv4_3
register hook for pool4
register hook for conv5_1
register hook for conv5_2
register hook for conv5_3
register hook for pool5
register hook for fc6
register hook for fc7


In [2]:
attributes = ['leye', 'reye', 'nose', 'mouth']
actions    = ['substitution', 'preservation']
root_path  = '../../data/attribute_mutated/'
base_img = '../../data/base_img.jpg'

base_img_tensor = utils.get_data(base_img).to(device)

In [3]:
def extract_witness(model, attri, base_img_tensor):
    neuron_set_lists = OrderedDict()
    org_layouts = model(base_img_tensor)
    
    for n,l in org_layouts.items():
        neuron_set_lists[n] = set(range(l.shape[0]))
    for img_name in os.listdir(root_path + attri + '_' + actions[0]):
        attri_sub = root_path + attri + '_' + actions[0] + '/' + img_name
        attri_pre = root_path + attri + '_' + actions[1] + '/' + img_name
        
        attri_sub_img_tensor = utils.get_data(attri_sub).to(device)
        attri_pre_img_tensor = utils.get_data(attri_pre).to(device)
        
        attri_sub_layouts = model(attri_sub_img_tensor)
        assert len(org_layouts) == len(attri_sub_layouts)
        
        for layer_idx in org_layouts.keys():
            # print(f'strengthen computing {layer_idx}')
            if attri_sub_layouts[layer_idx].shape != org_layouts[layer_idx].shape:
                print(f'{layer_idx} has different shape, {attri_sub_layouts[layer_idx].shape} != {org_layouts[layer_idx].shape}')
            if len(org_layouts[layer_idx].shape) == 1:
                # fc layers
                diff_layout = np.abs(attri_sub_layouts[layer_idx] - org_layouts[layer_idx])
            else:
                diff_layout = np.sum(np.abs(attri_sub_layouts[layer_idx] - org_layouts[layer_idx]), axis=(1,2))
            sub_set = set([i for i,v in enumerate(diff_layout) if v > np.median(diff_layout)])
            neuron_set_lists[layer_idx].intersection_update(sub_set) 
        
        attri_pre_layouts = model(attri_pre_img_tensor)
        assert len(org_layouts) == len(attri_pre_layouts)
        
        for layer_idx in org_layouts.keys():
            # print(f'weaken computing {layer_idx}')
            if len(org_layouts[layer_idx].shape) == 1:
                # fc layers
                diff_layout = np.abs(attri_pre_layouts[layer_idx] - org_layouts[layer_idx])
            else:
                diff_layout = np.sum(np.abs(attri_pre_layouts[layer_idx] - org_layouts[layer_idx]), axis=(1,2))
            sub_set = set([i for i,v in enumerate(diff_layout) if v < np.median(diff_layout)])
            neuron_set_lists[layer_idx].intersection_update(sub_set)
    
    res = []
    for name, neuron_set in neuron_set_lists.items():
        res.append('%s->%s' % (name, ','.join(map(str,sorted(neuron_set)))))
        
    print('Extracted witnesses:\n\t%s' % ('\n\t'.join(res)))
    
    with open(os.path.join('ami_data', f'{attri}_neurons.txt'), 'w') as out_file:
        out_file.write('\n'.join(res))

In [4]:
for attri in attributes:
    print(f'Extracting witnesses for {attri}')
    extract_witness(vgg_net, attri, base_img_tensor)

Extracting witnesses for leye
	conv1_1->
	conv1_2->
	pool1->
	conv2_1->
	conv2_2->28,58,60,102
	pool2->28,102
	conv3_1->45,185
	conv3_2->126
	conv3_3->48,52,54,93,126,135,185,207,239,246
	pool3->173,181,192
	conv4_1->95,276,323,327,412
	conv4_2->15,328,362
	conv4_3->235,246,251,272,410,448,480
	pool4->25,46,75,148,171,187,251,329,355,403,471
	conv5_1->251,497,504
	conv5_2->
	conv5_3->79
	pool5->145,387
	fc6->
	fc7->2384
Extracting witnesses for reye
	conv1_1->
	conv1_2->
	pool1->
	conv2_1->
	conv2_2->28,58,60,102
	pool2->28,102
	conv3_1->83,169,184,185,199
	conv3_2->
	conv3_3->4,48,101,135,185,207
	pool3->173,182
	conv4_1->
	conv4_2->184,448
	conv4_3->215,235,355
	pool4->0,15,17,91,171,238,258,274,329,446,475
	conv5_1->228,409
	conv5_2->104
	conv5_3->
	pool5->
	fc6->
	fc7->
Extracting witnesses for nose
	conv1_1->
	conv1_2->
	pool1->
	conv2_1->
	conv2_2->28,58,60,102
	pool2->28,102
	conv3_1->57,136
	conv3_2->
	conv3_3->48,101,135,174,185,200
	pool3->59
	conv4_1->327
	conv4_2->280,391,4