In [1]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
import torch
import torch.nn as nn
from torchvision.models import resnet18, resnet50

from source.layer_map import get_layer_list

![Original-ResNet-18-Architecture.png](Original-ResNet-18-Architecture.png)

In [3]:
orig_model = regnet_y_400mf(pretrained=True)
orig_model.eval()



RegNet(
  (stem): SimpleStemIN(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (trunk_output): Sequential(
    (block1): AnyStage(
      (block1-0): ResBottleneckBlock(
        (proj): Conv2dNormActivation(
          (0): Conv2d(32, 48, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (f): BottleneckTransform(
          (a): Conv2dNormActivation(
            (0): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (b): Conv2dNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=6, bias=False)
            

In [3]:
orig_model = regnet_y_3_2gf(pretrained=True)
orig_model.eval()



RegNet(
  (stem): SimpleStemIN(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (trunk_output): Sequential(
    (block1): AnyStage(
      (block1-0): ResBottleneckBlock(
        (proj): Conv2dNormActivation(
          (0): Conv2d(32, 72, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (f): BottleneckTransform(
          (a): Conv2dNormActivation(
            (0): Conv2d(32, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (b): Conv2dNormActivation(
            (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=3, bias=False)
            

In [3]:
orig_model = resnet50(pretrained=True)
orig_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
rate = 2.5
rank_map = {}
layer_list = get_layer_list('resnet50', downsample=False, conv1=False)
dummy_input = torch.rand(1, 3, 224, 224)

for layer_name in layer_list:
    layer = orig_model
    for attr in layer_name.split('.'):
        layer = layer.__getattr__(attr)
    
    weight = layer.weight.detach()
    bias = layer.bias
    kernel_size = layer.kernel_size
    groups = layer.groups
    print('Layer:', layer_name)
    print('bias:', bias is not None)
    print('Original Shape:', weight.shape)
    print('Kernel Size:', kernel_size)
    weight = weight.reshape((weight.shape[0], weight.shape[1], -1))
    if kernel_size == (1,1): weight = weight.reshape((weight.shape[0], weight.shape[1]))
    print('Flattened Shape:', weight.shape)
    
    print('Parameters:', weight.numel())
#     orig_macs = torch.tensor(weight.shape).prod() * y.shape[-1] * y.shape[-2]
#     print('Original MACs:', orig_macs.item())
#     ranks = list(int(weight.numel()/sum(list(weight.shape))/rate) for rate in REDUCTION_RATES)
    rank = int(weight.numel()/sum(list(weight.shape))/rate)
#     print('Reduction rate:', REDUCTION_RATES)
    print('Reduction rank:', rank)
    while rank % groups != 0:
        rank += 1
    print('Updated reduction rank:', rank)
#     if ltype != 'downsample': 
#         macs = [(rank * (c1 + c2 + c3) * 100/orig_macs).item() for rank in ranks]
#         print('Reduced MACs(%):', [round(mac, 2) for mac in macs])
    print()
    rank_map[layer_name] = rank

Layer: layer1.0.conv1
bias: False
Original Shape: torch.Size([64, 64, 1, 1])
Kernel Size: (1, 1)
Flattened Shape: torch.Size([64, 64])
Parameters: 4096
Reduction rank: 12
Updated reduction rank: 12

Layer: layer1.0.conv2
bias: False
Original Shape: torch.Size([64, 64, 3, 3])
Kernel Size: (3, 3)
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Reduction rank: 107
Updated reduction rank: 107

Layer: layer1.0.conv3
bias: False
Original Shape: torch.Size([256, 64, 1, 1])
Kernel Size: (1, 1)
Flattened Shape: torch.Size([256, 64])
Parameters: 16384
Reduction rank: 20
Updated reduction rank: 20

Layer: layer1.1.conv1
bias: False
Original Shape: torch.Size([64, 256, 1, 1])
Kernel Size: (1, 1)
Flattened Shape: torch.Size([64, 256])
Parameters: 16384
Reduction rank: 20
Updated reduction rank: 20

Layer: layer1.1.conv2
bias: False
Original Shape: torch.Size([64, 64, 3, 3])
Kernel Size: (3, 3)
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Reduction rank: 107
Updated reductio

In [5]:
rank_map

{'layer1.0.conv1': 12,
 'layer1.0.conv2': 107,
 'layer1.0.conv3': 20,
 'layer1.1.conv1': 20,
 'layer1.1.conv2': 107,
 'layer1.1.conv3': 20,
 'layer1.2.conv1': 20,
 'layer1.2.conv2': 107,
 'layer1.2.conv3': 20,
 'layer2.0.conv1': 34,
 'layer2.0.conv2': 222,
 'layer2.0.conv3': 40,
 'layer2.1.conv1': 40,
 'layer2.1.conv2': 222,
 'layer2.1.conv3': 40,
 'layer2.2.conv1': 40,
 'layer2.2.conv2': 222,
 'layer2.2.conv3': 40,
 'layer2.3.conv1': 40,
 'layer2.3.conv2': 222,
 'layer2.3.conv3': 40,
 'layer3.0.conv1': 68,
 'layer3.0.conv2': 452,
 'layer3.0.conv3': 81,
 'layer3.1.conv1': 81,
 'layer3.1.conv2': 452,
 'layer3.1.conv3': 81,
 'layer3.2.conv1': 81,
 'layer3.2.conv2': 452,
 'layer3.2.conv3': 81,
 'layer3.3.conv1': 81,
 'layer3.3.conv2': 452,
 'layer3.3.conv3': 81,
 'layer3.4.conv1': 81,
 'layer3.4.conv2': 452,
 'layer3.4.conv3': 81,
 'layer3.5.conv1': 81,
 'layer3.5.conv2': 452,
 'layer3.5.conv3': 81,
 'layer4.0.conv1': 136,
 'layer4.0.conv2': 913,
 'layer4.0.conv3': 163,
 'layer4.1.conv1':

In [7]:
{k:v[0] for (k,v) in rank_map.items()}

{'conv1': 27,
 'layer1.0.conv1': 89,
 'layer1.0.conv2': 89,
 'layer1.1.conv1': 89,
 'layer1.1.conv2': 89,
 'layer2.0.conv1': 122,
 'layer2.0.conv2': 185,
 'layer2.0.downsample': 14,
 'layer2.1.conv1': 185,
 'layer2.1.conv2': 185,
 'layer3.0.conv1': 250,
 'layer3.0.conv2': 377,
 'layer3.0.downsample': 28,
 'layer3.1.conv1': 377,
 'layer3.1.conv2': 377,
 'layer4.0.conv1': 506,
 'layer4.0.conv2': 761,
 'layer4.0.downsample': 56,
 'layer4.1.conv1': 761,
 'layer4.1.conv2': 761}

In [8]:
rank_map.keys()

dict_keys(['conv1', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.downsample', 'layer2.1.conv1', 'layer2.1.conv2', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.0.downsample', 'layer3.1.conv1', 'layer3.1.conv2', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.0.downsample', 'layer4.1.conv1', 'layer4.1.conv2'])