In [2]:
pip install fvcore

Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting iopath>=0.1.7 (from fvcore)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath>=0.1.7->fvcore)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Building wheels for collected packages: fvcore, iopath
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Created wheel for fvcore: filename=fvcore-0.1.

In [4]:
import torch
import torchvision.models as models
from transformers import BertModel
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str

def analyze_model_flops(model, input_tensor, model_name):
    flops = FlopCountAnalysis(model, input_tensor)

    print(f"Total FLOPs for {model_name}: {flops.total()}")

    print(f"FLOPs by operator for {model_name}: {flops.by_operator()}")

    print(f"FLOPs by module for {model_name}: {flops.by_module()}")

    print(f"FLOPs by module and operator for {model_name}: {flops.by_module_and_operator()}")

    print(flop_count_table(flops))

    print(flop_count_str(flops))

dummy_input_cnn = torch.randn(1, 3, 224, 224)
dummy_input_bert = torch.randint(0, 30522, (1, 128))

In [5]:
vgg16_model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
vgg16_model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=4)
analyze_model_flops(vgg16_model, dummy_input_cnn, "VGG16")



Total FLOPs for VGG16: 15466209792
FLOPs by operator for VGG16: Counter({'conv': 15346630656, 'linear': 119554048, 'adaptive_avg_pool2d': 25088})
FLOPs by module for VGG16: Counter({'': 15466209792, 'features': 15346630656, 'features.2': 1849688064, 'features.7': 1849688064, 'features.12': 1849688064, 'features.14': 1849688064, 'features.19': 1849688064, 'features.21': 1849688064, 'features.5': 924844032, 'features.10': 924844032, 'features.17': 924844032, 'features.24': 462422016, 'features.26': 462422016, 'features.28': 462422016, 'classifier': 119554048, 'classifier.0': 102760448, 'features.0': 86704128, 'classifier.3': 16777216, 'avgpool': 25088, 'classifier.6': 16384, 'features.1': 0, 'features.3': 0, 'features.4': 0, 'features.6': 0, 'features.8': 0, 'features.9': 0, 'features.11': 0, 'features.13': 0, 'features.15': 0, 'features.16': 0, 'features.18': 0, 'features.20': 0, 'features.22': 0, 'features.23': 0, 'features.25': 0, 'features.27': 0, 'features.29': 0, 'features.30': 0, 

In [6]:
resnet50_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet50_model.fc = torch.nn.Linear(in_features=2048, out_features=4)
analyze_model_flops(resnet50_model, dummy_input_cnn, "ResNet50")



Total FLOPs for ResNet50: 4142814720
FLOPs by operator for ResNet50: Counter({'conv': 4087136256, 'batch_norm': 55569920, 'adaptive_avg_pool2d': 100352, 'linear': 8192})
FLOPs by module for ResNet50: Counter({'': 4142814720, 'layer3': 1475124224, 'layer2': 1043159040, 'layer4': 812374528, 'layer1': 690020352, 'layer2.0': 379029504, 'layer3.0': 375768064, 'layer4.0': 374137344, 'layer1.0': 241246208, 'layer1.1': 224387072, 'layer1.2': 224387072, 'layer2.1': 221376512, 'layer2.2': 221376512, 'layer2.3': 221376512, 'layer3.1': 219871232, 'layer3.2': 219871232, 'layer3.3': 219871232, 'layer3.4': 219871232, 'layer3.5': 219871232, 'layer4.1': 219118592, 'layer4.2': 219118592, 'conv1': 118013952, 'layer1.0.conv2': 115605504, 'layer1.1.conv2': 115605504, 'layer1.2.conv2': 115605504, 'layer2.0.conv2': 115605504, 'layer2.1.conv2': 115605504, 'layer2.2.conv2': 115605504, 'layer2.3.conv2': 115605504, 'layer3.0.conv2': 115605504, 'layer3.1.conv2': 115605504, 'layer3.2.conv2': 115605504, 'layer3.3.c

In [7]:
densenet121_model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
densenet121_model.classifier = torch.nn.Linear(in_features=1024, out_features=4)
analyze_model_flops(densenet121_model, dummy_input_cnn, "DenseNet121")



Total FLOPs for DenseNet121: 2911529216
FLOPs by operator for DenseNet121: Counter({'conv': 2833137664, 'batch_norm': 78337280, 'adaptive_avg_pool2d': 50176, 'linear': 4096})
FLOPs by module for DenseNet121: Counter({'': 2911529216, 'features': 2911474944, 'features.denseblock1': 1066039296, 'features.denseblock2': 733221888, 'features.denseblock3': 566813184, 'features.denseblock1.denselayer6': 211040256, 'features.denseblock1.denselayer5': 197693440, 'features.denseblock1.denselayer4': 184346624, 'features.denseblock1.denselayer3': 170999808, 'features.denseblock1.denselayer2': 157652992, 'features.denseblock1.denselayer1': 144306176, 'features.conv0': 118013952, 'features.denseblock1.denselayer1.conv2': 115605504, 'features.denseblock1.denselayer2.conv2': 115605504, 'features.denseblock1.denselayer3.conv2': 115605504, 'features.denseblock1.denselayer4.conv2': 115605504, 'features.denseblock1.denselayer5.conv2': 115605504, 'features.denseblock1.denselayer6.conv2': 115605504, 'feature

In [8]:
bert_model = BertModel.from_pretrained("bert-base-uncased")
analyze_model_flops(bert_model, dummy_input_bert, "BERT-uncased")

encoder.layer.0.attention.self.dropout, encoder.layer.1.attention.self.dropout, encoder.layer.10.attention.self.dropout, encoder.layer.11.attention.self.dropout, encoder.layer.2.attention.self.dropout, encoder.layer.3.attention.self.dropout, encoder.layer.4.attention.self.dropout, encoder.layer.5.attention.self.dropout, encoder.layer.6.attention.self.dropout, encoder.layer.7.attention.self.dropout, encoder.layer.8.attention.self.dropout, encoder.layer.9.attention.self.dropout


Total FLOPs for BERT-uncased: 10884513792
FLOPs by operator for BERT-uncased: Counter({'linear': 10872225792, 'layer_norm': 12288000})
FLOPs by module for BERT-uncased: Counter({'': 10884513792, 'encoder': 10883432448, 'encoder.layer': 10883432448, 'encoder.layer.0': 906952704, 'encoder.layer.1': 906952704, 'encoder.layer.2': 906952704, 'encoder.layer.3': 906952704, 'encoder.layer.4': 906952704, 'encoder.layer.5': 906952704, 'encoder.layer.6': 906952704, 'encoder.layer.7': 906952704, 'encoder.layer.8': 906952704, 'encoder.layer.9': 906952704, 'encoder.layer.10': 906952704, 'encoder.layer.11': 906952704, 'encoder.layer.0.attention': 302481408, 'encoder.layer.0.output': 302481408, 'encoder.layer.1.attention': 302481408, 'encoder.layer.1.output': 302481408, 'encoder.layer.2.attention': 302481408, 'encoder.layer.2.output': 302481408, 'encoder.layer.3.attention': 302481408, 'encoder.layer.3.output': 302481408, 'encoder.layer.4.attention': 302481408, 'encoder.layer.4.output': 302481408, 'enc