In [2]:
# %run ./Modules/cv_can_new.ipynb
# %run sfanet_new-Copy2.ipynb

# %run sfanet_new_mpca_fsa-Copy1.ipynb
%run "./sfanet_new-Copy2.ipynb"

# %run sfanet_new_mpca_fsa.ipynb

In [3]:
model = build_dual_output_model()
# model.load_weights(".\SavedModels\weight_sfu_fastmri_complex_perploss_mpca_fsa2.h5")
model.load_weights("./weight_sfu_fastmri_complex_perploss_fsa.h5")


# output = model(input_slice)


In [4]:
import glob
import os
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
# files = sorted([os.path.join(val_folder, f) for f in os.listdir(val_folder) if f.endswith(".h5")])
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

In [5]:
kspace_files_list_val

['D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_000.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_001.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_002.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_003.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_004.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_005.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_006.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_007.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_008.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_009.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_010.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_011.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_012.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_013.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_014.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_015.h5',
 'D:\\fastmri_singlecoil_FSSCAN\\val_norm\\volume_016.h5

In [11]:
pd_files = []
pdfs_files = []

for f in kspace_files_list_val:
    if "PDFS" in f:
        pdfs_files.append(f)
    else:
        pd_files.append(f)

print(f"PD volumes: {len(pd_files)}")
print(f"PDFS volumes: {len(pdfs_files)}")


PD volumes: 483
PDFS volumes: 489


In [6]:
import os
import numpy as np
import h5py
import glob
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# Path to validation folder
# val_folder = "F:/denoised_preprocessed_h5_val"

# val_folder = r"E:\fastmri\val_norm"
# val_folder = r"D:\val_norm"
# train_folder = r"D:\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
# files = sorted([os.path.join(val_folder, f) for f in os.listdir(val_folder) if f.endswith(".h5")])
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))
# file_paths = kspace_files_list_val[0:5]

file_paths = kspace_files_list_val

# ----------------------
# HELPERS
# ----------------------
def to_complex(x):
    return x[..., 0] + 1j * x[..., 1]

def nmse(gt, pred):
    return np.linalg.norm(gt - pred) ** 2 / (np.linalg.norm(gt) ** 2 + 1e-10)

def compute_ssim(gt, pred, max_val):
    return structural_similarity(
        gt, pred,
        data_range=max_val,
        win_size=9,
        gaussian_weights=False,
        use_sample_covariance=False,
        K1=0.01,
        K2=0.03
    )

# ----------------------
# STORAGE
# ----------------------
ssim_list = []
psnr_list = []
nmse_list = []

# ----------------------
# PROCESSING
# ----------------------
for file in tqdm(file_paths, desc="Processing volumes"):
    with h5py.File(file, 'r') as f:
        image_full = f["image_full"][:]       # (slices, H, W, 2)
        image_under = f["image_under"][:]     # (slices, H, W, 2)
        max_val = float(f["max_val_full_image"][0])

#     mask_batch = np.tile(mask, (image_under.shape[0], 1, 1, 1)) 
    # Get model prediction (still in normalized form)
    # pred = model.predict([image_under,mask_batch,image_under], verbose=0)  # shape (slices, H, W, 2)
    pred = model.predict(image_under, verbose=0)  # shape (slices, H, W, 2)
    
    image_full *= max_val
    
    pred *= max_val  # Scale predicted output to original intensity range

    # Convert to complex and get magnitude
    gt_mag = np.abs(to_complex(image_full))
    pred_mag = np.abs(to_complex(pred))

    # Volume-wise PSNR and NMSE
    psnr_val = peak_signal_noise_ratio(gt_mag, pred_mag, data_range=max_val)
    nmse_val = nmse(gt_mag.flatten(), pred_mag.flatten())

    psnr_list.append(psnr_val)
    nmse_list.append(nmse_val)

    # Slice-wise SSIM
    for i in range(gt_mag.shape[0]):
        ssim_val = compute_ssim(gt_mag[i], pred_mag[i], max_val)
        ssim_list.append(ssim_val)

# ----------------------
# REPORT
# ----------------------
print("\n" + "=" * 40)
print(f"PSNR (Mag, volume): {np.mean(psnr_list):.2f} ¬± {np.std(psnr_list):.2f} dB")
print(f"NMSE (Mag, volume): {np.mean(nmse_list):.6f} ¬± {np.std(nmse_list):.6f}")
print(f"SSIM (Mag, slice):  {np.mean(ssim_list):.4f} ¬± {np.std(ssim_list):.4f}")

print("=" * 40)


Processing volumes: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 199/199 [05:26<00:00,  1.64s/it]


PSNR (Mag, volume): 34.27 ¬± 2.36 dB
NMSE (Mag, volume): 0.073141 ¬± 0.062938
SSIM (Mag, slice):  0.8332 ¬± 0.0754





In [12]:
import os
import time
import glob
import psutil
import numpy as np
import h5py
from tqdm import tqdm
import tensorflow as tf
from keras_flops import get_flops

# ============================================================
# CONFIGURATION
# ============================================================
VAL_FOLDER = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
WARMUP_SLICES = 10
NUM_TIMING_SLICES = 100  # increase for more stable stats

# ============================================================
# FILE LIST
# ============================================================
file_paths = sorted(glob.glob(os.path.join(VAL_FOLDER, "*.h5")))

# ============================================================
# MODEL MUST BE LOADED BEFORE THIS SCRIPT
# ============================================================
# model = ...
assert model is not None, "Model is not loaded."

# ============================================================
# PARAMETER COUNT
# ============================================================
num_params = model.count_params()

# ============================================================
# FLOPs (PER SLICE, BATCH SIZE = 1)
# ============================================================
flops = get_flops(model, batch_size=1)

# ============================================================
# MEMORY HELPERS
# ============================================================
process = psutil.Process(os.getpid())

def cpu_memory_mb():
    return process.memory_info().rss / (1024 ** 2)

def gpu_memory_mb():
    info = tf.config.experimental.get_memory_info("GPU:0")
    return info["peak"] / (1024 ** 2)

# ============================================================
# LATENCY / THROUGHPUT MEASUREMENT
# ============================================================
def measure_latency(device):

    latencies = []

    with tf.device(device):

        # -----------------------------
        # WARM-UP
        # -----------------------------
        for file in file_paths[:1]:
            with h5py.File(file, "r") as f:
                image_under = f["image_under"][:]

            for s in range(min(WARMUP_SLICES, image_under.shape[0])):
                _ = model(image_under[s:s+1], training=False)

        # -----------------------------
        # TIMED INFERENCE
        # -----------------------------
        count = 0
        for file in tqdm(file_paths, desc=f"Timing on {device}"):

            with h5py.File(file, "r") as f:
                image_under = f["image_under"][:]

            for s in range(image_under.shape[0]):

                if count >= NUM_TIMING_SLICES:
                    break

                slice_input = image_under[s:s+1]
                assert slice_input.shape[0] == 1  # batch size check

                start = time.perf_counter()
                _ = model(slice_input, training=False)

                # üîë GPU synchronization
                if "GPU" in device:
                    tf.config.experimental.get_memory_info("GPU:0")

                end = time.perf_counter()

                latencies.append(end - start)
                count += 1

            if count >= NUM_TIMING_SLICES:
                break

    latencies = np.array(latencies)

    mean_s = latencies.mean()
    median_s = np.median(latencies)

    return {
        "mean_s": mean_s,
        "median_s": median_s,
        "std_s": latencies.std(),
        "slices_per_sec_mean": 1.0 / mean_s,
        "slices_per_sec_median": 1.0 / median_s
    }

# ============================================================
# CPU BENCHMARK
# ============================================================
cpu_mem_before = cpu_memory_mb()
cpu_latency = measure_latency("/CPU:0")
cpu_mem_after = cpu_memory_mb()
cpu_mem_peak = cpu_mem_after - cpu_mem_before

# ============================================================
# GPU BENCHMARK (IF AVAILABLE)
# ============================================================
gpu_latency = None
gpu_mem_peak = None
gpu_name = None

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.reset_memory_stats("GPU:0")
    gpu_name = tf.config.experimental.get_device_details(gpus[0])["device_name"]
    gpu_latency = measure_latency("/GPU:0")
    gpu_mem_peak = gpu_memory_mb()

# ============================================================
# FINAL REPORT
# ============================================================
print("\n" + "=" * 70)
print("MODEL EFFICIENCY REPORT (BATCH SIZE = 1)")
print("=" * 70)

print(f"Parameters: {num_params / 1e6:.2f} M")
print(f"FLOPs:      {flops / 1e9:.2f} GFLOPs (per slice)")

print("\n--- CPU Inference ---")
print(f"Mean latency:   {cpu_latency['mean_s']:.2f} s / slice")
print(f"Throughput:    {cpu_latency['slices_per_sec_mean']:.3f} slices/sec")
print(f"Memory usage:  {cpu_mem_peak:.2f} MB")

if gpu_latency:
    print("\n--- GPU Inference ---")
    print(f"GPU:           {gpu_name}")
    print(f"Mean latency:  {gpu_latency['mean_s']:.2f} s / slice")
    print(f"Throughput:   {gpu_latency['slices_per_sec_mean']:.3f} slices/sec")
    print(f"Peak memory:  {gpu_mem_peak:.2f} MB")
else:
    print("\nGPU not available.")

print("=" * 70)


Timing on /CPU:0:   1%|‚ñå                                                           | 2/199 [08:57<14:41:35, 268.50s/it]
Timing on /GPU:0:   1%|‚ñã                                                               | 2/199 [00:19<32:25,  9.88s/it]


MODEL EFFICIENCY REPORT (BATCH SIZE = 1)
Parameters: 22.22 M
FLOPs:      159.10 GFLOPs (per slice)

--- CPU Inference ---
Mean latency:   5.37 s / slice
Throughput:    0.186 slices/sec
Memory usage:  0.08 MB

--- GPU Inference ---
GPU:           NVIDIA RTX A5000
Mean latency:  0.20 s / slice
Throughput:   5.082 slices/sec
Peak memory:  626.97 MB





In [7]:
import os
import numpy as np
import h5py
import glob
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ============================================================
# PATHS
# ============================================================
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
file_paths = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# ============================================================
# HELPERS
# ============================================================
def to_complex(x):
    """Convert [..., 2] real/imag array to complex."""
    return x[..., 0] + 1j * x[..., 1]

def nmse(gt, pred):
    """Normalized MSE."""
    return np.linalg.norm(gt - pred) ** 2 / (np.linalg.norm(gt) ** 2 + 1e-10)

def compute_ssim(gt, pred, max_val):
    return structural_similarity(
        gt, pred,
        data_range=max_val,
        win_size=9,
        gaussian_weights=False,
        use_sample_covariance=False,
        K1=0.01,
        K2=0.03
    )

def confidence_score(gt_mag, pred_mag):
    """
    Residual-based confidence score.
    Higher value => higher reconstruction reliability.
    """
    num = np.linalg.norm(gt_mag - pred_mag) ** 2
    den = np.linalg.norm(gt_mag) ** 2 + 1e-10
    return 1.0 - num / den

# ============================================================
# STORAGE
# ============================================================
psnr_list = []          # volume-wise
nmse_list = []          # volume-wise
ssim_list = []          # slice-wise
confidence_list = []    # slice-wise
confidence_vol_list = []  # volume-wise

# ============================================================
# PROCESSING
# ============================================================
for file in tqdm(file_paths, desc="Processing volumes"):

    with h5py.File(file, 'r') as f:
        image_full = f["image_full"][:]        # (S, H, W, 2)
        image_under = f["image_under"][:]      # (S, H, W, 2)
        max_val = float(f["max_val_full_image"][0])

    # --------------------------------------------------------
    # MODEL INFERENCE (normalized domain)
    # --------------------------------------------------------
    pred = model.predict(image_under, verbose=0)  # (S, H, W, 2)

    # --------------------------------------------------------
    # SCALE BACK TO ORIGINAL INTENSITY RANGE
    # --------------------------------------------------------
    image_full = image_full * max_val
    pred = pred * max_val

    # --------------------------------------------------------
    # MAGNITUDE IMAGES
    # --------------------------------------------------------
    gt_mag = np.abs(to_complex(image_full))
    pred_mag = np.abs(to_complex(pred))

    # --------------------------------------------------------
    # VOLUME-WISE METRICS
    # --------------------------------------------------------
    psnr_val = peak_signal_noise_ratio(
        gt_mag, pred_mag, data_range=max_val
    )
    nmse_val = nmse(gt_mag.flatten(), pred_mag.flatten())

    psnr_list.append(psnr_val)
    nmse_list.append(nmse_val)

    # --------------------------------------------------------
    # SLICE-WISE METRICS + CONFIDENCE
    # --------------------------------------------------------
    slice_confidence = []

    for s in range(gt_mag.shape[0]):
        ssim_val = compute_ssim(gt_mag[s], pred_mag[s], max_val)
        conf_val = confidence_score(gt_mag[s], pred_mag[s])

        ssim_list.append(ssim_val)
        confidence_list.append(conf_val)
        slice_confidence.append(conf_val)

    # Volume-wise confidence (mean over slices)
    confidence_vol_list.append(np.mean(slice_confidence))

# ============================================================
# REPORT
# ============================================================
print("\n" + "=" * 50)
print("RECONSTRUCTION PERFORMANCE (VALIDATION SET)")
print("=" * 50)

print(f"PSNR  (Mag, volume): {np.mean(psnr_list):.2f} ¬± {np.std(psnr_list):.2f} dB")
print(f"NMSE  (Mag, volume): {np.mean(nmse_list):.6f} ¬± {np.std(nmse_list):.6f}")
print(f"SSIM  (Mag, slice):  {np.mean(ssim_list):.4f} ¬± {np.std(ssim_list):.4f}")
print(f"CONF  (slice-wise):  {np.mean(confidence_list):.4f} ¬± {np.std(confidence_list):.4f}")
print(f"CONF  (volume-wise): {np.mean(confidence_vol_list):.4f} ¬± {np.std(confidence_vol_list):.4f}")

print("=" * 50)


Processing volumes: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 199/199 [05:15<00:00,  1.59s/it]


RECONSTRUCTION PERFORMANCE (VALIDATION SET)
PSNR  (Mag, volume): 34.27 ¬± 2.36 dB
NMSE  (Mag, volume): 0.016860 ¬± 0.007025
SSIM  (Mag, slice):  0.8332 ¬± 0.0754
CONF  (slice-wise):  0.9792 ¬± 0.0177
CONF  (volume-wise): 0.9790 ¬± 0.0079





In [10]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

# -----------------------
# HELPERS
# -----------------------
def to_complex(x):
    return x[..., 0] + 1j * x[..., 1]

def compute_metrics(gt, pred):
    eps = 1e-10
    nmse = np.linalg.norm(gt - pred) ** 2 / (np.linalg.norm(gt) ** 2 + eps)
    psnr = peak_signal_noise_ratio(np.abs(gt), np.abs(pred), data_range=np.abs(gt).max() - np.abs(gt).min())
    ssim = structural_similarity(
        np.abs(gt), np.abs(pred),
        data_range=np.abs(gt).max() - np.abs(gt).min(),
        win_size=9,
        gaussian_weights=False,
        use_sample_covariance=False,
        K1=0.01,
        K2=0.03
    )
    return nmse, psnr, ssim

def overlay_stats(ax, text):
    ax.text(
        0.05, 0.95, text,
        transform=ax.transAxes,
        ha="left", va="top",
        color="yellow", fontsize=20, fontweight="bold",
        bbox=dict(facecolor="black", alpha=0.5, pad=2)
    )

def plot_slice(gt, under, pred, slice_idx, pdf, max_val):
    domains = ['Real', 'Imag', 'Abs']
    gt_split = [gt.real, gt.imag, np.abs(gt)]
    under_split = [under.real, under.imag, np.abs(under)]
    pred_split = [pred.real, pred.imag, np.abs(pred)]

    # Compute error maps
    error_abs_under = [np.abs(gt_split[i] - under_split[i]) for i in range(3)]
    error_abs_pred = [np.abs(gt_split[i] - pred_split[i]) for i in range(3)]
    error_signed_under = [gt_split[i] - under_split[i] for i in range(3)]
    error_signed_pred = [gt_split[i] - pred_split[i] for i in range(3)]

    
    # fig, axs = plt.subplots(3, 5, figsize=(28, 20), constrained_layout=True)

    fig, axs = plt.subplots(3, 5, figsize=(28, 18))  # 3 rows (domains) √ó 5 columns (types)
    plt.suptitle(f"Slice {slice_idx}", fontsize=28, fontweight="bold")

    for row, domain in enumerate(domains):
        # Compute metrics
        nmse_u = compute_metrics(gt_split[row], under_split[row])[0]
        psnr_u = compute_metrics(gt_split[row], under_split[row])[1]
        ssim_u = compute_metrics(gt_split[row], under_split[row])[2]

        nmse_p = compute_metrics(gt_split[row], pred_split[row])[0]
        psnr_p = compute_metrics(gt_split[row], pred_split[row])[1]
        ssim_p = compute_metrics(gt_split[row], pred_split[row])[2]

        metric_text_under = f"NMSE: {nmse_u:.4f}\nPSNR: {psnr_u:.2f} dB\nSSIM: {ssim_u:.4f}"
        metric_text_pred = f"NMSE: {nmse_p:.4f}\nPSNR: {psnr_p:.2f} dB\nSSIM: {ssim_p:.4f}"

        vmax_abs = max(np.max(error_abs_under[row]), np.max(error_abs_pred[row]))
        # vmax_signed = max(np.max(np.abs(error_signed_under[i])), np.max(np.abs(error_signed_pred[i])))

        axs[row, 0].imshow(gt_split[row], cmap="gray")
        # axs[row, 0].set_title(f"GT {domain}", fontsize=18)
        axs[row, 0].axis("off")

        axs[row, 1].imshow(under_split[row], cmap="gray")
        # axs[row, 1].set_title(f"Undersampled {domain}", fontsize=18)
        axs[row, 1].axis("off")
        overlay_stats(axs[row, 1], metric_text_under)

        axs[row, 2].imshow(pred_split[row], cmap="gray")
        # axs[row, 2].set_title(f"Reconstructed {domain}", fontsize=18)
        axs[row, 2].axis("off")
        overlay_stats(axs[row, 2], metric_text_pred)

        axs[row, 3].imshow(error_abs_under[row], cmap="hot", vmin=0, vmax=vmax_abs)
        # axs[row, 3].set_title(f"|GT ‚àí Und| {domain}", fontsize=18)
        axs[row, 3].axis("off")

        axs[row, 4].imshow(error_abs_pred[row], cmap="hot", vmin=0, vmax=vmax_abs)
        # axs[row, 4].set_title(f"|GT ‚àí Pred| {domain}", fontsize=18)
        axs[row, 4].axis("off")

    # plt.tight_layout(rect=[0, 0, 1, 0.95])
    fig.subplots_adjust(hspace=0.1,wspace=0.1) 
    pdf.savefig(fig)
    plt.close()

# -----------------------
# LOAD DATA
# -----------------------
# file_path = "your_file_path_here.h5"  # ‚Üê Replace this path
val_folder = r"D:\val_norm"
# files = sorted([os.path.join(val_folder, f) for f in os.listdir(val_folder) if f.endswith(".h5")])
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))
file_path = kspace_files_list_val[0]


with h5py.File(file_path, 'r') as f:
    image_full = f["image_full"][:]
    image_under = f["image_under"][:]
    max_val = float(f["max_val_full_image"][0])

# -----------------------
# CONVERT TO COMPLEX
# -----------------------
gt_complex = to_complex(image_full)
under_complex = to_complex(image_under)

# -----------------------
# MODEL PREDICTION
# -----------------------
pred_slices = []
for i in range(image_under.shape[0]):
    input_slice = image_under[i:i+1]
    pred_slice = model.predict(input_slice, verbose=0)[0]
    pred_slices.append(pred_slice)

pred_slices = np.array(pred_slices)
pred_complex = to_complex(pred_slices)

# -----------------------
# PDF VISUALIZATION
# -----------------------
output_pdf_path = "volume_visualization_fsa_mse_cmap_hot.pdf"
with PdfPages(output_pdf_path) as pdf:
    # slice_indices = np.linspace(0, gt_complex.shape[0] - 1, 4, dtype=int)
    slice_indices = range(gt_complex.shape[0])

    for idx in slice_indices:
        plot_slice(gt_complex[idx], under_complex[idx], pred_complex[idx], idx, pdf, max_val)
        
    # Summary
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.set_title("Volume Intensity Summary", fontsize=18)
    ax.axis("off")

    text = (
        f"GT Abs:    min={np.abs(gt_complex).min():.4f}, max={np.abs(gt_complex).max():.4f}\n"
        f"Under Abs: min={np.abs(under_complex).min():.4f}, max={np.abs(under_complex).max():.4f}\n"
        f"Pred Abs:  min={np.abs(pred_complex).min():.4f}, max={np.abs(pred_complex).max():.4f}\n\n"
        f"GT Real:   min={gt_complex.real.min():.4f}, max={gt_complex.real.max():.4f}\n"
        f"GT Imag:   min={gt_complex.imag.min():.4f}, max={gt_complex.imag.max():.4f}"
    )
    ax.text(0.05, 0.5, text, fontsize=12, va="center")
    pdf.savefig(fig)
    plt.close()

print(f"‚úÖ PDF saved at: {output_pdf_path}")


IndexError: list index out of range

In [None]:
import os
import glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

# -----------------------
# CONFIG (edit these)
# -----------------------
val_folder = r"D:\val_norm"           # folder with .h5 files
file_index = 0                         # which file in sorted list to use
requested = [18, 19, 20]                # list of slice indices you want saved as PNG
out_dir = "slice_pngs"                 # where PNGs will be saved
dpi = 300                              # output dpi (300 recommended)
# -----------------------

os.makedirs(out_dir, exist_ok=True)

# -----------------------
# HELPERS
# -----------------------
def to_complex(x):
    return x[..., 0] + 1j * x[..., 1]

def compute_metrics(gt, pred):
    eps = 1e-10
    nmse = np.linalg.norm(gt - pred) ** 2 / (np.linalg.norm(gt) ** 2 + eps)
    psnr = peak_signal_noise_ratio(np.abs(gt), np.abs(pred), data_range=np.abs(gt).max() - np.abs(gt).min())
    ssim = structural_similarity(
        np.abs(gt), np.abs(pred),
        data_range=np.abs(gt).max() - np.abs(gt).min(),
        win_size=9,
        gaussian_weights=False,
        use_sample_covariance=False,
        K1=0.01,
        K2=0.03
    )
    return nmse, psnr, ssim

def overlay_stats(ax, text):
    ax.text(
        0.05, 0.95, text,
        transform=ax.transAxes,
        ha="left", va="top",
        color="yellow", fontsize=14, fontweight="bold",
        bbox=dict(facecolor="black", alpha=0.5, pad=2)
    )

def plot_slice(gt, under, pred, slice_idx, pdf, max_val):
    domains = ['Real', 'Imag', 'Abs']
    gt_split = [gt.real, gt.imag, np.abs(gt)]
    under_split = [under.real, under.imag, np.abs(under)]
    pred_split = [pred.real, pred.imag, np.abs(pred)]

    # Compute error maps
    error_abs_under = [np.abs(gt_split[i] - under_split[i]) for i in range(3)]
    error_abs_pred = [np.abs(gt_split[i] - pred_split[i]) for i in range(3)]

    fig, axs = plt.subplots(3, 5, figsize=(28, 18))
    plt.suptitle(f"Slice {slice_idx}", fontsize=28, fontweight="bold")

    for row, domain in enumerate(domains):
        # Compute metrics
        nmse_u = compute_metrics(gt_split[row], under_split[row])[0]
        psnr_u = compute_metrics(gt_split[row], under_split[row])[1]
        ssim_u = compute_metrics(gt_split[row], under_split[row])[2]

        nmse_p = compute_metrics(gt_split[row], pred_split[row])[0]
        psnr_p = compute_metrics(gt_split[row], pred_split[row])[1]
        ssim_p = compute_metrics(gt_split[row], pred_split[row])[2]

        metric_text_under = f"NMSE: {nmse_u:.4f}\nPSNR: {psnr_u:.2f} dB\nSSIM: {ssim_u:.4f}"
        metric_text_pred = f"NMSE: {nmse_p:.4f}\nPSNR: {psnr_p:.2f} dB\nSSIM: {ssim_p:.4f}"

        vmax_abs = max(np.max(error_abs_under[row]), np.max(error_abs_pred[row]))
        if vmax_abs == 0:
            vmax_abs = 1e-8  # avoid vmin==vmax

        # Column 0: GT
        axs[row, 0].imshow(gt_split[row], cmap="gray", interpolation='nearest')
        axs[row, 0].axis("off")

        # Column 1: Undersampled
        axs[row, 1].imshow(under_split[row], cmap="gray", interpolation='nearest')
        axs[row, 1].axis("off")
        overlay_stats(axs[row, 1], metric_text_under)

        # Column 2: Reconstructed
        axs[row, 2].imshow(pred_split[row], cmap="gray", interpolation='nearest')
        axs[row, 2].axis("off")
        overlay_stats(axs[row, 2], metric_text_pred)

        # Column 3: |GT - Und|
        axs[row, 3].imshow(error_abs_under[row], cmap="hot", vmin=0, vmax=vmax_abs, interpolation='nearest')
        axs[row, 3].axis("off")

        # Column 4: |GT - Pred|
        axs[row, 4].imshow(error_abs_pred[row], cmap="hot", vmin=0, vmax=vmax_abs, interpolation='nearest')
        axs[row, 4].axis("off")

    fig.subplots_adjust(hspace=0.1, wspace=0.1)

    # Minimal change: either save into provided pdf, or save as PNG into out_dir
    if pdf is not None:
        pdf.savefig(fig)
        plt.close(fig)
    else:
        out_path = os.path.join(out_dir, f"slice_{slice_idx:03d}.png")
        fig.savefig(out_path, dpi=300, bbox_inches='tight', pad_inches=0.04)
        plt.close(fig)

# -----------------------
# LOAD DATA
# -----------------------
h5_list = sorted(glob.glob(os.path.join(val_folder, "*.h5")))
if len(h5_list) == 0:
    raise FileNotFoundError(f"No .h5 files found in {val_folder}")
if file_index < 0 or file_index >= len(h5_list):
    raise IndexError(f"file_index {file_index} out of range (0..{len(h5_list)-1})")

file_path = h5_list[file_index]
with h5py.File(file_path, 'r') as f:
    print("Keys in file:", list(f.keys()))
    image_full = f["image_full"][:]       # (slices, H, W, 2)
    image_under = f["image_under"][:]     # (slices, H, W, 2)
    max_val = float(f["max_val_full_image"][0]) if "max_val_full_image" in f else None

# -----------------------
# CONVERT TO COMPLEX
# -----------------------
gt_complex = to_complex(image_full)
under_complex = to_complex(image_under)

# -----------------------
# MODEL PREDICTION (slice-by-slice to match your earlier approach)
# -----------------------
if "model" not in globals():
    raise RuntimeError("No 'model' found in the session. Load your TF model into variable name `model` first.")

pred_slices = []
for i in range(image_under.shape[0]):
    input_slice = image_under[i:i+1].astype(np.float32)
    pred_slice = model.predict(input_slice, verbose=0)[0]
    pred_slices.append(pred_slice)
pred_slices = np.array(pred_slices)
pred_complex = to_complex(pred_slices)

# -----------------------
# SAVE SPECIFIC SLICES AS PNG
# -----------------------
num_slices = gt_complex.shape[0]
slice_indices = [i for i in requested if 0 <= i < num_slices]
if len(slice_indices) == 0:
    raise ValueError(f"No valid slices in requested={requested} for volume length {num_slices}")

for idx in slice_indices:
    plot_slice(gt_complex[idx], under_complex[idx], pred_complex[idx], idx, None, max_val)
    print(f"Saved slice_{idx:03d}.png to {out_dir}")

# Optional: save a summary image
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Volume Intensity Summary", fontsize=18)
ax.axis("off")
text = (
    f"GT Abs:    min={np.abs(gt_complex).min():.4f}, max={np.abs(gt_complex).max():.4f}\n"
    f"Under Abs: min={np.abs(under_complex).min():.4f}, max={np.abs(under_complex).max():.4f}\n"
    f"Pred Abs:  min={np.abs(pred_complex).min():.4f}, max={np.abs(pred_complex).max():.4f}\n\n"
    f"GT Real:   min={gt_complex.real.min():.4f}, max={gt_complex.real.max():.4f}\n"
    f"GT Imag:   min={gt_complex.imag.min():.4f}, max={gt_complex.imag.max():.4f}"
)
ax.text(0.05, 0.5, text, fontsize=12, va="center")
summary_out = os.path.join(out_dir, "summary.png")
fig.savefig(summary_out, dpi=300, bbox_inches='tight', pad_inches=0.04)
plt.close(fig)
print(f"‚úÖ Saved summary: {summary_out}")


In [None]:
import h5py
import glob, os

kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))
file_path = kspace_files_list_val[0]

with h5py.File(file_path, 'r') as f:
    print("Keys in file:", list(f.keys()))


In [None]:
def get_attention_maps(model):
    attn_maps = {}

    for layer in model.layers:
        if isinstance(layer, SF_UNet_TF):
            for name in ['skip4', 'skip3', 'skip2', 'skip1']:
                skip = getattr(layer, name)
                attn = skip.fsa.sa.last_attn_map
                if attn is not None:
                    attn_maps[name] = attn[0, :, :, 0, 0].numpy()  # [H, W]
    return attn_maps


In [None]:
import matplotlib.pyplot as plt

def plot_attention_grid(attn_maps):
    plt.figure(figsize=(16, 4))
    for i, (name, attn) in enumerate(attn_maps.items(), 1):
        plt.subplot(1, len(attn_maps), i)
        plt.imshow(attn, cmap='viridis')
        plt.title(name)
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
from matplotlib.backends.backend_pdf import PdfPages

def save_attention_maps_to_pdf(attn_maps, save_path="attention_maps_new_fsa.pdf"):
    with PdfPages(save_path) as pdf:
        for name, attn in attn_maps.items():
            plt.figure(figsize=(6, 6))
            plt.imshow(attn, cmap='viridis')
            plt.title(f"{name} Attention Map")
            plt.axis('off')
            pdf.savefig()
            plt.close()
    print(f"‚úÖ Saved attention maps to {save_path}")


In [None]:
# attn_maps = get_attention_maps(model)
# plot_attention_grid(attn_maps)
# save_attention_maps_to_pdf(attn_maps, "attention_maps_from_volume000_mse_mpca_fsa.pdf")


In [None]:
import tensorflow as tf
import numpy as np
import h5py
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# -- Extract attention maps from model, now including skip0 --
def get_attention_maps(model):
    attn_maps = {}
    for layer in model.layers:
        if isinstance(layer, SF_UNet_TF):
            for name in ['skip4', 'skip3', 'skip2', 'skip1']:  # ‚¨ÖÔ∏è added skip0
                skip = getattr(layer, name)
                if hasattr(skip.fsa.sa, 'last_attn_map') and skip.fsa.sa.last_attn_map is not None:
                    attn = skip.fsa.sa.last_attn_map
                    attn_maps[name] = attn[0, :, :, 0, 0].numpy()  # [H, W]
    return attn_maps

# -- Plot one row of attention maps --
def plot_attention_maps_for_slice(attn_maps, slice_idx):
    fig, axs = plt.subplots(1, len(attn_maps), figsize=(4 * len(attn_maps), 4))
    for i, (name, attn) in enumerate(attn_maps.items()):
        axs[i].imshow(attn, cmap='viridis')
        axs[i].set_title(name)
        axs[i].axis('off')
    fig.suptitle(f"Slice {slice_idx}", fontsize=14)
    return fig

# -- Run volume through model and save attention maps to PDF --
def process_volume_and_save_attention_pdf(h5_path, model, save_path="attention_all_slices.pdf", key='image_under'):
    with h5py.File(h5_path, 'r') as f, PdfPages(save_path) as pdf:
        volume = f[key][...]  # shape: [num_slices, H, W, 2] or [num_slices, H, W]
        if volume.ndim == 3:
            volume = np.stack([volume, np.zeros_like(volume)], axis=-1)  # Make complex

        num_slices = volume.shape[0]
        for i in range(num_slices):
            input_slice = np.expand_dims(volume[i], axis=0).astype(np.float32)  # [1, H, W, 2]
            _ = model(input_slice, training=False)

            attn_maps = get_attention_maps(model)
            fig = plot_attention_maps_for_slice(attn_maps, i)
            pdf.savefig(fig)
            plt.close(fig)

        print(f"‚úÖ Saved attention maps for all {num_slices} slices to: {save_path}")
process_volume_and_save_attention_pdf(
    h5_path="D:/val_norm/volume_000.h5",
    model=model,
    save_path="attention_maps_fsa.pdf",
    key='image_under'
)


In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ---------- Helpers ----------
def get_attention_maps(model):
    """Return dict of attention maps found on SF_UNet_TF skip layers."""
    attn_maps = {}
    for layer in model.layers:
        # If your model itself is the SF_UNet_TF instance, consider checking the model object directly
        if isinstance(layer, SF_UNet_TF):
            for name in ['skip4', 'skip3', 'skip2', 'skip1']:
                skip = getattr(layer, name)
                # safe attribute check
                if hasattr(skip, 'fsa') and hasattr(skip.fsa, 'sa') and hasattr(skip.fsa.sa, 'last_attn_map'):
                    attn = skip.fsa.sa.last_attn_map
                    if attn is not None:
                        # adapt indexing if your attn shape differs
                        attn_maps[name] = attn[0, :, :, 0, 0].numpy()  # [H, W]
    return attn_maps

def plot_attention_row(attn_maps, under_img, full_img, slice_idx, attn_cmap='viridis'):
    """
    attn_maps: Ordered dict-like (name->2D array)
    under_img: (H, W, 2) or (H, W) -> will be converted to magnitude for display
    full_img:  same as under_img (ground truth)
    """
    # Ensure magnitude images
    if under_img.ndim == 3 and under_img.shape[-1] == 2:
        under_disp = np.abs(under_img[...,0] + 1j * under_img[...,1])
    else:
        under_disp = np.abs(under_img)

    if full_img.ndim == 3 and full_img.shape[-1] == 2:
        full_disp = np.abs(full_img[...,0] + 1j * full_img[...,1])
    else:
        full_disp = np.abs(full_img)

    # Number of columns: under + attention maps + full
    n_attn = len(attn_maps)
    n_cols = 1 + max(0, n_attn) + 1

    fig, axs = plt.subplots(1, n_cols, figsize=(4 * n_cols, 4))
    if n_cols == 1:
        axs = np.array([axs])  # keep indexing consistent

    # Leftmost = under-sampled magnitude
    axs[0].imshow(under_disp, cmap='gray')
    axs[0].set_title("Under-sampled (Abs)")
    axs[0].axis('off')

    # Middle = attention maps in order
    for i, (name, attn) in enumerate(attn_maps.items(), start=1):
        ax = axs[i]
        # optional: set vmin/vmax for cross-slice consistency outside this function
        ax.imshow(attn, cmap=attn_cmap)
        ax.set_title(name)
        ax.axis('off')

    # Rightmost = full-sampled magnitude
    axs[-1].imshow(full_disp, cmap='gray')
    axs[-1].set_title("Full-sampled (Abs)")
    axs[-1].axis('off')

    fig.suptitle(f"Slice {slice_idx}", fontsize=14)
    plt.tight_layout()
    return fig

# ---------- Main processing & PDF saving ----------
def process_volume_and_save_attention_pdf(h5_path, model, save_path="attention_with_under_and_full.pdf",
                                         key_under='image_under', key_full='image_full'):
    with h5py.File(h5_path, 'r') as f, PdfPages(save_path) as pdf:
        # read volumes
        under_vol = f[key_under][...]   # expected shape: [num_slices, H, W, 2] or [num_slices, H, W]
        full_vol = f[key_full][...]     # same shape

        # ensure complex-like shape (H, W, 2) if stored as 3D
        if under_vol.ndim == 3:
            under_vol = np.stack([under_vol, np.zeros_like(under_vol)], axis=-1)
        if full_vol.ndim == 3:
            full_vol = np.stack([full_vol, np.zeros_like(full_vol)], axis=-1)

        num_slices = under_vol.shape[0]
        for i in range(num_slices):
            # prepare input for model: [1, H, W, 2], float32
            input_slice = np.expand_dims(under_vol[i], axis=0).astype(np.float32)

            # forward pass (this should populate last_attn_map inside the skip blocks)
            _ = model(input_slice, training=False)

            # extract attention maps
            attn_maps = get_attention_maps(model)  # dict: name -> 2D

            # plot row with under | attn maps... | full
            fig = plot_attention_row(attn_maps, under_vol[i], full_vol[i], slice_idx=i)
            pdf.savefig(fig)
            plt.close(fig)

        print(f"‚úÖ Saved attention maps + under + full for {num_slices} slices to: {save_path}")

# Example usage (update path & model as needed)
process_volume_and_save_attention_pdf(
    h5_path="D:/val_norm/volume_000.h5",
    model=model,
    save_path="attention_maps_with_under_full_mse_FSA.pdf",
    key_under='image_under',
    key_full='image_full'
)


In [None]:
import os   # <- small necessary addition

# ---------- Main: process specific slices and save PNGs ----------
def process_and_save_slices_png(h5_path, model, out_dir,
                                slices=[0], key_under='image_under', key_full='image_full',
                                dpi=300, figsize_per_col=4, attn_cmap='viridis'):
    """
    Minimal replacement to save specific slices as PNGs.
    - slices: list of slice indices to process (e.g. [0,10,20]) or "all"
    - out_dir: directory to save PNGs
    """
    os.makedirs(out_dir, exist_ok=True)

    with h5py.File(h5_path, 'r') as f:
        under_vol = f[key_under][...]   # [S, H, W, 2] or [S, H, W]
        full_vol = f[key_full][...]     # same

    # make sure shape is (S, H, W, 2)
    if under_vol.ndim == 3:
        under_vol = np.stack([under_vol, np.zeros_like(under_vol)], axis=-1)
    if full_vol.ndim == 3:
        full_vol = np.stack([full_vol, np.zeros_like(full_vol)], axis=-1)

    num_slices = under_vol.shape[0]
    if slices == "all":
        selected = list(range(num_slices))
    else:
        # filter valid indices
        selected = [int(s) for s in slices if 0 <= int(s) < num_slices]

    if len(selected) == 0:
        raise ValueError("No valid slices selected.")

    for idx in selected:
        inp = np.expand_dims(under_vol[idx], axis=0).astype(np.float32)
        # forward pass to populate attention maps
        _ = model(inp, training=False)

        attn_maps = get_attention_maps(model)  # same helper you already have

        # build figure (reuses your plot_attention_row which returns a fig)
        fig = plot_attention_row(attn_maps, under_vol[idx], full_vol[idx], slice_idx=idx, attn_cmap=attn_cmap)

        # Save PNG
        out_name = f"fsa_perp_slice_{idx:03d}.png"
        out_path = os.path.join(out_dir, out_name)
        fig.savefig(out_path, dpi=dpi, bbox_inches='tight', pad_inches=0.05)
        plt.close(fig)

    print(f"‚úÖ Saved {len(selected)} images to: {out_dir}")

# ---------- Example usage ----------
process_and_save_slices_png(
    h5_path="D:/val_norm/volume_000.h5",
    model=model,
    out_dir="attention_pngs",    # where PNGs will go
    slices=[18, 19, 20],         # <-- specify slices you want
    dpi=300,                    # change if you want higher/lower resolution
    figsize_per_col=4,          # not used here directly (plot_attention_row uses fixed figsize), kept for compatibility
    attn_cmap='viridis'
)


In [None]:
def get_attention_maps(model):
    attn_maps = {}
    for layer in model.layers:
        if isinstance(layer, SF_UNet_TF):
            for name in ['skip4', 'skip3', 'skip2', 'skip1']:
                skip = getattr(layer, name)
                attn = skip.fsa.sa.last_attn_map
                if attn is not None:
                    attn_maps[name] = attn[0, :, :, 0, 0].numpy()  # [H, W]
    return attn_maps


from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

def plot_attention_maps_for_slice(attn_maps, slice_idx):
    fig, axs = plt.subplots(1, len(attn_maps), figsize=(16, 4))
    for i, (name, attn) in enumerate(attn_maps.items()):
        axs[i].imshow(attn, cmap='viridis')
        axs[i].set_title(f"{name}")
        axs[i].axis('off')
    fig.suptitle(f"Slice {slice_idx}", fontsize=14)
    return fig


In [37]:
import h5py
import numpy as np

def process_volume_and_save_attention_pdf(h5_path, model, save_path="attention_all_slices_mpca_fsa_mse.pdf", key='image_under'):
    with h5py.File(h5_path, 'r') as f, PdfPages(save_path) as pdf:
        volume = f[key][...]  # shape: [num_slices, H, W, 2] or [num_slices, H, W]
        if volume.ndim == 3:  # Real-only, no complex
            volume = np.stack([volume, np.zeros_like(volume)], axis=-1)  # Add imaginary

        num_slices = volume.shape[0]

        for i in range(num_slices):
            input_slice = np.expand_dims(volume[i], axis=0).astype(np.float32)  # [1, H, W, 2]

            # Run forward pass
            _ = model(input_slice)

            # Extract attention maps
            attn_maps = get_attention_maps(model)

            # Plot and save figure
            fig = plot_attention_maps_for_slice(attn_maps, i)
            pdf.savefig(fig)
            plt.close(fig)

        print(f"‚úÖ Saved all {num_slices} slices' attention maps to {save_path}")


In [38]:
# model = build_dual_output_model()
# model.load_weights("weight_sfu_fastmri_complex_perploss_fsa.h5")

process_volume_and_save_attention_pdf(
    h5_path="G:/val_norm/val_norm/volume_000.h5",
    model=model,
    save_path="attention_maps_volume_000_mpca_fsa_perploss.pdf",
    key='image_under'  
)


‚úÖ Saved all 40 slices' attention maps to attention_maps_volume_000_mpca_fsa_perploss.pdf
