In [5]:
import segmentation_models_pytorch as smp
import torch
import time
from thop import profile

def count_parametersWD(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

# Configurations to evaluate
encoders = ["mobilenet_v2", "efficientnet-b0"]
decoders = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "DeepLabV3": smp.DeepLabV3,
}

num_classes = 101   # use 101 for CropAndWeed
in_channels = 3
device = "cuda" if torch.cuda.is_available() else "cpu"

# Input resolution for FLOPs + inference test
input_size = (1, in_channels, 512, 512)   # batch_size=1, 512x512 image
dummy_input = torch.randn(input_size).to(device)

for enc in encoders:
    for dec_name, dec_class in decoders.items():
        model = dec_class(
            encoder_name=enc,
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=num_classes,
        ).to(device)
        model.eval()

        # Count params
        encoder_params = count_parametersWD(model.encoder)
        decoder_params = count_parametersWD(model.decoder)
        total_params = count_parametersWD(model)

        # FLOPs
        macs, params = profile(model, inputs=(dummy_input,), verbose=False)
        gflops = macs / 1e9

        # Inference time (average over N runs)
        N = 50
        with torch.no_grad():
            # warmup
            for _ in range(10):
                _ = model(dummy_input)

            torch.cuda.synchronize() if device == "cuda" else None
            start = time.perf_counter()

            for _ in range(N):
                _ = model(dummy_input)
            
            torch.cuda.synchronize() if device == "cuda" else None
            end = time.perf_counter()

        avg_time = (end - start) / N * 1000  # ms per image

        # Print results
        print(f"{dec_name} + {enc}")
        print(f"  Encoder params: {encoder_params/1e6:.2f}M")
        print(f"  Decoder params: {decoder_params/1e6:.2f}M")
        print(f"  Total params:   {total_params/1e6:.2f}M")
        print(f"  FLOPs:          {gflops:.2f} GFLOPs")
        print(f"  Inference time: {avg_time:.2f} ms/image\n")

Unet + mobilenet_v2
  Encoder params: 2.22M
  Decoder params: 4.40M
  Total params:   6.64M
  FLOPs:          17.36 GFLOPs
  Inference time: 4.83 ms/image

Unet++ + mobilenet_v2
  Encoder params: 2.22M
  Decoder params: 4.60M
  Total params:   6.84M
  FLOPs:          21.78 GFLOPs
  Inference time: 6.44 ms/image

DeepLabV3 + mobilenet_v2
  Encoder params: 2.22M
  Decoder params: 10.42M
  Total params:   12.67M
  FLOPs:          51.47 GFLOPs
  Inference time: 7.33 ms/image

Unet + efficientnet-b0
  Encoder params: 4.01M
  Decoder params: 2.24M
  Total params:   6.27M
  FLOPs:          13.91 GFLOPs
  Inference time: 8.28 ms/image

Unet++ + efficientnet-b0
  Encoder params: 4.01M
  Decoder params: 2.56M
  Total params:   6.58M
  FLOPs:          24.20 GFLOPs
  Inference time: 9.63 ms/image

DeepLabV3 + efficientnet-b0
  Encoder params: 4.01M
  Decoder params: 3.30M
  Total params:   7.33M
  FLOPs:          13.97 GFLOPs
  Inference time: 10.23 ms/image

