# Base NN Architectures

In [1]:
import torch

from scaling.models.resnet import resnet18, resnet50, resnet101, resnext18, resnext50, resnext101
from scaling.models.convnext import convnext_mini, convnext_tiny, convnext_small, convnext_base, convnext_large
from scaling.models.vit import vit_tiny, vit_small, vit_base, vit_large

In [2]:
def count_params(model):
    return sum([x.numel() for x in model.parameters() if x.requires_grad])

def test_models(*models):
    """Print model size and runs it on test input."""
    for model_fn in models:
        model = model_fn()
        params = count_params(model)
        print(f"{model_fn.__name__} params: \t {params:,}")
        with torch.no_grad():
            _ = model(torch.rand(1, 12, 992))

In [3]:
# ResNet
test_models(resnet18, resnet50, resnet101)

resnet18 params: 	 3,862,170
resnet50 params: 	 16,012,442
resnet101 params: 	 28,319,898


In [4]:
# ResNeXt
test_models(resnext18, resnext50, resnext101)

resnext18 params: 	 12,867,482
resnext50 params: 	 22,086,042
resnext101 params: 	 79,676,826


In [5]:
# ConvNeXt
test_models(convnext_mini, convnext_tiny, convnext_small, convnext_base, convnext_large)

convnext_mini params: 	 13,383,098
convnext_tiny params: 	 26,787,770
convnext_small params: 	 48,132,026
convnext_base params: 	 85,458,842
convnext_large params: 	 192,036,698


In [6]:
# ViTs
test_models(vit_tiny, vit_small, vit_base, vit_large)

No GPU detected, using math or mem efficient attention
vit_tiny params: 	 6,421,834
vit_small params: 	 25,441,114
vit_base params: 	 85,430,186
vit_large params: 	 302,982,298
