In [None]:
!pip install thop

Collecting thop
  Using cached thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->thop)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->thop)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->thop)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->thop)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->thop)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->thop)
  Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-man

In [None]:
import torch
import torchvision.models as models
import timm
from thop import profile, clever_format

def get_model_stats(model_name, model_creator_fn, input_size=(1, 3, 224, 224)):
    """
    Calculates and prints the number of parameters and FLOPs for a given model.

    Args:
        model_name (str): Name of the model for display purposes.
        model_creator_fn (function): A function that returns an instance of the model.
        input_size (tuple, optional): The input tensor shape (batch_size, channels, height, width).
                                      Defaults to (1, 3, 224, 224).
    """
    try:
        model = model_creator_fn()
        model.eval()  # Set model to evaluation mode

        # Create a dummy input tensor
        dummy_input = torch.randn(input_size)

        # Use thop to profile the model
        # thop.profile returns MACs (Multiply-Accumulate Operations) and parameters.
        # For many common operations (like convolutions and linear layers), FLOPs ≈ 2 * MACs.
        macs, params = profile(model, inputs=(dummy_input,), verbose=False)

        # Convert MACs to FLOPs (commonly approximated as 2 * MACs)
        flops = 2 * macs

        # Use clever_format to get human-readable strings
        params_str, flops_str = clever_format([params, flops], "%.3f")

        # For more fine-grained GFLOPs, you can divide by 1e9
        gflops = flops / 1e9
        mparams = params / 1e6

        print(f"--- {model_name} ---")
        print(f"Input size: {input_size}")
        print(f"Parameters: {params_str} ({mparams:.3f} M)")
        print(f"FLOPs: {flops_str} ({gflops:.3f} GFLOPs)")
        print("-" * 30)
        return mparams, gflops

    except Exception as e:
        print(f"Could not profile {model_name}: {e}")
        print("-" * 30)
        return None, None

if __name__ == '__main__':
    # Standard input size for ImageNet models
    input_res = 500
    input_tensor_size = (1, 3, input_res, input_res)
    swin_res = 256
    input_tensor_size_256 = (1, 3, swin_res, swin_res)

    # --- ResNet Models ---
    get_model_stats("ResNet-18", lambda: models.resnet18(weights=None), input_tensor_size)
    get_model_stats("ResNet-34", lambda: models.resnet34(weights=None), input_tensor_size)
    get_model_stats("ResNet-50", lambda: models.resnet50(weights=None), input_tensor_size)

    # --- Swin Transformer Models ---
    # Swin-T Tiny (default window size 7)
    # Model name in timm: swin_tiny_patch4_window7_224
    get_model_stats(
        "Swinv2 Tiny (W16)",
        lambda: timm.create_model('swinv2_tiny_window16_256.ms_in1k', pretrained=False),
        input_tensor_size_256
    )

    # Swin-T Small (default window size 7)
    # Model name in timm: swin_small_patch4_window7_224
    get_model_stats(
        "Swinv2 Small (W16)",
        lambda: timm.create_model('swinv2_small_window16_256.ms_in1k', pretrained=False),
        input_tensor_size_256
    )

    # Swin-T Tiny with Window Size 8
    # We take the base 'swin_tiny_patch4_window7_224' and override the window_size parameter.
    # Note: Pretrained weights, if loaded, would be for window_size=7.
    # For FLOPs/param count, pretrained=False is fine and avoids mismatches.
    get_model_stats(
        "Swinv2 Tiny (W8)",
        lambda: timm.create_model('swinv2_tiny_window8_256.ms_in1k', pretrained=False, window_size=8),
        input_tensor_size_256
    )

    # Example for a different input resolution if needed for Swin (e.g., 256x256)
    # input_tensor_size_256 = (1, 3, 256, 256)
    # get_model_stats(
    #     "Swin-T Tiny (W8, 256x256 input)",
    #     lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, window_size=8, img_size=256),
    #     input_tensor_size_256
    # )
    # Note: The 'swin_tiny_patch4_window7_224' model in timm is hardcoded for 224x224.
    # For SwinV2 models, they often use larger window sizes and resolutions by default.
    # e.g., 'swinv2_tiny_window8_256' has window_size=8 and expects 256x256 input.
    # get_model_stats(
    #     "SwinV2-T Tiny (W8, 256x256 input)",
    #     lambda: timm.create_model('swinv2_tiny_window8_256', pretrained=False),
    #     (1,3,256,256)
    # )


--- ResNet-18 ---
Input size: (1, 3, 500, 500)
Parameters: 11.690M (11.690 M)
FLOPs: 18.637G (18.637 GFLOPs)
------------------------------
--- ResNet-34 ---
Input size: (1, 3, 500, 500)
Parameters: 21.798M (21.798 M)
FLOPs: 37.754G (37.754 GFLOPs)
------------------------------
--- ResNet-50 ---
Input size: (1, 3, 500, 500)
Parameters: 25.557M (25.557 M)
FLOPs: 42.398G (42.398 GFLOPs)
------------------------------
--- Swinv2 Tiny (W16) ---
Input size: (1, 3, 256, 256)
Parameters: 21.869M (21.869 M)
FLOPs: 8.888G (8.888 GFLOPs)
------------------------------
--- Swinv2 Small (W16) ---
Input size: (1, 3, 256, 256)
Parameters: 37.932M (37.932 M)
FLOPs: 17.283G (17.283 GFLOPs)
------------------------------
--- Swinv2 Tiny (W8) ---
Input size: (1, 3, 256, 256)
Parameters: 21.869M (21.869 M)
FLOPs: 8.742G (8.742 GFLOPs)
------------------------------
