In [1]:
import os

import torch
from torchvision.models.resnet import resnet50
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0

if os.path.basename(os.getcwd()) == 'analysis':
    os.chdir('..')
from external_models.dcp.pruned_resnet import PrunedResnet30, PrunedResnet50, PrunedResnet70
from models.gate_wrapped_module import compute_flop_cost_change
from models.custom_resnet import filter_mapping_from_default_resnet, custom_resnet_50, custom_resnet_56

In [81]:
path = '/home/eli/Downloads/best.resnet50.2018-07-11-9994.pth.tar'
aaa = torch.load(path)

In [82]:
state_dict = aaa['state_dict']
state_dict = {k[7:]:v for k,v in state_dict.items()}
net = resnet50(pretrained=False)
net.load_state_dict(state_dict)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [None]:
from matplotlib import pyplot as plt

w = net.layer1[0].downsample[0].weight.data[:,:,0,0].abs()
w_in = w.sum(0).numpy()
w_out = w.sum(1).numpy()

w_out.sort()
w_out
# plt.hist(w_in)
# (w_out<1e-1).sum(), 256-180

In [83]:
def conv_weights_analysis(weights, prev_out=None):
    no_kernals = (weights!=0).sum(-1).sum(-1) != 0
#     no_kernals = weights.sum(-1).sum(-1) != 0
#     init_filters = no_kernals.size(0), no_kernals.size(1)
    in_non_zero = (no_kernals.sum(0) > 0).sum().item()
    out_non_zero = (no_kernals.sum(1) > 0).sum().item()
    if prev_out is not None:
        in_non_zero = min(in_non_zero, prev_out)
    init_cost = weights.numel()
    final_cost  = in_non_zero*out_non_zero *weights.size(2) * weights.size(3)
    return init_cost, final_cost, in_non_zero, out_non_zero
    

In [84]:
downsample = 4

channel_config = {}

res1, res2, in_channels, out_channels  = conv_weights_analysis(net.conv1.weight.data)
channel_config['conv1'] = out_channels
first_conv_out = out_channels
orig_total = res1/4 
orig_memory = res1
new_total = res2/4
print(new_total )
new_memory = res2
for l in range(1,5):
    layer_config = {}
    if l > 1:
        downsample = downsample * 2
    layer = getattr(net, 'layer' + str(l))
    res1, res2, in_channels, out_channels  = conv_weights_analysis(layer[0].downsample[0].weight.data, first_conv_out if i==1 else None)
    orig_total += res1/(downsample**2) 
    orig_memory += res1
    new_total += res2/(downsample**2) 
    print('down', res2/(downsample**2),  in_channels, out_channels )
    print(new_total )
    new_memory += res2
    for i in range(len(layer)):
        block_config = {}
        if i==0:
            block_config['downsample'] = out_channels
        for j in range(1,4):
            res1, res2, in_channels, out_channels  = conv_weights_analysis(getattr(layer[i],'conv'+str(j)).weight.data, first_conv_out if (i == 1) else None)
            block_config['conv'+str(j)] = out_channels
            orig_total += res1/(downsample**2) 
            orig_memory += res1
            new_total += res2/(downsample**2) 
            print('conv'+str(j), res2/(downsample**2), in_channels, out_channels )
            print(new_total )
            new_memory += res2
        layer_config[str(i)] = block_config
    channel_config['layer'+str(l)] = layer_config
fc_size = net.fc.in_features * net.fc.out_features
fc_flops = fc_size / (224**2)

orig_total += fc_flops
orig_memory += fc_size
new_total += fc_flops
new_memory += fc_size

print (fc_flops)
print(new_total)

print(orig_total, new_total, 1-new_total/orig_total, orig_total/new_total, orig_memory, new_memory)
print(channel_config)

1433.25
down 1024.0 64 256
2457.25
conv1 156.0 64 39
2613.25
conv2 1404.0 64 39
4017.25
conv3 616.0 64 154
4633.25
conv1 95.0625 39 39
4728.3125
conv2 855.5625 39 39
5583.875
conv3 375.375 39 154
5959.25
conv1 624.0 256 39
6583.25
conv2 1404.0 64 39
7987.25
conv3 616.0 64 154
8603.25
down 2048.0 256 512
10651.25
conv1 308.0 256 77
10959.25
conv2 1386.0 128 77
12345.25
conv3 616.0 128 308
12961.25
conv1 46.921875 39 77
13008.171875
conv2 422.296875 39 77
13430.46875
conv3 187.6875 39 308
13618.15625
conv1 616.0 512 77
14234.15625
conv2 1386.0 128 77
15620.15625
conv3 616.0 128 308
16236.15625
conv1 616.0 512 77
16852.15625
conv2 1386.0 128 77
18238.15625
conv3 616.0 128 308
18854.15625
down 2048.0 512 1024
20902.15625
conv1 308.0 512 154
21210.15625
conv2 1386.0 256 154
22596.15625
conv3 615.0 256 615
23211.15625
conv1 23.4609375 39 154
23234.6171875
conv2 211.1484375 39 154
23445.765625
conv3 93.69140625 39 615
23539.45703125
conv1 616.0 1024 154
24155.45703125
conv2 1386.0 256 154
255

In [71]:
net_from_config =custom_resnet_50(channel_config)
print(net_from_config.compute_flops_memory())
# res = net_from_config(torch.randn(2,3,224,224))
# res.shape

1433.25
375.375
1808.625
95.0625
1903.6875
855.5625
2759.25
375.375
3134.625
375.375
3510.0
855.5625
4365.5625
375.375
4740.9375
375.375
5116.3125
855.5625
5971.875
375.375
6347.25
741.125
7088.375
185.28125
7273.65625
833.765625
8107.421875
370.5625
8477.984375
370.5625
8848.546875
833.765625
9682.3125
370.5625
10052.875
370.5625
10423.4375
833.765625
11257.203125
370.5625
11627.765625
370.5625
11998.328125
833.765625
12832.09375
370.5625
13202.65625
739.921875
13942.578125
185.28125
14127.859375
833.765625
14961.625
369.9609375
15331.5859375
369.9609375
15701.546875
833.765625
16535.3125
369.9609375
16905.2734375
369.9609375
17275.234375
833.765625
18109.0
369.9609375
18478.9609375
369.9609375
18848.921875
833.765625
19682.6875
369.9609375
20052.6484375
369.9609375
20422.609375
833.765625
21256.375
369.9609375
21626.3359375
369.9609375
21996.296875
833.765625
22830.0625
369.9609375
23200.0234375
738.1201171875
23938.1435546875
184.98046875
24123.1240234375
833.765625
24956.8896484375

In [None]:
import numpy as np
from torchvision import transforms
from PIL import Image

im = Image.open("/home/eli/Downloads/cat_1.jpeg")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
imagenet_transforms_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

im_t = imagenet_transforms_test(im).reshape(1,3,224,224)

# net_from_config(im_t).argsort()

In [51]:
def get_conv_cost(m, down):
    assert isinstance(m, torch.nn.Conv2d)
    res = m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1]
    return res / (m.groups * down**2), res

In [40]:
mob =mobilenet_v2(pretrained=False)
downsample = 1
flops = 0
memory = 0
for m in mob.modules():
    if not isinstance(m, torch.nn.Conv2d):
        if isinstance(m, torch.nn.Linear):
            flops += m.in_features * m.out_features / (224**2)
            memory +=m.in_features * m.out_features
    else:
        if m.stride[0] == 2:
            downsample = 2 * downsample
        flops += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1] / (m.groups * downsample**2)
        memory += m.in_channels* m.out_channels* m.kernel_size[0]* m.kernel_size[1]
print(flops, memory, downsample)

5994.385204081633 44015840 32


In [38]:
76848 / 5994.385204081633 , 76848 / 5968.875

(12.8199969444195, 12.874787962555757)

In [41]:
5994.385204081633 *224 *224

300774272.0

In [66]:
sq_0 = False

sq = squeezenet1_0(False) if sq_0 else squeezenet1_1(False)

downsample = 4

flops, memory = get_conv_cost(sq.features[0], 2)

for i in range(3,13):
    if i in ([6,11] if sq_0 else [5,8]):
        downsample *= 2
    else:
        for m in sq.features[i].modules():
            if isinstance(m, torch.nn.Conv2d):
                if m.stride[0] == 2:
                    downsample = 2 * downsample
                f, m = get_conv_cost(m, downsample)
                flops += f
                memory += m
f, m = get_conv_cost(sq.classifier[1], downsample)
flops += f
memory += m     

print(flops, memory, downsample)

7720.0 1231552 16


In [67]:
7720.0*224**2

387358720.0

In [29]:
path = '/media/victoria/d/Training/Eli/resnet50_pre_0_995_w_0_25_gm_0_2_w_0_5_w_1_w_2_custom_timing/net_e_140_simple'
state_dict = torch.load(path)
net = custom_resnet_50(state_dict['channels_config'], 1000)
net.load_state_dict({k[7:]:v for k,v in state_dict['state_dict'].items()})
print(*net.compute_flops_memory())

19385.307238520407 12194386.0


In [None]:
print(*net.compute_flops_memory(False))

In [5]:
net = PrunedResnet70()
print(*net.compute_flops_memory())

22223.351482780614 8670406.0
