In [1]:
# repghost_unet.py
# U-Net with RepGhost blocks (PyTorch >= 1.10)
# Paper / repo (concept & fusion idea): RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization
# https://arxiv.org/abs/2211.06088  (module: add instead of concat; move ReLU; identity-BN branch; fuse for deploy)
# https://github.com/ChengpengChen/RepGhost

import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------
# Utils
# -------------------------

def _bn_to_scale_shift(bn: nn.BatchNorm2d):
    """Return per-channel scale and shift for folding BN into preceding linear op."""
    # y = gamma * (x - mean) / sqrt(var + eps) + beta  ==  scale * x + shift
    gamma = bn.weight
    beta = bn.bias
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps
    scale = gamma / torch.sqrt(var + eps)
    shift = beta - mean * scale
    return scale, shift


# -------------------------
# Squeeze-and-Excitation (optional)
# -------------------------

class SqueezeExcite(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        hidden = max(8, channels // reduction)
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, 1, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        w = self.fc(self.avg(x))
        return x * w


# -------------------------
# RepGhost module (training-time)  -> convertible to a single DW conv (deploy)
# Diagram matches Fig. 3(d-e) in the paper: depthwise conv branch + identity-BN branch, then ReLU.
# We assume Cin == Cout inside the module (preceded/followed by 1x1 convs in the bottleneck).
# -------------------------

class RepGhostModule(nn.Module):
    """
    Training graph:
        y = BN_dw(DW(x)) + BN_id(x)
        y = ReLU(y)
    Deploy graph (after convert_to_deploy):
        y = DW_fused(x) + bias; ReLU(y)
    """
    def __init__(self, channels: int, ksize: int = 3, stride: int = 1, deploy: bool = False):
        super().__init__()
        padding = ksize // 2
        self.channels = channels
        self.stride = stride
        self.ksize = ksize
        self.deploy = deploy

        if deploy:
            # Single depthwise conv with bias (fused)
            self.reparam = nn.Conv2d(channels, channels, ksize, stride, padding,
                                     groups=channels, bias=True)
            self.act = nn.ReLU(inplace=True)
        else:
            # Depthwise conv branch + BN
            self.dw = nn.Conv2d(channels, channels, ksize, stride, padding,
                                groups=channels, bias=False)
            self.dw_bn = nn.BatchNorm2d(channels)

            # Identity branch with BN (no spatial conv)
            self.id_bn = nn.BatchNorm2d(channels)

            self.act = nn.ReLU(inplace=True)

    @torch.no_grad()
    def convert_to_deploy(self):
        """Fuse BN_dw(DW) + BN_id into a single DW conv with bias."""
        if self.deploy:
            return

        # 1) Fold BN into depthwise conv weights/bias
        scale_dw, shift_dw = _bn_to_scale_shift(self.dw_bn)
        # dw conv has no bias:
        Wdw = self.dw.weight.clone()  # [C,1,kh,kw]
        # scale each channel's kernel by its scale_dw
        Wdw = Wdw * scale_dw.view(-1, 1, 1, 1)
        bdw = shift_dw.clone()  # [C]

        # 2) Convert BN_id(x) to a depthwise conv with an identity kernel
        scale_id, shift_id = _bn_to_scale_shift(self.id_bn)
        # Build an impulse (identity) kernel for depthwise conv
        k = torch.zeros_like(Wdw)  # [C,1,kh,kw]
        center = self.ksize // 2
        k[:, 0, center, center] = scale_id

        # 3) Sum both linear ops (same groups/channels), sum biases too
        W_fused = Wdw + k
        b_fused = bdw + shift_id

        # 4) Create reparam conv and load weights
        self.reparam = nn.Conv2d(self.channels, self.channels,
                                 self.ksize, self.stride, self.ksize // 2,
                                 groups=self.channels, bias=True)
        self.reparam.weight.data.copy_(W_fused)
        self.reparam.bias.data.copy_(b_fused)

        # 5) Cleanup training branches
        del self.dw, self.dw_bn, self.id_bn
        self.deploy = True

    def forward(self, x):
        if self.deploy:
            return self.act(self.reparam(x))
        else:
            y = self.dw_bn(self.dw(x))
            y = y + self.id_bn(x)
            return self.act(y)


# -------------------------
# RepGhost Bottleneck-ish block:
# 1x1 PW conv -> RepGhostModule -> (optional SE) -> 1x1 PW conv -> RepGhostModule
# Keeps Cin/Cout flexible (U-Net-style). Residual optional when Cin==Cout.
# -------------------------

class RGBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=False, se_reduction=16, residual=False):
        super().__init__()
        mid = out_ch // 2  # "thinner" middle channels (Fig. 4b hint)
        mid = max(8, mid)

        self.proj1 = nn.Sequential(
            nn.Conv2d(in_ch, mid, 1, bias=False),
            nn.BatchNorm2d(mid),
            nn.ReLU(inplace=True),
        )
        self.rg1 = RepGhostModule(mid, ksize=3, stride=1)

        self.se = SqueezeExcite(mid) if use_se else nn.Identity()

        self.proj2 = nn.Sequential(
            nn.Conv2d(mid, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        # second RG
        self.rg2 = RepGhostModule(out_ch, ksize=3, stride=1)

        self.residual = residual and (in_ch == out_ch)

    def forward(self, x):
        identity = x
        x = self.proj1(x)
        x = self.rg1(x)
        x = self.se(x)
        x = self.proj2(x)
        x = self.rg2(x)
        if self.residual:
            x = x + identity
        return x

    @torch.no_grad()
    def convert_to_deploy(self):
        self.rg1.convert_to_deploy()
        self.rg2.convert_to_deploy()


# -------------------------
# U-Net with RepGhost blocks
# -------------------------

class DoubleRG(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=False):
        super().__init__()
        self.b1 = RGBlock(in_ch, out_ch, use_se=use_se, residual=False)
        self.b2 = RGBlock(out_ch, out_ch, use_se=use_se, residual=True)

    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        return x

    @torch.no_grad()
    def convert_to_deploy(self):
        self.b1.convert_to_deploy()
        self.b2.convert_to_deploy()


class Down(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=False):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.block = DoubleRG(in_ch, out_ch, use_se=use_se)

    def forward(self, x):
        return self.block(self.pool(x))

    @torch.no_grad()
    def convert_to_deploy(self):
        self.block.convert_to_deploy()


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=False, bilinear=False):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        else:
            self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
            self.reduce = nn.Identity()
        self.block = DoubleRG(out_ch * 2, out_ch, use_se=use_se)

    def forward(self, x, skip):
        x = self.up(x)
        x = self.reduce(x)
        # pad if needed (odd dims)
        diffY = skip.size(-2) - x.size(-2)
        diffX = skip.size(-1) - x.size(-1)
        if diffY != 0 or diffX != 0:
            x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([skip, x], dim=1)
        return self.block(x)

    @torch.no_grad()
    def convert_to_deploy(self):
        self.block.convert_to_deploy()


class OutConv(nn.Module):
    def __init__(self, in_ch, n_classes):
        super().__init__
        self.conv = nn.Conv2d(in_ch, n_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class RepGhostUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, base_ch=32, use_se=False, bilinear=False):
        """
        base_ch=32 is a good lightweight start. Use 64 for larger models.
        """
        super().__init__()
        c1, c2, c3, c4, c5 = base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16

        self.inc   = DoubleRG(n_channels, c1, use_se=use_se)
        self.down1 = Down(c1, c2, use_se=use_se)
        self.down2 = Down(c2, c3, use_se=use_se)
        self.down3 = Down(c3, c4, use_se=use_se)
        self.down4 = Down(c4, c5, use_se=use_se)

        self.up1 = Up(c5, c4, use_se=use_se, bilinear=bilinear)
        self.up2 = Up(c4, c3, use_se=use_se, bilinear=bilinear)
        self.up3 = Up(c3, c2, use_se=use_se, bilinear=bilinear)
        self.up4 = Up(c2, c1, use_se=use_se, bilinear=bilinear)

        self.outc = nn.Conv2d(c1, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)     # [B, c1, H, W]
        x2 = self.down1(x1)  # [B, c2, H/2, W/2]
        x3 = self.down2(x2)  # [B, c3, H/4, W/4]
        x4 = self.down3(x3)  # [B, c4, H/8, W/8]
        x5 = self.down4(x4)  # [B, c5, H/16, W/16]

        x = self.up1(x5, x4)
        x = self.up2(x,  x3)
        x = self.up3(x,  x2)
        x = self.up4(x,  x1)
        logits = self.outc(x)
        return logits

    @torch.no_grad()
    def convert_to_deploy(self):
        """Fuse all RepGhost modules in-place for faster inference."""
        for m in self.modules():
            if isinstance(m, DoubleRG):
                m.convert_to_deploy()
            elif isinstance(m, Down) or isinstance(m, Up):
                m.convert_to_deploy()
            elif isinstance(m, RGBlock):
                m.convert_to_deploy()


# -------------------------
# Quick sanity test
# -------------------------
if __name__ == "__main__":
    model = RepGhostUNet(n_channels=3, n_classes=1, base_ch=32, use_se=False, bilinear=False)
    x = torch.randn(1, 3, 256, 256)
    y = model(x)
    print("out:", y.shape)  # -> [1, 1, 256, 256]

    # Convert to deploy (after training + eval)
    model.eval()
    model.convert_to_deploy()
    with torch.no_grad():
        y2 = model(x)
    print("deploy out:", y2.shape)


out: torch.Size([1, 1, 256, 256])
deploy out: torch.Size([1, 1, 256, 256])


In [2]:
# Model Profiling: GFLOPs, Memory, Parameters, and Inference Time
import torch
import time
import numpy as np
from typing import Tuple, Dict

def count_parameters(model):
    """Count total and trainable parameters."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def get_model_size_mb(model):
    """Calculate model size in MB."""
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_mb = (param_size + buffer_size) / 1024**2
    return size_mb

def calculate_flops(model, input_shape=(1, 3, 256, 256), device='cpu'):
    """
    Calculate FLOPs for the model using hook-based method.
    Returns GFLOPs (Giga FLOPs).
    """
    model = model.to(device)
    model.eval()
    
    flops_dict = {}
    
    def conv_hook(module, input, output):
        batch_size = input[0].size(0)
        output_height, output_width = output.size(2), output.size(3)
        
        kernel_height, kernel_width = module.kernel_size
        in_channels = module.in_channels
        out_channels = module.out_channels
        groups = module.groups
        
        # FLOPs = batch_size √ó output_spatial √ó (kernel_ops √ó in_channels / groups) √ó out_channels
        # kernel_ops = kernel_h √ó kernel_w
        # For bias, add output_height √ó output_width √ó out_channels
        
        conv_flops = batch_size * output_height * output_width * \
                     (kernel_height * kernel_width * in_channels // groups) * out_channels
        
        if module.bias is not None:
            conv_flops += batch_size * output_height * output_width * out_channels
        
        flops_dict[id(module)] = conv_flops
    
    def bn_hook(module, input, output):
        batch_size = input[0].size(0)
        flops = input[0].numel() * 2  # mean and variance
        flops_dict[id(module)] = flops
    
    def relu_hook(module, input, output):
        flops = input[0].numel()
        flops_dict[id(module)] = flops
    
    def linear_hook(module, input, output):
        batch_size = input[0].size(0)
        weight_ops = module.weight.numel()
        flops = batch_size * weight_ops
        if module.bias is not None:
            flops += batch_size * module.out_features
        flops_dict[id(module)] = flops
    
    hooks = []
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            hooks.append(module.register_forward_hook(conv_hook))
        elif isinstance(module, nn.BatchNorm2d):
            hooks.append(module.register_forward_hook(bn_hook))
        elif isinstance(module, nn.ReLU):
            hooks.append(module.register_forward_hook(relu_hook))
        elif isinstance(module, nn.Linear):
            hooks.append(module.register_forward_hook(linear_hook))
    
    with torch.no_grad():
        dummy_input = torch.randn(input_shape).to(device)
        _ = model(dummy_input)
    
    for hook in hooks:
        hook.remove()
    
    total_flops = sum(flops_dict.values())
    gflops = total_flops / 1e9
    
    return gflops

def measure_inference_time(model, input_shape=(1, 3, 256, 256), device='cpu', warmup=10, iterations=100):
    """
    Measure average inference time with warmup.
    Returns time in milliseconds.
    """
    model = model.to(device)
    model.eval()
    
    dummy_input = torch.randn(input_shape).to(device)
    
    # Warmup
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(dummy_input)
    
    # Measure
    if device == 'cuda':
        torch.cuda.synchronize()
    
    times = []
    with torch.no_grad():
        for _ in range(iterations):
            start = time.time()
            _ = model(dummy_input)
            if device == 'cuda':
                torch.cuda.synchronize()
            end = time.time()
            times.append(end - start)
    
    avg_time = np.mean(times) * 1000  # Convert to ms
    std_time = np.std(times) * 1000
    
    return avg_time, std_time

def get_activation_memory(model, input_shape=(1, 3, 256, 256), device='cpu'):
    """
    Estimate peak activation memory during forward pass.
    Returns memory in MB.
    """
    model = model.to(device)
    model.eval()
    
    activation_sizes = []
    
    def hook(module, input, output):
        if isinstance(output, torch.Tensor):
            size = output.numel() * output.element_size()
            activation_sizes.append(size)
        elif isinstance(output, (list, tuple)):
            for o in output:
                if isinstance(o, torch.Tensor):
                    size = o.numel() * o.element_size()
                    activation_sizes.append(size)
    
    hooks = []
    for module in model.modules():
        hooks.append(module.register_forward_hook(hook))
    
    with torch.no_grad():
        dummy_input = torch.randn(input_shape).to(device)
        _ = model(dummy_input)
    
    for h in hooks:
        h.remove()
    
    # Peak memory (sum of all activations stored)
    peak_memory_mb = sum(activation_sizes) / 1024**2
    
    return peak_memory_mb

def profile_model(model, input_shape=(1, 3, 256, 256), device='cpu', verbose=True):
    """
    Comprehensive model profiling.
    Returns dictionary with all metrics.
    """
    metrics = {}
    
    # Parameters
    total_params, trainable_params = count_parameters(model)
    metrics['total_parameters'] = total_params
    metrics['trainable_parameters'] = trainable_params
    metrics['non_trainable_parameters'] = total_params - trainable_params
    
    # Model size
    metrics['model_size_mb'] = get_model_size_mb(model)
    
    # GFLOPs
    metrics['gflops'] = calculate_flops(model, input_shape, device)
    
    # Activation memory
    metrics['activation_memory_mb'] = get_activation_memory(model, input_shape, device)
    
    # Total memory (model + activations)
    metrics['total_memory_mb'] = metrics['model_size_mb'] + metrics['activation_memory_mb']
    
    # Inference time
    avg_time, std_time = measure_inference_time(model, input_shape, device)
    metrics['avg_inference_time_ms'] = avg_time
    metrics['std_inference_time_ms'] = std_time
    metrics['fps'] = 1000.0 / avg_time
    
    if verbose:
        print("=" * 70)
        print("MODEL PROFILING RESULTS")
        print("=" * 70)
        print(f"\nüìä Model Architecture: {model.__class__.__name__}")
        print(f"   Input Shape: {input_shape}")
        print(f"   Device: {device}")
        print("\n" + "-" * 70)
        print("PARAMETERS:")
        print(f"   Total Parameters:        {total_params:,}")
        print(f"   Trainable Parameters:    {trainable_params:,}")
        print(f"   Non-trainable Parameters: {total_params - trainable_params:,}")
        print("\n" + "-" * 70)
        print("MEMORY:")
        print(f"   Model Size:              {metrics['model_size_mb']:.2f} MB")
        print(f"   Activation Memory:       {metrics['activation_memory_mb']:.2f} MB")
        print(f"   Total Memory:            {metrics['total_memory_mb']:.2f} MB")
        print("\n" + "-" * 70)
        print("COMPUTE:")
        print(f"   GFLOPs:                  {metrics['gflops']:.3f}")
        print(f"   Avg Inference Time:      {avg_time:.2f} ¬± {std_time:.2f} ms")
        print(f"   Throughput (FPS):        {metrics['fps']:.1f}")
        print("=" * 70)
    
    return metrics

In [3]:
# Example Usage: Profile RepGhostUNet model

# Create model
model = RepGhostUNet(n_channels=3, n_classes=1, base_ch=32, use_se=False, bilinear=False)

# Profile the model (training mode - before deployment)
print("\nüîç PROFILING TRAINING MODEL (Before convert_to_deploy)")
metrics_train = profile_model(model, input_shape=(1, 3, 256, 256), device='cpu', verbose=True)

# Convert to deployment mode
model.eval()
model.convert_to_deploy()

# Profile the deployed model
print("\n\nüöÄ PROFILING DEPLOYED MODEL (After convert_to_deploy)")
metrics_deploy = profile_model(model, input_shape=(1, 3, 256, 256), device='cpu', verbose=True)

# Compare training vs deployed
print("\n\nüìà COMPARISON: Training vs Deployed")
print("=" * 70)
print(f"Parameter Reduction:     {metrics_train['total_parameters']:,} ‚Üí {metrics_deploy['total_parameters']:,}")
print(f"Model Size Reduction:    {metrics_train['model_size_mb']:.2f} MB ‚Üí {metrics_deploy['model_size_mb']:.2f} MB")
print(f"GFLOPs Reduction:        {metrics_train['gflops']:.3f} ‚Üí {metrics_deploy['gflops']:.3f}")
print(f"Inference Time Speedup:  {metrics_train['avg_inference_time_ms']:.2f} ms ‚Üí {metrics_deploy['avg_inference_time_ms']:.2f} ms")
speedup = metrics_train['avg_inference_time_ms'] / metrics_deploy['avg_inference_time_ms']
print(f"Speedup Factor:          {speedup:.2f}x")
print("=" * 70)


üîç PROFILING TRAINING MODEL (Before convert_to_deploy)
MODEL PROFILING RESULTS

üìä Model Architecture: RepGhostUNet
   Input Shape: (1, 3, 256, 256)
   Device: cpu

----------------------------------------------------------------------
PARAMETERS:
   Total Parameters:        1,591,537
   Trainable Parameters:    1,591,537
   Non-trainable Parameters: 0

----------------------------------------------------------------------
MEMORY:
   Model Size:              6.17 MB
   Activation Memory:       1002.25 MB
   Total Memory:            1008.42 MB

----------------------------------------------------------------------
COMPUTE:
   GFLOPs:                  3.806
   Avg Inference Time:      55.67 ¬± 7.57 ms
   Throughput (FPS):        18.0


üöÄ PROFILING DEPLOYED MODEL (After convert_to_deploy)
MODEL PROFILING RESULTS

üìä Model Architecture: RepGhostUNet
   Input Shape: (1, 3, 256, 256)
   Device: cpu

----------------------------------------------------------------------
PARAMETERS:


In [4]:
# Optional: Profile different model configurations

configs = [
    {'base_ch': 16, 'use_se': False, 'name': 'Tiny (16 ch, no SE)'},
    {'base_ch': 32, 'use_se': False, 'name': 'Small (32 ch, no SE)'},
    {'base_ch': 32, 'use_se': True, 'name': 'Small (32 ch, with SE)'},
    {'base_ch': 64, 'use_se': False, 'name': 'Medium (64 ch, no SE)'},
    {'base_ch': 64, 'use_se': True, 'name': 'Medium (64 ch, with SE)'},
]

print("\nüìä COMPARING DIFFERENT MODEL CONFIGURATIONS")
print("=" * 100)
print(f"{'Configuration':<30} {'Params (M)':<15} {'GFLOPs':<12} {'Memory (MB)':<15} {'Time (ms)':<12} {'FPS':<10}")
print("=" * 100)

for config in configs:
    model = RepGhostUNet(n_channels=3, n_classes=1, 
                         base_ch=config['base_ch'], 
                         use_se=config['use_se'], 
                         bilinear=False)
    
    metrics = profile_model(model, input_shape=(1, 3, 256, 256), device='cpu', verbose=False)
    
    params_m = metrics['total_parameters'] / 1e6
    print(f"{config['name']:<30} {params_m:<15.2f} {metrics['gflops']:<12.3f} "
          f"{metrics['total_memory_mb']:<15.2f} {metrics['avg_inference_time_ms']:<12.2f} "
          f"{metrics['fps']:<10.1f}")

print("=" * 100)


üìä COMPARING DIFFERENT MODEL CONFIGURATIONS
Configuration                  Params (M)      GFLOPs       Memory (MB)     Time (ms)    FPS       
Tiny (16 ch, no SE)            0.41            1.056        503.01          26.65        37.5      
Tiny (16 ch, no SE)            0.41            1.056        503.01          26.65        37.5      
Small (32 ch, no SE)           1.59            3.806        1008.42         52.49        19.1      
Small (32 ch, no SE)           1.59            3.806        1008.42         52.49        19.1      
Small (32 ch, with SE)         1.62            3.806        1008.57         59.70        16.7      
Small (32 ch, with SE)         1.62            3.806        1008.57         59.70        16.7      
Medium (64 ch, no SE)          6.23            14.390       2027.98         106.02       9.4       
Medium (64 ch, no SE)          6.23            14.390       2027.98         106.02       9.4       
Medium (64 ch, with SE)        6.35            14.390

In [5]:
# Install required libraries (run once)
# !pip install thop torchinfo fvcore

In [6]:
# Model Profiling using Popular Libraries
from thop import profile, clever_format
from torchinfo import summary
import torch

def profile_with_thop(model, input_shape=(1, 3, 256, 256)):
    """Profile model using THOP library."""
    model.eval()
    dummy_input = torch.randn(input_shape)
    
    # Calculate FLOPs and parameters
    macs, params = profile(model, inputs=(dummy_input,), verbose=False)
    macs, params = clever_format([macs, params], "%.3f")
    
    print("=" * 70)
    print("THOP PROFILING")
    print("=" * 70)
    print(f"MACs (Multiply-Accumulate Operations): {macs}")
    print(f"Parameters: {params}")
    print("=" * 70)
    print("Note: FLOPs ‚âà 2 √ó MACs")
    
    return macs, params

def profile_with_torchinfo(model, input_shape=(1, 3, 256, 256), device='cpu'):
    """Profile model using torchinfo library."""
    print("\n" + "=" * 70)
    print("TORCHINFO SUMMARY")
    print("=" * 70)
    
    model_stats = summary(
        model,
        input_size=input_shape,
        device=device,
        col_names=["input_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        verbose=0
    )
    
    print(model_stats)
    return model_stats

def profile_with_fvcore(model, input_shape=(1, 3, 256, 256)):
    """Profile model using fvcore library (optional - more detailed)."""
    try:
        from fvcore.nn import FlopCountAnalysis, parameter_count
        
        model.eval()
        dummy_input = torch.randn(input_shape)
        
        flops = FlopCountAnalysis(model, dummy_input)
        params = parameter_count(model)
        
        print("\n" + "=" * 70)
        print("FVCORE PROFILING")
        print("=" * 70)
        print(f"Total FLOPs: {flops.total():,}")
        print(f"Total FLOPs (GFLOPs): {flops.total() / 1e9:.3f}")
        print(f"Total Parameters: {params['']:,}")
        print("=" * 70)
        
        return flops.total(), params['']
    except ImportError:
        print("\nfvcore not installed. Install with: pip install fvcore")
        return None, None

In [7]:
# Example: Profile RepGhostUNet using libraries

# Create a fresh model
model = RepGhostUNet(n_channels=3, n_classes=1, base_ch=32, use_se=False, bilinear=False)

print("\nüîç PROFILING TRAINING MODEL (Before convert_to_deploy)\n")

# Method 1: THOP
print("\n1Ô∏è‚É£ Using THOP:")
macs, params = profile_with_thop(model, input_shape=(1, 3, 256, 256))

# Method 2: Torchinfo (most detailed)
print("\n2Ô∏è‚É£ Using Torchinfo:")
stats = profile_with_torchinfo(model, input_shape=(1, 3, 256, 256))

# Method 3: FVCore (optional)
print("\n3Ô∏è‚É£ Using FVCore:")
flops, params_fv = profile_with_fvcore(model, input_shape=(1, 3, 256, 256))

# Now test deployed model
print("\n\nüöÄ PROFILING DEPLOYED MODEL (After convert_to_deploy)\n")
model.eval()
model.convert_to_deploy()

print("\n1Ô∏è‚É£ Using THOP (Deployed):")
macs_deploy, params_deploy = profile_with_thop(model, input_shape=(1, 3, 256, 256))

print("\n2Ô∏è‚É£ Using Torchinfo (Deployed):")
stats_deploy = profile_with_torchinfo(model, input_shape=(1, 3, 256, 256))


üîç PROFILING TRAINING MODEL (Before convert_to_deploy)


1Ô∏è‚É£ Using THOP:
THOP PROFILING
MACs (Multiply-Accumulate Operations): 3.898G
Parameters: 1.592M
Note: FLOPs ‚âà 2 √ó MACs

2Ô∏è‚É£ Using Torchinfo:

TORCHINFO SUMMARY
Layer (type (var_name))                       Input Shape               Output Shape              Param #                   Mult-Adds
RepGhostUNet (RepGhostUNet)                   [1, 3, 256, 256]          [1, 1, 256, 256]          --                        --
‚îú‚îÄDoubleRG (inc)                              [1, 3, 256, 256]          [1, 32, 256, 256]         --                        --
‚îÇ    ‚îî‚îÄRGBlock (b1)                           [1, 3, 256, 256]          [1, 32, 256, 256]         --                        --
‚îÇ    ‚îÇ    ‚îî‚îÄSequential (proj1)                [1, 3, 256, 256]          [1, 16, 256, 256]         80                        3,145,760
‚îÇ    ‚îÇ    ‚îî‚îÄRepGhostModule (rg1)              [1, 16, 256, 256]         [1, 16, 256, 256]     

Unsupported operator aten::add encountered 45 time(s)
Unsupported operator aten::max_pool2d encountered 4 time(s)
Unsupported operator aten::max_pool2d encountered 4 time(s)


Total FLOPs: 2,143,420,416
Total FLOPs (GFLOPs): 2.143
Total Parameters: 1,591,537


üöÄ PROFILING DEPLOYED MODEL (After convert_to_deploy)


1Ô∏è‚É£ Using THOP (Deployed):
THOP PROFILING
MACs (Multiply-Accumulate Operations): 3.706G
Parameters: 1.578M
Note: FLOPs ‚âà 2 √ó MACs

2Ô∏è‚É£ Using Torchinfo (Deployed):

TORCHINFO SUMMARY
Layer (type (var_name))                       Input Shape               Output Shape              Param #                   Mult-Adds
RepGhostUNet (RepGhostUNet)                   [1, 3, 256, 256]          [1, 1, 256, 256]          --                        --
‚îú‚îÄDoubleRG (inc)                              [1, 3, 256, 256]          [1, 32, 256, 256]         --                        --
‚îÇ    ‚îî‚îÄRGBlock (b1)                           [1, 3, 256, 256]          [1, 32, 256, 256]         --                        --
‚îÇ    ‚îÇ    ‚îî‚îÄSequential (proj1)                [1, 3, 256, 256]          [1, 16, 256, 256]         80                        3,145,7

## üìö Library Comparison

### **THOP (Recommended for simplicity)**
- ‚úÖ Easy to use
- ‚úÖ Gives MACs and parameters in human-readable format
- ‚úÖ Lightweight
- ‚ùå Less detailed breakdown

### **Torchinfo (Recommended for detailed analysis)**
- ‚úÖ Very detailed layer-by-layer breakdown
- ‚úÖ Shows input/output shapes
- ‚úÖ Shows memory usage per layer
- ‚úÖ Beautiful formatting
- ‚úÖ Can export to different formats

### **FVCore (Facebook Research)**
- ‚úÖ Very accurate FLOPs counting
- ‚úÖ Used in official PyTorch repositories
- ‚úÖ Can provide per-operation breakdown
- ‚ùå Slightly more complex API

### **Quick Start:**
```python
# Install (uncomment the first cell above and run it)
# pip install thop torchinfo fvcore

# Then just run the example cells!
```

### Note on MACs vs FLOPs:
- **MACs** (Multiply-Accumulate Operations): One multiply + one add
- **FLOPs** (Floating Point Operations): Individual operations
- **Relationship**: FLOPs ‚âà 2 √ó MACs (approximately)