In [1]:
import torchvision.models as models
from utils.train import num_params
import warnings
import torch

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)



In [2]:
model_creating_fns = [
    models.resnet18,
    models.resnet50,
    models.alexnet,
    models.vgg16,
    models.squeezenet1_0,
    models.densenet161,
    # models.inception_v3,
    # models.googlenet,
    models.shufflenet_v2_x1_0,
    models.mobilenet_v2,
    models.mobilenet_v3_large,
    models.mobilenet_v3_small,
    models.resnext50_32x4d,
    models.wide_resnet50_2,
    models.mnasnet1_0,
    models.efficientnet_b0,
    models.efficientnet_b1,
    models.efficientnet_b2,
    models.efficientnet_b3,
    models.efficientnet_b4,
    models.efficientnet_b5,
    models.efficientnet_b6,
    models.efficientnet_b7,
    models.regnet_y_400mf,
    models.regnet_y_800mf,
    models.regnet_y_1_6gf,
    models.regnet_y_3_2gf,
    models.regnet_y_8gf,
    models.regnet_y_16gf,
    models.regnet_y_32gf,
    models.regnet_y_128gf,
    models.regnet_x_400mf,
    models.regnet_x_800mf,
    models.regnet_x_1_6gf,
    models.regnet_x_3_2gf,
    models.regnet_x_8gf,
    models.regnet_x_16gf,
    models.regnet_x_32gf,
    models.vit_b_16,
    models.vit_b_32,
    models.vit_l_16,
    models.vit_l_32,
    models.convnext_tiny,
    models.convnext_small,
    models.convnext_base,
    models.convnext_large,
]



In [3]:
import torch.nn as nn

class NoAction(nn.Module):
    def __init__(self,):
        super().__init__()

    def forward(self,x):
        return x


In [4]:
test_input = torch.randn(1,3,1024,1024)

In [8]:
m = models.densenet161()

In [9]:
m

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(192, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (rel

In [5]:
for creator in model_creating_fns:
    model = creator()   

    out_channels = "Unknown"
    
    if hasattr(model, "features"):
        out_channels = model.features(test_input).shape[1]

    elif hasattr(model, "fc"):
        model.fc = NoAction()
        out_channels = model(test_input).shape[1]

    elif hasattr(model, "classifier"):
        model.classifier = NoAction()
        out_channels = model(test_input).shape[1]

    print(f"| [{creator.__name__}] | Size: [{num_params(model):,}] | Out: [{out_channels}] |")

| [resnet18] | Size: [11,176,512] | Out: [512] |
| [resnet50] | Size: [23,508,032] | Out: [2048] |
| [alexnet] | Size: [61,100,840] | Out: [256] |
| [vgg16] | Size: [138,357,544] | Out: [512] |
| [squeezenet1_0] | Size: [1,248,424] | Out: [512] |
| [densenet161] | Size: [28,681,000] | Out: [2208] |
| [shufflenet_v2_x1_0] | Size: [1,253,604] | Out: [1024] |
| [mobilenet_v2] | Size: [3,504,872] | Out: [1280] |
| [mobilenet_v3_large] | Size: [5,483,032] | Out: [960] |
| [mobilenet_v3_small] | Size: [2,542,856] | Out: [576] |
| [resnext50_32x4d] | Size: [22,979,904] | Out: [2048] |
| [wide_resnet50_2] | Size: [66,834,240] | Out: [2048] |
| [mnasnet1_0] | Size: [3,102,312] | Out: [1280] |
| [efficientnet_b0] | Size: [5,288,548] | Out: [1280] |
| [efficientnet_b1] | Size: [7,794,184] | Out: [1280] |
| [efficientnet_b2] | Size: [9,109,994] | Out: [1408] |
| [efficientnet_b3] | Size: [12,233,232] | Out: [1536] |
| [efficientnet_b4] | Size: [19,341,616] | Out: [1792] |
| [efficientnet_b5] | Siz

In [6]:
creator.__name__

'convnext_large'