In [1]:
!pip install flopco-pytorch



In [2]:
import torch
from torch import nn
import torchvision
import flopco
from flopco import FlopCo

from tqdm import tqdm as tqdm
import math
import time
import sys
import gc
import statistics

In [3]:

torch.random.manual_seed(10)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
device_cpu = torch.device("cpu")
print(device)

cuda


## ResNet

In [4]:
resnet18 = torchvision.models.resnet18()

![image.png](https://cdn-images-1.medium.com/max/759/1*PLDIbqMGxoSqWKsiqAGF4g.png)

In [5]:
stats_resnet18 = FlopCo(resnet18,
               img_size = (1, 3, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])

### Absolute flops in ResNet18

In [6]:
conv_kxk_flops, conv_1x1_flops, lins_flops = 0,0,0
for layer_name, layer_type in stats_resnet18.ltypes.items():
  if layer_type['type'] is nn.Conv2d:
    if layer_type['groups'] > 1 and layer_type['kernel_size'] != (1,1):
      conv_1x1_flops += stats_resnet18.flops[layer_name][0]
    else:
      conv_kxk_flops += stats_resnet18.flops[layer_name][0]
  elif layer_type['type'] is nn.Linear:
    lins_flops += stats_resnet18.flops[layer_name][0]

print("Depthwise conv flops:\t", conv_1x1_flops)
print("Standart conv flops:\t", conv_kxk_flops)
print("Linear flops:\t", lins_flops)

Depthwise conv flops:	 0
Standart conv flops:	 3627122688
Linear flops:	 1024512


### Relative flops in ResNet18

In [7]:
conv_kxk_flops, conv_1x1_flops, lins_flops = 0,0,0
for layer_name, layer_type in stats_resnet18.ltypes.items():
  if layer_type['type'] is nn.Conv2d:
    if layer_type['groups'] > 1 and layer_type['kernel_size'] != (1,1):
      conv_1x1_flops += stats_resnet18.relative_flops[layer_name]
    else:
      conv_kxk_flops += stats_resnet18.relative_flops[layer_name]
  elif layer_type['type'] is nn.Linear:
    lins_flops += stats_resnet18.relative_flops[layer_name]

print("Depthwise conv flop:\t", round(conv_1x1_flops, 4))
print("Standart conv flop:\t", round(conv_kxk_flops, 4))
print("Linear flop:\t", round(lins_flops, 4))

Depthwise conv flop:	 0
Standart conv flop:	 0.997
Linear flop:	 0.0003


In [8]:
# stats for different blocks of resnet18 (not used below)
stats_resnet18_l1 = FlopCo(resnet18.layer1,
               img_size = (1, 64, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])
stats_resnet18_l2 = FlopCo(resnet18.layer2,
               img_size = (1, 64, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])
stats_resnet18_l3 = FlopCo(resnet18.layer3,
               img_size = (1, 128, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])
stats_resnet18_l4 = FlopCo(resnet18.layer4,
               img_size = (1, 256, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])

## MobileNet

![image.png](https://user-images.githubusercontent.com/3350865/77837270-d2df3580-7199-11ea-9b2b-704966a3c19d.png)

In [9]:
mobilenet = torchvision.models.mobilenet_v2()

In [10]:
len(mobilenet.features), mobilenet.features[12]

(19,
 InvertedResidual(
   (conv): Sequential(
     (0): Conv2dNormActivation(
       (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU6(inplace=True)
     )
     (1): Conv2dNormActivation(
       (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
       (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU6(inplace=True)
     )
     (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ))

In [11]:
stats_mobilenet = FlopCo(mobilenet,
               img_size = (1, 3, 224, 224),
               device = device_cpu,
               instances = [nn.Conv2d, nn.BatchNorm2d, nn.Linear])


### Absolute flops in MobileNet

In [12]:
conv_kxk_flops, conv_1x1_flops, lins_flops = 0,0,0
for layer_name, layer_type in stats_mobilenet.ltypes.items():
  if layer_type['type'] is nn.Conv2d:
    if layer_type['groups'] > 1 and layer_type['kernel_size'] != (1,1):
      conv_1x1_flops += stats_mobilenet.flops[layer_name][0]
    else:
      conv_kxk_flops += stats_mobilenet.flops[layer_name][0]
  elif layer_type['type'] is nn.Linear:
    lins_flops += stats_mobilenet.flops[layer_name][0]
print("Depthwise conv flop:\t", conv_1x1_flops)
print("Standart conv flop:\t", conv_kxk_flops)
print("Linear flop:\t", lins_flops)

Depthwise conv flop:	 41432832
Standart conv flop:	 557555712
Linear flop:	 2561280


### Relative flops in MobileNet

In [13]:
conv_kxk_flops, conv_1x1_flops, lins_flops = 0,0,0
for layer_name, layer_type in stats_mobilenet.ltypes.items():
  if layer_type['type'] is nn.Conv2d:
    if layer_type['groups'] > 1 and layer_type['kernel_size'] != (1,1):
      conv_1x1_flops += stats_mobilenet.relative_flops[layer_name]
    else:
      conv_kxk_flops += stats_mobilenet.relative_flops[layer_name]
  elif layer_type['type'] is nn.Linear:
    lins_flops += stats_mobilenet.relative_flops[layer_name]
print("Depthwise conv flop:\t", round(conv_1x1_flops, 4))
print("Standart conv flop:\t", round(conv_kxk_flops, 4))
print("Linear flop:\t", round(lins_flops, 4))

Depthwise conv flop:	 0.0659
Standart conv flop:	 0.8875
Linear flop:	 0.0041
