In [2]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet101

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

import pathlib
import os

In [3]:
model = resnet101(pretrained=False, num_classes=10)

model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()



In [5]:
class Fire(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes, stride = 1):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1, stride = stride)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1, stride = stride)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)

In [6]:
model.layer1[0] = Fire(64, 64, 128, 128)
model.layer1[1] = Fire(256, 64, 128, 128)
model.layer1[2] = Fire(256, 64, 128, 128)
model.layer2[0] = Fire(256, 128, 256, 256, 2)
model.layer2[1] = Fire(512, 128, 256, 256)
model.layer2[2] = Fire(512, 128, 256, 256)
model.layer2[3] = Fire(512, 128, 256, 256)
model.layer3[0] = Fire(512, 256, 512, 512, 2)
model.layer3[1] = Fire(1024, 256, 512, 512)
model.layer3[2] = Fire(1024, 256, 512, 512)
model.layer3[3] = Fire(1024, 256, 512, 512)
model.layer3[4] = Fire(1024, 256, 512, 512)
model.layer3[5] = Fire(1024, 256, 512, 512)
model.layer3[6] = Fire(1024, 256, 512, 512)
model.layer3[7] = Fire(1024, 256, 512, 512)
model.layer3[8] = Fire(1024, 256, 512, 512)
model.layer3[9] = Fire(1024, 256, 512, 512)
model.layer3[10] = Fire(1024, 256, 512, 512)
model.layer3[11] = Fire(1024, 256, 512, 512)
model.layer3[12] = Fire(1024, 256, 512, 512)
model.layer3[13] = Fire(1024, 256, 512, 512)
model.layer3[14] = Fire(1024, 256, 512, 512)
model.layer3[15] = Fire(1024, 256, 512, 512)
model.layer3[16] = Fire(1024, 256, 512, 512)
model.layer3[17] = Fire(1024, 256, 512, 512)
model.layer3[18] = Fire(1024, 256, 512, 512)
model.layer3[19] = Fire(1024, 256, 512, 512)
model.layer3[20] = Fire(1024, 256, 512, 512)
model.layer3[21] = Fire(1024, 256, 512, 512)
model.layer3[22] = Fire(1024, 256, 512, 512)
model.layer4[0] = Fire(1024, 512, 1024, 1024, 2)
model.layer4[1] = Fire(2048, 512, 1024, 1024)
model.layer4[2] = Fire(2048, 512, 1024, 1024)

In [7]:
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): Fire(
      (squeeze): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (1): Fire(
      (squeeze): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3

In [8]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

count_parameters(model)

pytorch_total_params = sum(p.numel() for p in model.parameters())
print("Total Params: ", pytorch_total_params)

+----------------------------+------------+
|          Modules           | Parameters |
+----------------------------+------------+
|        conv1.weight        |    1728    |
|         bn1.weight         |     64     |
|          bn1.bias          |     64     |
|  layer1.0.squeeze.weight   |    4096    |
|   layer1.0.squeeze.bias    |     64     |
| layer1.0.expand1x1.weight  |    8192    |
|  layer1.0.expand1x1.bias   |    128     |
| layer1.0.expand3x3.weight  |   73728    |
|  layer1.0.expand3x3.bias   |    128     |
|  layer1.1.squeeze.weight   |   16384    |
|   layer1.1.squeeze.bias    |     64     |
| layer1.1.expand1x1.weight  |    8192    |
|  layer1.1.expand1x1.bias   |    128     |
| layer1.1.expand3x3.weight  |   73728    |
|  layer1.1.expand3x3.bias   |    128     |
|  layer1.2.squeeze.weight   |   16384    |
|   layer1.2.squeeze.bias    |     64     |
| layer1.2.expand1x1.weight  |    8192    |
|  layer1.2.expand1x1.bias   |    128     |
| layer1.2.expand3x3.weight  |  

In [9]:
pip install -U fvcore

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 KB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting iopath>=0.1.7
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 KB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker
  Downloading portalocker-2.6.0-py2.py3-none-any.whl (15 kB)
Building wheels for collected packages: fvcore, iopath
  Building wheel for fvcore (setup.py) ... [?25l[?25hdone
  Created wheel for fvcore: filename=fvcore-0.1.5.post20221221-py3-none-any.whl size=61431 sha256=1339241ae5e4335dd5405485900c63e7

In [10]:
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table

input =  torch.rand(1, 3, 32, 32)

flops = FlopCountAnalysis(model, input)
print(flop_count_table(flops))
print("Total number of FLOPS: ", flops.total())

| module                 | #parameters or shape   | #flops     |
|:-----------------------|:-----------------------|:-----------|
| model                  | 56.281M                | 3.362G     |
|  conv1                 |  1.728K                |  1.769M    |
|   conv1.weight         |   (64, 3, 3, 3)        |            |
|  bn1                   |  0.128K                |  0.131M    |
|   bn1.weight           |   (64,)                |            |
|   bn1.bias             |   (64,)                |            |
|  layer1                |  0.284M                |  0.289G    |
|   layer1.0             |   86.336K              |   88.08M   |
|    layer1.0.squeeze    |    4.16K               |    4.194M  |
|    layer1.0.expand1x1  |    8.32K               |    8.389M  |
|    layer1.0.expand3x3  |    73.856K             |    75.497M |
|   layer1.1             |   98.624K              |   0.101G   |
|    layer1.1.squeeze    |    16.448K             |    16.777M |
|    layer1.1.expand1x1  

# Conclusion
The numbers of parameters and FLOPs are too big for further research.