In [1]:
import os
from time import time, sleep

from matplotlib import pyplot as plt
import numpy as np
import torch
from torchvision.models.resnet import resnet50

if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
from models.gates_mapper import GatesModulesMapper, NaiveSequentialGatesModulesMapper, ResNetGatesModulesMapper
from models.wrapped_gated_models import ResNet18_gating, ResNet34_gating, ResNet50_gating, custom_resnet_from_gated_net
from models.gate_wrapped_module import compute_flop_cost_change, create_conv_channels_dict, create_edge_to_channels_map, dot_string_to_tree_dict
from models.custom_resnet import filter_mapping_from_default_resnet, custom_resnet_50


In [2]:
def time_net(net, run_times=1, measurements=100, lower_limit=0, upper_limit=500000, sleep_seconds=0):
    net.eval()
    total_times=[]
    for _ in range(run_times):
        sample = torch.rand((8,3,224,224)).cuda(1)
        times = []
        for i in range(measurements):
            t = time()
            res = net(sample)
            aa = res[-1,-1].item()
            times.append(time()-t)
        times = np.array(times)
        times = 1000 * times
#         times.sort()
#         times = times[lower_limit:upper_limit]
        total_times.append(times)
#         sleep(1.0)
    net.train()

    total_times=np.array(total_times)
    sleep(sleep_seconds)
    return(total_times.min())
    
#     remove_range=[]
#     std = np.median(total_times.std(1))
#     for i in range(len(total_times)):
#         if total_times[i].mean() - np.median(total_times) > std:
#             remove_range.append(i)
    
#     total_times = np.delete(total_times, remove_range, axis=0)

#     print('25-75 mean, std, min, max (ms) \t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}'.format(times.mean(), times.std(), times.min(), times.max()))

#     return total_times

#     return times

In [3]:
resnet_original = resnet50(True)
resnet_original = resnet_original.cuda(1)
print(time_net(resnet_original))
# for i in range(10):
#     times=time_net(resnet_original)
#     print(np.array(times).min())
#     sleep(20)
# for i in range(10):
#     times.append(time_net(resnet_original))
#     sleep(1.0)
    

14.162063598632812


In [6]:
# Resnet50
net_name = 'ResNet50_gating'
weight_path = '/media/victoria/d/Training/Eli/resnet50_pre_0_995_w_0_25_gm_0_2_w_0_5_w_1_w_2/net_e_140'

no_last_conv=False

net, aaa, _ = globals()[net_name](1000)
state_dict = torch.load(weight_path)['state_dict']
state_dict = {k[7:]: v for k,v in state_dict.items()}
net.load_state_dict(state_dict)
mapper = ResNetGatesModulesMapper(net.net, no_last_conv, map_for_replacement=True)
custom_net = custom_resnet_from_gated_net(net_name, weight_path).cuda(1)




In [6]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 1.46 original cost 76848.00 new cost 52639.88
memory comperssion ratio: 1.11 original cost 23454912.00 new cost 21189636.00


In [11]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 1.37 original cost 76848.00 new cost 56277.48
memory comperssion ratio: 1.06 original cost 23454912.00 new cost 22124172.00


In [4]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 2.00 original cost 76848.00 new cost 38474.25
memory comperssion ratio: 1.26 original cost 23454912.00 new cost 18619922.00


In [9]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 1.93 original cost 76848.00 new cost 39736.22
memory comperssion ratio: 1.24 original cost 23454912.00 new cost 18852860.00


In [9]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 2.68 original cost 76848.00 new cost 28684.06
memory comperssion ratio: 1.52 original cost 23454912.00 new cost 15446608.00


In [7]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 4.28 original cost 76848.00 new cost 17971.34
memory comperssion ratio: 2.50 original cost 23454912.00 new cost 9397143.00


In [4]:
for factor in ['flop', 'memory']:
    original_cost, new_cost = compute_flop_cost_change(net, mapper, factor_type= factor + "_factor")
    print("{} comperssion ratio: {:.2f} original cost {:.2f} new cost {:.2f}".format(factor, original_cost/new_cost, original_cost, new_cost))

flop comperssion ratio: 7.21 original cost 76848.00 new cost 10658.53
memory comperssion ratio: 4.49 original cost 23454912.00 new cost 5228501.00


In [9]:
time_net(custom_net)

12.209177017211914

In [13]:
time_net(custom_net)

11.297225952148438

In [15]:
time_net(custom_net)

11.02447509765625

In [18]:
time_net(custom_net)

9.547233581542969

In [11]:
time_net(custom_net)

9.53531265258789

In [34]:
time_net(custom_net)

7.979869842529297

In [5]:
time_net(custom_net)

6.571292877197266

In [None]:
plt.figure(figsize=(20,12))
for i in range(len(times)):
    plt.plot(times[i], label=str(i))
plt.legend(loc='best')

In [None]:
for i in range(len(times)):
    print(((times[i][1:]-times[i][:-1])<0).sum())

In [4]:
net, _ = ResNet50_gating(1000)
mapper = ResNetGatesModulesMapper(net.net, False, False)
hyper_edge_to_active_channels = create_edge_to_channels_map(mapper, net.forward_hooks)

original_cost = 76848.00

channels_config = filter_mapping_from_default_resnet(resnet50(False))
orig_resnet = custom_resnet_50(channels_config)
orig_resnet = orig_resnet.cuda(1)

# meas = []


for _ in range(9):
    print("{:2d}, {:8.4f}, {:8.4f}".format(0, 100, time_net(orig_resnet, measurements=100, sleep_seconds=100)))
    for i, (_,ch_num_change) in enumerate(hyper_edge_to_active_channels.items()):
        conv_channels_dict = {}
        ch_num_change = int(ch_num_change * 0.5)
        for j, (hyper_edge, ch_num_keep) in enumerate(hyper_edge_to_active_channels.items()):
            num_channels = hyper_edge_to_active_channels[hyper_edge]
            # take only out channels of convolutions
            num_channels = ch_num_change if i==j else ch_num_keep
            for k, (_, s) in enumerate(hyper_edge.convs_and_sides):
                if not s:
                    dot_string_to_tree_dict(conv_channels_dict, hyper_edge.conv_names[k], num_channels)

        custom_net = custom_resnet_50(conv_channels_dict).cuda(1)
        new_cost, _ = custom_net.compute_flops_memory()
        print("{:2d}, {:8.4f}, {:8.4f}".format(i+1, 100*new_cost/original_cost, time_net(custom_net, measurements=100, sleep_seconds=100)))
    
    # measure original net as a sanity check
#     if ((i % 10) == 0) and (i > 0):
#         meas.append(time_net(orig_resnet, measurements=100, sleep_seconds=100))

# measure original net as a sanity check   
# meas.append(time_net(orig_resnet, measurements=100, sleep_seconds=100))

# run again to double check similar timing
# for m in meas:
#     print(m)


 0, 100.0000,  14.1511
 1,  97.6369,  13.3538
 2,  98.3344,  13.6471
 3,  97.8347,  13.5651
 4,  97.8347,  13.6349
 5,  97.8347,  13.5915
 6,  97.8347,  13.6435
 7,  97.8347,  13.6027
 8,  94.3369,  12.4342
 9,  98.1678,  13.4819
10,  97.8347,  13.6530
11,  97.8347,  13.6702
12,  97.8347,  13.6149
13,  97.8347,  13.6433
14,  97.8347,  13.6054
15,  97.8347,  13.6774
16,  97.8347,  13.6030
17,  92.3381,  12.6061
18,  98.1678,  13.5615
19,  97.8347,  13.6697
20,  97.8347,  13.6843
21,  97.8347,  13.6261
22,  97.8347,  13.6459
23,  97.8347,  13.6266
24,  97.8347,  13.6690
25,  97.8347,  13.6325
26,  97.8347,  13.6769
27,  97.8347,  13.6366
28,  97.8347,  13.6580
29,  97.8347,  13.5863
30,  89.6731,  12.4362
31,  98.1678,  13.4728
32,  97.8347,  13.5407
33,  97.8347,  13.5648
34,  97.8347,  13.5281
35,  97.8347,  13.5434
36,  97.8347,  13.5252
37,  95.3362,  13.0966
 0, 100.0000,  13.7303
 1,  97.6369,  13.3791
 2,  98.3344,  13.6652
 3,  97.8347,  13.5932
 4,  97.8347,  13.6683
 5,  97.834