In [2]:
import os

import torch
from torchvision.models.resnet import resnet50
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0

if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
from external_models.dcp.pruned_resnet import PrunedResnet30, PrunedResnet50, PrunedResnet70
from models.gate_wrapped_module import compute_flop_cost_change
from models.custom_resnet import filter_mapping_from_default_resnet, custom_resnet_50, custom_resnet_56

In [37]:
path = '/home/victoria/Downloads/resnet50-20210116T145733Z-001/resnet50/pretrain-resnet50-ratenorm1-ratedist0.3/best.resnet50.2018-07-03-1645.pth.tar'
aaa = torch.load(path)

In [38]:
state_dict = aaa['state_dict']
state_dict = {k[7:]:v for k,v in state_dict.items()}
net = resnet50(pretrained=False)
net.load_state_dict(state_dict)

<All keys matched successfully>

In [35]:
def conv_weights_analysis(weights):
    no_kernals = weights.sum(-1).sum(-1) != 0
#     init_filters = no_kernals.size(0), no_kernals.size(1)
    in_non_zero = (no_kernals.sum(0) > 0).sum().item()
    out_non_zero = (no_kernals.sum(1) > 0).sum().item()
    init_cost = weights.numel()
    final_cost  = in_non_zero*out_non_zero *weights.size(2) * weights.size(3)
    return init_cost, final_cost
    

In [39]:
downsample = 4
res1, res2  = conv_weights_analysis(net.conv1.weight.data)
orig_total = res1/4 
orig_memory = res1
new_total = res2/4
new_memory = res2
for l in range(1,5):
    if l > 1:
        downsample = downsample * 2
    layer = getattr(net, 'layer' + str(l))
    res1, res2  = conv_weights_analysis(layer[0].downsample[0].weight.data)
    orig_total += res1/(downsample**2) 
    orig_memory += res1
    new_total += res2/(downsample**2) 
    new_memory += res2
    for i in range(len(layer)):
        for j in range(1,4):
            res1, res2  = conv_weights_analysis(getattr(layer[i],'conv'+str(j)).weight.data)
            orig_total += res1/(downsample**2) 
            orig_memory += res1
            new_total += res2/(downsample**2) 
            new_memory += res2
fc_size = net.fc.in_features * net.fc.out_features
fc_flops = fc_size / (224**2)

orig_total += fc_flops
orig_memory += fc_size
new_total += fc_flops
new_memory += fc_size

print(orig_total, new_total, 1-new_total/orig_total, orig_total/new_total, orig_memory, new_memory)
            

76888.8163265306 56149.066326530614 0.2697368875067948 1.369369454505077 25502912 19328471


In [51]:
def get_conv_cost(m, down):
    assert isinstance(m, torch.nn.Conv2d)
    res = m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1]
    return res / (m.groups * down**2), res

In [40]:
mob =mobilenet_v2(pretrained=False)
downsample = 1
flops = 0
memory = 0
for m in mob.modules():
    if not isinstance(m, torch.nn.Conv2d):
        if isinstance(m, torch.nn.Linear):
            flops += m.in_features * m.out_features / (224**2)
            memory +=m.in_features * m.out_features
    else:
        if m.stride[0] == 2:
            downsample = 2 * downsample
        flops += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1] / (m.groups * downsample**2)
        memory += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1]
print(flops, memory, downsample)

5994.385204081633 44015840 32


In [38]:
76848 / 5994.385204081633 , 76848 / 5968.875

(12.8199969444195, 12.874787962555757)

In [41]:
5994.385204081633 *224 *224

300774272.0

In [66]:
sq_0 = False

sq = squeezenet1_0(False) if sq_0 else squeezenet1_1(False)

downsample = 4

flops, memory = get_conv_cost(sq.features[0], 2)

for i in range(3,13):
    if i in ([6,11] if sq_0 else [5,8]):
        downsample *= 2
    else:
        for m in sq.features[i].modules():
            if isinstance(m, torch.nn.Conv2d):
                if m.stride[0] == 2:
                    downsample = 2 * downsample
                f, m = get_conv_cost(m, downsample)
                flops += f
                memory += m
f, m = get_conv_cost(sq.classifier[1], downsample)
flops += f
memory += m     

print(flops, memory, downsample)

7720.0 1231552 16


In [67]:
7720.0*224**2

387358720.0

In [29]:
path = '/media/victoria/d/Training/Eli/resnet50_pre_0_995_w_0_25_gm_0_2_w_0_5_w_1_w_2_custom_timing/net_e_140_simple'
state_dict = torch.load(path)
net = custom_resnet_50(state_dict['channels_config'], 1000)
net.load_state_dict({k[7:]:v for k,v in state_dict['state_dict'].items()})
print(*net.compute_flops_memory())

19385.307238520407 12194386.0


In [None]:
print(*net.compute_flops_memory(False))

In [5]:
net = PrunedResnet70()
print(*net.compute_flops_memory())

22223.351482780614 8670406.0
