In [None]:
import os
# change to root folder
if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
    
from time import time

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader

from models.vgg_fully_convolutional import *

from data.dataset_factory import cifar_transform_test
from data.dataset import CIFARDataset
from eval import GenericEvaluator
from utils.net_utils import load_net, NetWithResult
from utils.run_arg_parser import parse_net_args_inner, NET_LOAD_TYPE
from utils.forward_hooks import OutputHook
from models.net_auxiliary_extension import NetWithAuxiliaryOutputs
from models.gated_grouped_conv import create_target_mask

In [None]:
# net = resnet18(True)
net_name = 'VGG16BnDropV2_GatedHardSimple'
weight_path = '/home/eli/Eli/Training/Cifar10/VGG16DropV2/VGG16BnDropV2_trained_gating_hard_simple_l1_static_0_000005_no_decay_sum_init_4_T_inverse_mult_10000_optimizer_softmax_schedule_200/net_backup.pt'
load_type = NET_LOAD_TYPE.Cifar10
net = load_net(weight_path, parse_net_args_inner(load_type, net_name, 10, False), load_type='inner')

In [None]:
cifar_dataset = CIFARDataset('/home/eli/Data/Cifar10/cifar-10-batches-py', cifar_transform_test, False, 10)
cifar_dataloader = DataLoader(cifar_dataset, shuffle=True, batch_size=32, num_workers=16)

In [None]:
conv_blocks = [m for m in net.modules() if isinstance(m, ConvBlock)]
gate_outputs = [c.gate.gumble for c in conv_blocks if c.gated]
hooks = [OutputHook(m) for m in gate_outputs]
wrapped_net = NetWithAuxiliaryOutputs(net, hooks)

In [None]:
num_batches = 20
current_batch = 0
features_usage = []

nested_list =[]
features_usage = []

for i in range(len(hooks)):
    features_usage.append(nested_list[:])
    for j in range(10):
        features_usage[-1].append(nested_list[:])

for b, (_, data, targets) in enumerate(cifar_dataloader):
    if current_batch >= num_batches:
        break
    _, hooks_out = wrapped_net(data)
#     diff_mask = create_target_mask(targets)
#     same_mask = (target_mask == 0) - torch.eye(mask.size(0))
    for i, hook_out in enumerate(hooks_out):
        for j in range(data.size(0)):
            features_usage[i][targets[j]].append(hook_out[j].detach().numpy())
    current_batch += 1

for i in range(len(hooks_out)):
    for j in range(10):
        features_usage[i][j] = np.array(features_usage[i][j])

In [None]:
class_split = False
if class_split:
    for i in range(len(hooks_out)):
        plt.figure(i, figsize=(20,12))
        for j in range(10):
            plt.scatter(np.linspace(1,features_usage[i][j].shape[1],features_usage[i][j].shape[1]), features_usage[i][j].sum(0)/features_usage[i][j].shape[0])
else:
    for i in range(len(hooks_out)):
        combined = np.concatenate(features_usage[i],0)
        plt.figure(i, figsize=(10,6))
        plt.scatter(np.linspace(1,combined.shape[1],combined.shape[1]), combined.mean(0))
#         print(((((np.absolute(combined.mean(0) - 0.5) - 0.45) > 0).sum())/combined.shape[1]))
        

In [None]:
def get_channels_to_prune(features_usage, threshold=0.5):
    features_usage = [np.concatenate(features_usage_layer) for features_usage_layer in features_usage]
    features_to_prune = [[]] # first layer not gated
    for i, features_usage_layer in enumerate(features_usage):
        multiplier = 1
        if hasattr(conv_blocks[i], 'dropout') and conv_blocks[i].dropout is not None:
            multiplier = (1-conv_blocks[i].dropout.p)
        features_to_prune.append(np.where(features_usage_layer.mean(0) < threshold * multiplier)[0].tolist())
    return features_to_prune

In [None]:
channels_to_prune = get_channels_to_prune(features_usage, 0.7)
print('ch\tpruned\tpercent\tcost')
for i, c in enumerate(channels_to_prune):
    print(conv_blocks[i].conv.in_channels, '\t', len(c), '\t', '{:5.2f}'.format(100* len(c)/conv_blocks[i].conv.in_channels), '\t', conv_blocks[i].in_channel_flop_cost_prune)

In [None]:
eval1 = GenericEvaluator(NetWithResult(net), cifar_dataloader).eval()

In [None]:
net_no_gates = net.get_net_without_gating()

In [None]:
def create_pruned_net(net, channels_to_prune):
    i = 0
    prev_conv_module = None
    flops_cost = 0
    total_memory = 0
    saved_memory = 0
    # if we are using pruning on first module, we will get an error
    conv_blocks = [conv_module for conv_module in net.modules() if isinstance(conv_module, ConvBlock)]
    for conv_module in conv_blocks:
        total_memory += conv_module.conv.in_channels * conv_module.conv.out_channels
        saved_memory += len(channels_to_prune[i]) * conv_module.conv.out_channels
        flops_cost += conv_module.in_channel_flop_cost * (conv_module.conv.in_channels - len(channels_to_prune[i]))
        for ind in channels_to_prune[i]:
            conv_module.nullify_input_channel(ind)
            prev_conv_module.nullify_output_channel(ind)
#             flops_saved += conv_module.in_channel_flop_cost_prune
        prev_conv_module = conv_module
        i += 1
    flops_saved = net.total_pixel_flop_cost - flops_cost
    print('Saved {} flops per pixel from a total of {} ({:5.2f}%)'.format(flops_saved, net.total_pixel_flop_cost, 100 * flops_saved / net.total_pixel_flop_cost))
    print('Compression ratio: {:5.2f}'.format(total_memory/(total_memory-saved_memory)))

In [None]:
eval2 = GenericEvaluator(NetWithResult(net_no_gates), cifar_dataloader).eval()

In [None]:
create_pruned_net(net_no_gates, channels_to_prune)

In [None]:
eval3 = GenericEvaluator(NetWithResult(net_no_gates), cifar_dataloader).eval()