In [1]:
import math
import torch
import torch.nn as nn


In [2]:
def conv2d_flops(h, w, cin, cout, k):
    # mul + add
    return h * w * cout * (cin * k * k * 2)


In [3]:
def complex_conv_flops(x, module):
    """
    x: [B,C,H,W,2]
    """
    _, cin, h, w, _ = x.shape
    cout = module.out_channel
    k = module.conv_re.kernel_size[0]

    real_conv = conv2d_flops(h, w, cin, cout, k)
    return 2 * real_conv   # conv_re + conv_im


In [4]:
def fft2d_flops(x):
    """
    x: [B,C,H,W,2]
    """
    _, c, h, w, _ = x.shape
    n = h * w
    return 5 * c * n * math.log2(n)


In [5]:
def count_fgnet_flops(model, input_shape):
    """
    input_shape: (1,1,H,W,2)
    returns: total FLOPs (float)
    """

    model.eval()
    flops = 0

    # dummy inputs
    x = torch.zeros(input_shape)
    k = torch.zeros(input_shape)
    mask = torch.zeros((1,1,input_shape[2],1,1))

    def forward_hook(module, inp, out):
        nonlocal flops

        if isinstance(module, ComplexConv):
            flops += complex_conv_flops(inp[0], module)

        elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
            # used only in attention blocks (real-valued)
            x0 = inp[0]
            _, cin, h, w = x0.shape
            cout = module.out_channels
            k = module.kernel_size[0]
            flops += conv2d_flops(h, w, cin, cout, k)

    # register hooks
    hooks = []
    for m in model.modules():
        hooks.append(m.register_forward_hook(forward_hook))

    # ---- manual FFT accounting (based on your forward) ----
    # HF branch
    flops += fft2d_flops(x) * 2    # RCAB has fft + ifft
    flops += fft2d_flops(x) * 2

    # 4 Recon RCAB blocks
    flops += 4 * (fft2d_flops(x) * 2)

    # DC block
    flops += fft2d_flops(x)
    flops += fft2d_flops(x)

    # forward once (shape propagation only)
    with torch.no_grad():
        model(x, k, mask)

    # cleanup
    for h in hooks:
        h.remove()

    return flops


In [6]:
%run HFGN_Model.ipynb

In [7]:
model = FGNet()

ckpt = torch.load("fgnet_best.pth")
model.load_state_dict(ckpt["model"])

model.eval()
print("✔ FGNet weights loaded")


✔ FGNet weights loaded


In [8]:
H, W = 320, 320   # your image size

flops = count_fgnet_flops(
    model,
    input_shape=(1,1,H,W,2)
)

print(f"FGNet FLOPs per slice: {flops/1e9:.2f} GFLOPs")


FGNet FLOPs per slice: 696.50 GFLOPs


In [9]:
# ============================================================
# FGNet FLOPs Counter (Single Cell, Safe, Reproducible)
# ============================================================

import math
import torch
import torch.nn as nn

# ------------------------------------------------------------
# FLOPs utilities
# ------------------------------------------------------------
def conv2d_flops(h, w, cin, cout, k):
    # mul + add
    return h * w * cout * (cin * k * k * 2)

def complex_conv_flops(x, module):
    """
    x: [B,C,H,W,2]
    ComplexConv = conv_re + conv_im
    """
    _, cin, h, w, _ = x.shape
    cout = module.out_channel
    k = module.conv_re.kernel_size[0]
    real_conv = conv2d_flops(h, w, cin, cout, k)
    return 2 * real_conv

def fft2d_flops(x):
    """
    x: [B,C,H,W,2]
    FFT complexity ~ 5*N*log2(N)
    """
    _, c, h, w, _ = x.shape
    n = h * w
    return 5 * c * n * math.log2(n)

# ------------------------------------------------------------
# FGNet FLOPs counter
# ------------------------------------------------------------
def count_fgnet_flops(model, input_shape):
    """
    input_shape: (1,1,H,W,2)
    returns FLOPs per slice
    """
    model.eval()
    total_flops = 0

    # ----------------------------
    # Dummy inputs (batch = 1)
    # ----------------------------
    x = torch.zeros(input_shape)
    k = torch.zeros(input_shape)
    mask = torch.zeros((1, 1, input_shape[2], 1, 1))

    # ----------------------------
    # Forward hook
    # ----------------------------
    def hook_fn(module, inp, out):
        nonlocal total_flops

        if isinstance(module, ComplexConv):
            total_flops += complex_conv_flops(inp[0], module)

        elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
            x0 = inp[0]
            _, cin, h, w = x0.shape
            cout = module.out_channels
            k = module.kernel_size[0]
            total_flops += conv2d_flops(h, w, cin, cout, k)

    hooks = []
    for m in model.modules():
        hooks.append(m.register_forward_hook(hook_fn))

    # ----------------------------
    # Manual FFT / iFFT accounting
    # ----------------------------
    # HF branch: 2 RCAB (FFT + iFFT)
    total_flops += 2 * (fft2d_flops(x) * 2)

    # Recon branch: 4 RCAB (FFT + iFFT)
    total_flops += 4 * (fft2d_flops(x) * 2)

    # Data Consistency (fft + ifft)
    total_flops += fft2d_flops(x)
    total_flops += fft2d_flops(x)

    # ----------------------------
    # Forward pass (shape propagation)
    # ----------------------------
    with torch.no_grad():
        model(x, k, mask)

    # ----------------------------
    # Cleanup hooks
    # ----------------------------
    for h in hooks:
        h.remove()

    return total_flops

# ============================================================
# RUN FLOPs COUNT
# ============================================================
H, W = 320, 320   # image size

flops = count_fgnet_flops(
    model,
    input_shape=(1, 1, H, W, 2)
)

print(f"FGNet FLOPs per slice: {flops / 1e9:.2f} GFLOPs")


FGNet FLOPs per slice: 696.50 GFLOPs
