In [1]:
import torch, warnings

import torchvision.models as models

from models.build import NoAction
from utils.train import num_params

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
model_creating_fns = [
    models.resnet18,
    models.resnet50,
    models.alexnet,
    models.vgg16,
    models.squeezenet1_0,
    models.densenet161,
    models.inception_v3,
    models.googlenet,
    models.shufflenet_v2_x1_0,
    models.mobilenet_v2,
    models.mobilenet_v3_large,
    models.mobilenet_v3_small,
    models.resnext50_32x4d,
    models.wide_resnet50_2,
    models.mnasnet1_0,
    models.efficientnet_b0,
    models.efficientnet_b1,
    models.efficientnet_b2,
    models.efficientnet_b3,
    models.efficientnet_b4,
    models.efficientnet_b5,
    models.efficientnet_b6,
    models.efficientnet_b7,
    models.regnet_y_400mf,
    models.regnet_y_800mf,
    models.regnet_y_1_6gf,
    models.regnet_y_3_2gf,
    models.regnet_y_8gf,
    models.regnet_y_16gf,
    models.regnet_y_32gf,
    models.regnet_y_128gf,
    models.regnet_x_400mf,
    models.regnet_x_800mf,
    models.regnet_x_1_6gf,
    models.regnet_x_3_2gf,
    models.regnet_x_8gf,
    models.regnet_x_16gf,
    models.regnet_x_32gf,
    models.vit_b_16,
    models.vit_b_32,
    models.vit_l_16,
    models.vit_l_32,
    models.convnext_tiny,
    models.convnext_small,
    models.convnext_base,
    models.convnext_large,
]

In [3]:
test_input = torch.randn(1,3,512,512)

In [4]:
for creator in model_creating_fns:
    model = creator()   

    out_channels = "Unknown"
    
    if hasattr(model, "features"):
        out_channels = model.features(test_input).shape

    elif hasattr(model, "fc"):
        model.fc = NoAction()
        out_channels = model(test_input).shape

    elif hasattr(model, "classifier"):
        model.classifier = NoAction()
        out_channels = model(test_input).shape

    print(f"| [{creator.__name__}] | Size: [{num_params(model):,}] | Out: [{out_channels}] |")

| [resnet18] | Size: [11,176,512] | Out: [torch.Size([1, 512, 1, 1])] |
| [resnet50] | Size: [23,508,032] | Out: [torch.Size([1, 2048, 1, 1])] |
| [alexnet] | Size: [61,100,840] | Out: [torch.Size([1, 256, 15, 15])] |
| [vgg16] | Size: [138,357,544] | Out: [torch.Size([1, 512, 16, 16])] |
| [squeezenet1_0] | Size: [1,248,424] | Out: [torch.Size([1, 512, 31, 31])] |


KeyboardInterrupt: 

In [7]:
m = models.resnet18()
m.fc = NoAction()
m.avgpool = NoAction()
out = m(test_input)



In [8]:
out.shape

torch.Size([1, 512, 16, 16])

In [9]:
out.logits.shape

AttributeError: 'Tensor' object has no attribute 'logits'