In [1]:
import itertools

import pandas as pd
import torch
from rich.progress import track
from segmentation_models_pytorch import create_model

In [None]:
archs = [
    "Unet",
    # "UnetPlusPlus",
    # "DeepLabV3",
    # "DeepLabV3Plus",
    "MAnet",
    "UPerNet",
    "Segformer",
]

# All around 20M parameters
encoders = [
    "resnet50",
    "resnext50_32x4d",
    "mit_b2",
    "tu-convnextv2_tiny",
    "tu-hiera_tiny_224",
    "tu-mambaout_tiny",
    "tu-maxvit_tiny_rw_224",
    "tu-swin_s3_tiny_224",
]


c = 11
x = torch.randn(1, c, 256, 256)
x = torch.randn(1, c, 224, 224)

combs = []
for arch, encoder in track(itertools.product(archs, encoders), total=len(archs) * len(encoders)):
    error, n_params, out_shape = None, None, None
    try:
        model = create_model(arch, encoder, encoder_weights=None, in_channels=c)
        model.eval()
        y_hat = model(x)
        out_shape = y_hat.shape
        n_params = sum(p.numel() for p in model.parameters())
        n_params /= 1000 * 1000
        n_params = round(n_params, 2)
        # print(f"{arch} {encoder} {n_params / 1000 / 1000:.2f}M")
    except Exception as e:
        error = str(e)
    combs.append({"arch": arch, "encoder": encoder, "n_params": n_params, "out_shape": out_shape, "error": error})

combs = pd.DataFrame(combs)
combs