In [None]:
from time import time
import numpy as np
import os

import torch
from torchvision.models.resnet import resnet50
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0

if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
from external_models.dcp.pruned_resnet import PrunedResnet30, PrunedResnet50, PrunedResnet70
from models.gate_wrapped_module import compute_flop_cost_change
from models.custom_resnet import filter_mapping_from_default_resnet, custom_resnet_50, custom_resnet_56

In [None]:
path = '/home/victoria/Downloads/resnet50-20210116T145733Z-001/resnet50/pretrain-resnet50-ratenorm1-ratedist0.4/best.resnet50.2018-07-16-4310.pth.tar'
aaa = torch.load(path)

In [None]:
state_dict = aaa['state_dict']
state_dict = {k[7:]:v for k,v in state_dict.items()}
net = resnet50(pretrained=False)
net.load_state_dict(state_dict)

In [None]:
def conv_weights_analysis(weights, prev_out=None):
    no_kernals = (weights!=0).sum(-1).sum(-1) != 0
#     no_kernals = weights.sum(-1).sum(-1) != 0
#     init_filters = no_kernals.size(0), no_kernals.size(1)
    in_non_zero = (no_kernals.sum(0) > 0).sum().item()
    out_non_zero = (no_kernals.sum(1) > 0).sum().item()
    if prev_out is not None:
        in_non_zero = min(in_non_zero, prev_out)
    init_cost = weights.numel()
    final_cost  = in_non_zero*out_non_zero *weights.size(2) * weights.size(3)
    return init_cost, final_cost, in_non_zero, out_non_zero
    

In [None]:
downsample = 4

channel_config = {}

res1, res2, in_channels, out_channels  = conv_weights_analysis(net.conv1.weight.data)
channel_config['conv1'] = out_channels
first_conv_out = out_channels
orig_total = res1/4 
orig_memory = res1
new_total = res2/4
# print(new_total )
new_memory = res2
for l in range(1,5):
    layer_config = {}
    if l > 1:
        downsample *= 2
    layer = getattr(net, 'layer' + str(l))
    res1, res2, in_channels, out_channels  = conv_weights_analysis(layer[0].downsample[0].weight.data, first_conv_out if l==1 else None)
    orig_total += res1/(downsample**2) 
    orig_memory += res1
    new_total += res2/(downsample**2) 
#     print('down', res2/(downsample**2),  in_channels, out_channels )
#     print(new_total )
    new_memory += res2
    for i in range(len(layer)):
        block_config = {}
        if i==0:
            block_config['downsample'] = out_channels
        for j in range(1,4):
            prev_out = None
            if j==1 and l==1 and i==0:
                prev_out = first_conv_out
            elif j>1:
                prev_out = out_channels
            
            res1, res2, in_channels, out_channels  = conv_weights_analysis(getattr(layer[i],'conv'+str(j)).weight.data, prev_out)
#             if j == 3:
#                 out_channels = layer[i].conv3.weight.size(0)
            block_config['conv'+str(j)] = out_channels
            orig_total += res1/(downsample**2) 
            orig_memory += res1
            new_total += res2/(downsample**2) 
#             print('conv'+str(j), res2/(downsample**2), in_channels, out_channels )
#             print(new_total )
            new_memory += res2
        layer_config[str(i)] = block_config
    channel_config['layer'+str(l)] = layer_config
fc_size = net.fc.in_features * net.fc.out_features
fc_flops = fc_size / (224**2)

orig_total += fc_flops
orig_memory += fc_size
new_total += fc_flops
new_memory += fc_size

# print (fc_flops)
# print(new_total)

print(orig_total, new_total, 1-new_total/orig_total, orig_total/new_total, orig_memory, new_memory)
print(channel_config)

In [None]:
net_from_config =custom_resnet_50(channel_config)
print(net_from_config.compute_flops_memory())
# res = net_from_config(torch.randn(1,3,64,64))
# res.shape

In [None]:
def time_net(net, batch_size=8, run_times=1, measurements=100, lower_limit=0, upper_limit=500000, sleep_seconds=150, return_times=False):
    net.eval()
    total_times=[]
    for r in range(run_times):
        if r > 0:
            sleep(sleep_seconds)
        sample = torch.rand((batch_size,3,224,224)).cuda(1)
        times = []
        for i in range(measurements):
            t = time()
            res = net(sample)
            aa = res[-1,-1].item()
            times.append(time()-t)
        times = np.array(times)
        times = 1000 * times
#         times.sort()
#         times = times[lower_limit:upper_limit]
        total_times.append(times)

    net.train()

    total_times=np.array(total_times)
    
    if return_times:
        return total_times
    else:
        return(total_times.min())

In [None]:
net_from_config = net_from_config.cuda(1)
time_net(net_from_config, batch_size=32)

In [None]:
def get_conv_cost(m, down):
    assert isinstance(m, torch.nn.Conv2d)
    res = m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1]
    return res / (m.groups * down**2), res

In [None]:
x = torch.Tensor(size=(8,10,5,5))
y = torch.Tensor(size=(8,6,5,5))

In [None]:
def unmatched_channels_addition(x, y):
    y_new = torch.cat([y, torch.zeros((y.size(0), x.size(1)-y.size(1), y.size(2), y.size(3)), device=y.device)],1)
    return x + y_new

In [None]:
torch.zeros((y.size(0), x.size(1)-y.size(1), y.size(2), y.size(3))).shape

In [None]:
unmatched_channels_addition(x,y).shape

In [None]:
mob =mobilenet_v2(pretrained=False)
downsample = 1
flops = 0
memory = 0
for m in mob.modules():
    if not isinstance(m, torch.nn.Conv2d):
        if isinstance(m, torch.nn.Linear):
            flops += m.in_features * m.out_features / (224**2)
            memory +=m.in_features * m.out_features
    else:
        if m.stride[0] == 2:
            downsample = 2 * downsample
        flops += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1] / (m.groups * downsample**2)
        memory += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1]
print(flops, memory, downsample)

In [None]:
76848 / 5994.385204081633 , 76848 / 5968.875

In [None]:
5994.385204081633 *224 *224

In [None]:
sq_0 = False

sq = squeezenet1_0(False) if sq_0 else squeezenet1_1(False)

downsample = 4

flops, memory = get_conv_cost(sq.features[0], 2)

for i in range(3,13):
    if i in ([6,11] if sq_0 else [5,8]):
        downsample *= 2
    else:
        for m in sq.features[i].modules():
            if isinstance(m, torch.nn.Conv2d):
                if m.stride[0] == 2:
                    downsample = 2 * downsample
                f, m = get_conv_cost(m, downsample)
                flops += f
                memory += m
f, m = get_conv_cost(sq.classifier[1], downsample)
flops += f
memory += m     

print(flops, memory, downsample)

In [None]:
7720.0*224**2

In [None]:
path = '/media/victoria/d/Training/Eli/resnet50_pre_0_995_w_0_25_gm_0_2_w_0_5_w_1_w_2_custom_timing/net_e_140_simple'
state_dict = torch.load(path)
net = custom_resnet_50(state_dict['channels_config'], 1000)
net.load_state_dict({k[7:]:v for k,v in state_dict['state_dict'].items()})
print(*net.compute_flops_memory())

In [None]:
print(*net.compute_flops_memory(False))

In [None]:
net = PrunedResnet70()
print(*net.compute_flops_memory())

In [None]:
path = '/home/eli/Eli/Training/Cifar10/ResNet56_long/resnet56_w_0_25_w_5_w_1_w_2_w_4_w_8_w16/net_e_900'

state_dict = torch.load(path)
net = custom_resnet_56(state_dict['channels_config'], 10)
net.load_state_dict({k[7:]:v for k,v in state_dict['state_dict'].items()})
print(*net.compute_flops_memory())
len(net.layer1),len(net.layer2), len(net.layer3)

In [None]:
net.layer2

## channel pruning article

In [None]:
import json
config_path ='./analysis/channel_pruning_article_resnet50_config.txt'
with open(config_path, 'r') as f:
    channels_config = json.loads(f.read().replace('\n', '').replace("'", '"'))
net = custom_resnet_50(channels_config)

In [None]:
net.compute_flops_memory()

In [16]:
import torch
aa = torch.load('/home/eli/Eli/Training/Cifar10/ResNet56_long/resnet56_120_180_240/net_e_240')

In [None]:
aa['state_dict']['module.fc.weight'].shape

In [None]:
sum = 0
for k, v in aa['state_dict'].items():
    if 'conv' in k:
        print(k, '\t',v.shape)
        sum += v.size(0)*v.size(1)* v.size(2)*v.size(3)
sum += 64*10 #+ 40*3*16
print(sum)