In [1]:
import os
import glob
import h5py
import numpy as np
import torch
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


Using device: cuda


In [3]:
def to_complex(x):
    """(H,W,2) → complex"""
    return x[..., 0] + 1j * x[..., 1]

def complex_abs(x):
    """torch complex magnitude, x: (...,2)"""
    return torch.sqrt(x[..., 0]**2 + x[..., 1]**2)

def nmse(gt, pred):
    return np.linalg.norm(gt - pred) ** 2 / (np.linalg.norm(gt) ** 2 + 1e-10)

def compute_ssim(gt, pred, max_val):
    return structural_similarity(
        gt, pred,
        data_range=max_val,
        win_size=9,
        gaussian_weights=False,
        use_sample_covariance=False,
        K1=0.01,
        K2=0.03
    )


In [4]:
MASK_PATH = "./mask_4x_320_random.npy"
mask = np.load(MASK_PATH)           # (1,1,H,1)
mask_t = torch.from_numpy(mask).float()


In [6]:
%run HFGN_Model.ipynb

In [8]:
model = FGNet().to(DEVICE)

ckpt = torch.load("fgnet_best.pth")
model.load_state_dict(ckpt["model"])

model.eval()
print("✔ FGNet weights loaded")


✔ FGNet weights loaded


In [9]:
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
file_paths = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

print("Number of volumes:", len(file_paths))


Number of volumes: 199


In [10]:
pd_files = []
pdfs_files = []

for f in file_paths:
    if "PDFS" in f:
        pdfs_files.append(f)
    else:
        pd_files.append(f)

print(f"PD volumes: {len(pd_files)}")
print(f"PDFS volumes: {len(pdfs_files)}")


PD volumes: 100
PDFS volumes: 99


In [14]:
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))
pd_files = []
pdfs_files = []

for f in kspace_files_list_val:
    if "PDFS" in f:
        pdfs_files.append(f)
    else:
        pd_files.append(f)

print(f"PD volumes: {len(pd_files)}")
print(f"PDFS volumes: {len(pdfs_files)}")


PD volumes: 100
PDFS volumes: 99


In [15]:
psnr_list = []
nmse_list = []
ssim_list = []
with torch.no_grad():
    for file in tqdm(pd_files, desc="Running FGNet inference"):
        with h5py.File(file, "r") as f:
            img_us   = f["image_under"][:]     # (S,H,W,2)
            img_gt   = f["image_full"][:]      # (S,H,W,2)
            kspace   = f["kspace_full"][:]     # (S,H,W,2)
            max_val  = float(f["max_val_full_image"][0])

        num_slices = img_us.shape[0]

        # --------------------------
        # Torch conversion
        # --------------------------
        img_us_t = torch.from_numpy(img_us).float().to(DEVICE)
        img_gt_t = torch.from_numpy(img_gt).float().to(DEVICE)
        kspace_t = torch.from_numpy(kspace).float().to(DEVICE)

        # Add channel dim for FGNet
        img_us_t = img_us_t.unsqueeze(1)     # (S,1,H,W,2)
        kspace_t = kspace_t.unsqueeze(1)     # (S,1,H,W,2)

        # Tile mask for slices
        mask_batch = mask_t.repeat(num_slices, 1,1, 1, 1).to(DEVICE)

        # --------------------------
        # FGNet Forward
        # --------------------------
        out, _ = model(img_us_t, kspace_t, mask_batch)

        # --------------------------
        # Scale back to original range
        # --------------------------
        pred_mag = complex_abs(out).cpu().numpy() * max_val
        gt_mag   = complex_abs(img_gt_t).cpu().numpy() * max_val

        # --------------------------
        # Volume metrics
        # --------------------------
        psnr_val = peak_signal_noise_ratio(
            gt_mag.flatten(), pred_mag.flatten(), data_range=max_val
        )

        nmse_val = nmse(gt_mag.flatten(), pred_mag.flatten())

        psnr_list.append(psnr_val)
        nmse_list.append(nmse_val)

        # --------------------------
        # Slice-wise SSIM
        # --------------------------
        for i in range(num_slices):
            ssim_val = compute_ssim(gt_mag[i], pred_mag[i], max_val)
            ssim_list.append(ssim_val)
print("\n" + "=" * 50)
print(f"PSNR (Mag, volume): {np.mean(psnr_list):.4f} ± {np.std(psnr_list):.4f} dB")
print(f"NMSE (Mag, volume): {np.mean(nmse_list):.6f} ± {np.std(nmse_list):.6f}")
print(f"SSIM (Mag, slice):  {np.mean(ssim_list):.4f} ± {np.std(ssim_list):.4f}")
print("=" * 50)


Running FGNet inference: 100%|████████████████████████████████████████████████████| 100/100 [3:28:05<00:00, 124.85s/it]


PSNR (Mag, volume): 35.7083 ± 3.0233 dB
NMSE (Mag, volume): 0.011012 ± 0.006443
SSIM (Mag, slice):  0.8564 ± 0.0648





In [7]:
import os
import time
import glob
import psutil
import numpy as np
import h5py
from tqdm import tqdm

import torch

# ============================================================
# CONFIG
# ============================================================
VAL_FOLDER = r"D:\fastmri_singlecoil_FSSCAN\val_norm"

WARMUP_SLICES = 10
NUM_TIMING_SLICES = 100   # fixed slice count

DEVICE_CPU = torch.device("cpu")
DEVICE_GPU = torch.device("cuda:0") if torch.cuda.is_available() else None

# ============================================================
# FILE LIST
# ============================================================
file_paths = sorted(glob.glob(os.path.join(VAL_FOLDER, "*.h5")))
assert len(file_paths) > 0, "No validation files found"

# ============================================================
# MODEL
# ============================================================
assert model is not None, "Model not loaded"
model.eval()

# ============================================================
# PARAMETER COUNT
# ============================================================
num_params = sum(p.numel() for p in model.parameters())

# ============================================================
# MEMORY HELPERS
# ============================================================
process = psutil.Process(os.getpid())

def cpu_memory_mb():
    return process.memory_info().rss / (1024 ** 2)

def gpu_memory_mb():
    return torch.cuda.max_memory_allocated() / (1024 ** 2)

# ============================================================
# LATENCY MEASUREMENT
# ============================================================
def measure_latency(device):

    model.to(device)
    latencies = []

    # -----------------------------
    # WARM-UP
    # -----------------------------
    with torch.no_grad():
        for file in file_paths[:1]:
            with h5py.File(file, "r") as f:
                img_us = f["image_under"][:]
                kspace = f["kspace_full"][:]

            for s in range(min(WARMUP_SLICES, img_us.shape[0])):
                x = torch.from_numpy(img_us[s:s+1]).float().unsqueeze(1).to(device)
                k = torch.from_numpy(kspace[s:s+1]).float().unsqueeze(1).to(device)
                m = mask_t[:1].to(device)

                _ = model(x, k, m)

                if device.type == "cuda":
                    torch.cuda.synchronize()

    # -----------------------------
    # TIMED INFERENCE
    # -----------------------------
    count = 0
    with torch.no_grad():
        for file in tqdm(file_paths, desc=f"Timing on {device}"):
            with h5py.File(file, "r") as f:
                img_us = f["image_under"][:]
                kspace = f["kspace_full"][:]

            for s in range(img_us.shape[0]):

                if count >= NUM_TIMING_SLICES:
                    break

                x = torch.from_numpy(img_us[s:s+1]).float().unsqueeze(1).to(device)
                k = torch.from_numpy(kspace[s:s+1]).float().unsqueeze(1).to(device)
                m = mask_t[:1].to(device)

                start = time.perf_counter()
                _ = model(x, k, m)

                if device.type == "cuda":
                    torch.cuda.synchronize()

                end = time.perf_counter()

                latencies.append(end - start)
                count += 1

            if count >= NUM_TIMING_SLICES:
                break

    latencies = np.array(latencies)

    return {
        "mean_s": latencies.mean(),
        "median_s": np.median(latencies),
        "std_s": latencies.std(),
        "slices_per_sec": 1.0 / latencies.mean()
    }

# ============================================================
# CPU BENCHMARK
# ============================================================
cpu_mem_before = cpu_memory_mb()
cpu_latency = measure_latency(DEVICE_CPU)
cpu_mem_after = cpu_memory_mb()
cpu_mem_peak = cpu_mem_after - cpu_mem_before

# ============================================================
# GPU BENCHMARK
# ============================================================
gpu_latency = None
gpu_mem_peak = None
gpu_name = None

if DEVICE_GPU:
    torch.cuda.reset_peak_memory_stats()
    gpu_name = torch.cuda.get_device_name(0)
    gpu_latency = measure_latency(DEVICE_GPU)
    gpu_mem_peak = gpu_memory_mb()

# ============================================================
# FINAL REPORT
# ============================================================
print("\n" + "=" * 70)
print("FGNet EFFICIENCY REPORT (SLICE-WISE, BATCH SIZE = 1)")
print("=" * 70)

print(f"Parameters: {num_params / 1e6:.2f} M")
print("FLOPs:      Not reported (complex-valued + FFT operations)")

print("\n--- CPU Inference ---")
print(f"Latency:     {cpu_latency['mean_s']:.2f} s / slice")
print(f"Throughput:  {cpu_latency['slices_per_sec']:.3f} slices/sec")
print(f"Memory:      {cpu_mem_peak:.2f} MB")

if gpu_latency:
    print("\n--- GPU Inference ---")
    print(f"GPU:         {gpu_name}")
    print(f"Latency:     {gpu_latency['mean_s']:.2f} s / slice")
    print(f"Throughput:  {gpu_latency['slices_per_sec']:.3f} slices/sec")
    print(f"Peak VRAM:   {gpu_mem_peak:.2f} MB")
else:
    print("\nGPU not available.")

print("=" * 70)


Timing on cpu:   1%|▋                                                               | 2/199 [03:49<6:16:23, 114.64s/it]
Timing on cuda:0:   1%|▋                                                               | 2/199 [00:08<14:36,  4.45s/it]


FGNet EFFICIENCY REPORT (SLICE-WISE, BATCH SIZE = 1)
Parameters: 1.16 M
FLOPs:      Not reported (complex-valued + FFT operations)

--- CPU Inference ---
Latency:     2.29 s / slice
Throughput:  0.436 slices/sec
Memory:      741.25 MB

--- GPU Inference ---
GPU:         NVIDIA RTX A5000
Latency:     0.09 s / slice
Throughput:  11.448 slices/sec
Peak VRAM:   631.31 MB



