In [2]:
from copy import deepcopy
import os
# change to root folder
if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
    
from time import time

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet50
from models.vgg_fully_convolutional import *

from data.dataset_factory import cifar_transform_test
from data.dataset import CIFARDataset
from eval import GenericEvaluator
from utils.net_utils import load_net, NetWithResult
from utils.run_arg_parser import parse_net_args_inner, NET_LOAD_TYPE
from utils.forward_hooks import OutputHook
from models.custom_resnet import custom_resnet_18, custom_resnet_34, custom_resnet_50, filter_mapping_from_default_resnet
from models.net_auxiliary_extension import NetWithAuxiliaryOutputs
from models.gated_grouped_conv import create_target_mask
from models.gates_mapper import GatesModulesMapper, NaiveSequentialGatesModulesMapper, ResNetGatesModulesMapper
from models.gate_wrapped_module import compute_flop_cost_change, compute_flop_cost, create_conv_channels_dict, create_edge_to_channels_map
from models.wrapped_models import ResNet18_gating_hard_simple_l1_combined_masked_adaptive_static, VGG16BnDropV2_gating_hard_simple_l1_static_naive_combined_masked_adaptive_generic
from models.wrapped_gated_models import ResNet18_gating, ResNet34_gating, ResNet50_gating
from models.gated_prunning import prune_net_with_hooks, prune_net_with_hooks

from external_models.dcp.pruned_resnet import PrunedResnet50

In [3]:
# VGG
net_name = 'VGG16_gating_hard_simple_l1_static_naive_combined_masked_adaptive_generic_rand'
weight_path = '/home/eli/Eli/Training/Cifar10/VGG16_memory/VGG16_trained_gating_hard_simple_l1_static_naive_combined_masked_adaptive_generic_0_20_from_0_10_from_c_0_05_mult_500000_init_0_99_schedule_80_memory_rand/net_backup.pt'

# Resnet18
net_name = 'ResNet18_gating_hard_simple_l1_combined_masked_adaptive_static'
weight_path = '/home/eli/Eli/Training/Imagenet/resnet18/resnet18_pre_l1_static_init_0_995_m_50000_w_4_0_adaptive_from_w_2_0_w_1_0/net_e_54'

# Resnet50
net_name = 'ResNet50_gating'
weight_path = '/media/eli/0FBF-BADB/resnet50/resnet50_pre_0_995_w_0_25_gm_0_2_memory_with_mult_increasing/net_e_23'


no_last_conv=False

if 'resnet' in net_name.lower():
    net, aaa = globals()[net_name](1000)
    state_dict = torch.load(weight_path)['state_dict']
    state_dict = {k[7:]: v for k,v in state_dict.items()}
    net.load_state_dict(state_dict)
    mapper = ResNetGatesModulesMapper(net.net, no_last_conv, map_for_replacement=True)
else:
    load_type = NET_LOAD_TYPE.WithCriterion
    net = load_net(weight_path, parse_net_args_inner(load_type, net_name, 10, False))
    no_last_conv = False
    mapper = NaiveSequentialGatesModulesMapper(net, no_last_conv)

    
# net, cri = globals()[net_name](10)
# state_dict = torch.load(weight_path)#['state_dict']
# state_dict = {k: v for k,v in state_dict.items()}
# net.load_state_dict(state_dict)
# mapper = NaiveSequentialGatesModulesMapper(net.net, no_last_conv)

    

In [8]:
net_name = 'PrunedResnet50'
weight_path = '/home/eli/Downloads/resnet50_pruned0.5.pth'
net = PrunedResnet50()
net.load_state_dict(torch.load(weight_path))
mapper = ResNetGatesModulesMapper(net, False, map_for_replacement=True)
for factor in ['flop', 'memory']:
    print(factor + ' cost', compute_flop_cost(net, mapper, factor + "_factor"))

flop cost 33968.0
memory cost 10287296


In [4]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))
#     if 'factor' == 'memory'

flop comperssion ratio: 1.22 original cost 76848.00 new cost 63065.72
memory comperssion ratio: 1.20 original cost 23454912.00 new cost 19557007.00


In [None]:
def time_net(net):
    net.eval()
    sample = torch.rand((16,3,224,224)).cuda()
    times = []
    for i in range(100):
        t = time()
        net(sample)
        times.append(time()-t)
    net.train()
    times = np.array(times)
    times = 1000 * times
    times.sort()
    times = times[25:75]
    
    print('25-75 mean, std, min, max (ms) \t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}'.format(times.mean(), times.std(), times.min(), times.max()))

In [3]:
custom_net_func = None
for ind in ['18','34','50']:
    if ind in net_name:
        custom_net_func = globals()['custom_resnet_' + ind]

channels_config, state_dict = create_conv_channels_dict(net, mapper)
# out = create_conv_channels_dict(net, mapper)
custom_net = custom_net_func(channels_config)
custom_net.load_state_dict(state_dict)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
channels_config

{'conv1': 44,
 'layer1': {'0': {'conv1': 41, 'conv2': 28, 'downsample': 99, 'conv3': 99},
  '1': {'conv1': 28, 'conv2': 20, 'conv3': 99},
  '2': {'conv1': 25, 'conv2': 38, 'conv3': 99}},
 'layer2': {'0': {'conv1': 42, 'conv2': 112, 'downsample': 297, 'conv3': 297},
  '1': {'conv1': 39, 'conv2': 61, 'conv3': 297},
  '2': {'conv1': 71, 'conv2': 66, 'conv3': 297},
  '3': {'conv1': 68, 'conv2': 80, 'conv3': 297}},
 'layer3': {'0': {'conv1': 232, 'conv2': 225, 'downsample': 862, 'conv3': 862},
  '1': {'conv1': 177, 'conv2': 199, 'conv3': 862},
  '2': {'conv1': 142, 'conv2': 204, 'conv3': 862},
  '3': {'conv1': 193, 'conv2': 141, 'conv3': 862},
  '4': {'conv1': 194, 'conv2': 177, 'conv3': 862},
  '5': {'conv1': 233, 'conv2': 195, 'conv3': 862}},
 'layer4': {'0': {'conv1': 493,
   'conv2': 505,
   'downsample': 2048,
   'conv3': 2048},
  '1': {'conv1': 506, 'conv2': 512, 'conv3': 2048},
  '2': {'conv1': 512, 'conv2': 512, 'conv3': 2048}}}

In [4]:
def save_net(net, file_path, custom_channels_config=None):
    state_dict = torch.load(weight_path)
    state_dict['state_dict'] = net.state_dict()
    del state_dict['optimizer']
    if custom_channels_config is not None:
        state_dict['channels_config'] = custom_channels_config
    state_dict['state_dict'] = {'module.' + k:v for k,v in state_dict['state_dict'].items()}
    torch.save(state_dict, file_path)

file_path = '/home/eli/Eli/Training/Imagenet/resnet50/resnet50_pre_0_995_w_0_25_gm_0_2_w_0_5/net_e_80_custom_resnet'
save_net(custom_net, file_path, channels_config)

## Net Timing I

In [None]:
custom_net = custom_net.cuda()
for i in range(10):
    time_net(custom_net)

## Net Timing II

In [None]:
# pruned_net = deepcopy(net)
# prune_net_with_hooks(pruned_net, mapper, False)
# resnet = pruned_net.net.cuda()

# for i in range(10):
#     time_net(resnet)

## Original Resnet timing

In [4]:
from torchvision.models.resnet import resnet50
resnet_original = resnet50(True)
resnet_original = resnet_original.cuda()
for i in range(10):
    time_net(resnet_original)

NameError: name 'time_net' is not defined

## Plotting hooks

In [None]:
def get_gating_plots(hooks):
    for i, h  in enumerate(hooks):
        probs = h.gating_module.gating_probabilities.detach().numpy()
        plt.figure(i, figsize=(10,6))
        plt.scatter(np.linspace(1,len(probs),len(probs)), probs)
        print('gate {:2d}  prune {:4d} / {:4d} prune percentage {:6.2f}%'.format(i, int(h.gating_module.active_channels()), len(h.gating_module.gating_weights), 100 * (1 - int(h.gating_module.active_channels())/ len(h.gating_module.gating_weights))))
get_gating_plots(net.forward_hooks)