In [28]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [29]:
mask = np.load(r"C:\Users\DU\aman_fastmri\Data\mask_4x_320_random.npy")  # Shape: (1, 320, 320)
#C:\Users\DU\aman_fastmri\Data
print("og_mask shape:", mask.shape)

og_mask shape: (1, 1, 320, 1)


In [30]:
%run model.ipynb

In [31]:
model = RSCAGAN(in_channels=2, base_channels=32)
generator = model.generator


In [32]:
import tensorflow as tf
import os

# --------------------------------------------------
# Paths
# --------------------------------------------------
SAVE_DIR = "./SavedModels_RSCA_GAN_full_2"

# --------------------------------------------------
# Build model (EXACT SAME)
# --------------------------------------------------
model = RSCAGAN(in_channels=2, base_channels=32)
generator = model.generator

# --------------------------------------------------
# Restore checkpoint
# --------------------------------------------------
ckpt = tf.train.Checkpoint(generator=generator)

ckpt_manager = tf.train.CheckpointManager(
    ckpt,
    directory=SAVE_DIR,
    max_to_keep=5
)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
    print("✅ Generator weights restored")
else:
    raise RuntimeError("❌ No checkpoint found for inference")


✅ Generator weights restored


In [33]:
import h5py
import numpy as np
import tensorflow as tf

class MRISliceValGenerator(tf.keras.utils.Sequence):

    def __init__(self, file_list, batch_size=4, shuffle=False):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Open files once
        self.files = [h5py.File(fp, 'r') for fp in self.file_list]

        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, f in enumerate(self.files):
            num_slices = f['image_under'].shape[0]
            for slice_idx in range(num_slices):
                self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):
    
        batch_map = self.slice_index_map[
            index * self.batch_size:(index + 1) * self.batch_size
        ]
    
        SZF_batch = []
        SGT_batch = []
        MAXV_batch = []
        VOL_batch = []
        SLICE_batch = []
    
        for file_idx, slice_idx in batch_map:
            f = self.files[file_idx]
    
            SZF_batch.append(f['image_under'][slice_idx])
            SGT_batch.append(f['image_full'][slice_idx])
            MAXV_batch.append(f['max_val_full_image'][0])
    
            VOL_batch.append(file_idx)
            SLICE_batch.append(slice_idx)
    
        return (
            np.stack(SZF_batch).astype(np.float32),
            np.stack(SGT_batch).astype(np.float32),
            np.array(MAXV_batch).astype(np.float32),
            np.array(VOL_batch),
            np.array(SLICE_batch),
        )


In [34]:
import h5py
import numpy as np
import glob
import os
train_folder = r"D:\fastmri_singlecoil_FSSCAN\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_val = kspace_files_list_val[:]

# # Create generators
# train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True,mask=mask)
# val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)

val_gen = MRISliceValGenerator(kspace_files_list_val, batch_size=1, shuffle=False)

 
print(len(val_gen))  


7135


In [35]:
def complex_to_mag(x):
    return tf.sqrt(x[..., 0]**2 + x[..., 1]**2 + 1e-8)


In [36]:
@tf.function
def inference_metrics_step_slice(SZF, SGT, MAXV):

    SRE = generator(SZF, training=False)

    gt_mag  = complex_to_mag(SGT)
    rec_mag = complex_to_mag(SRE)

    MAXV = tf.reshape(MAXV, (-1, 1, 1))
    gt_mag  = gt_mag * MAXV
    rec_mag = rec_mag * MAXV

    gt_mag  = tf.expand_dims(gt_mag, axis=-1)
    rec_mag = tf.expand_dims(rec_mag, axis=-1)

    ssim = tf.image.ssim(gt_mag, rec_mag, max_val=MAXV)

    # PSNR and NMSE are still useful for volume aggregation
    psnr = tf.image.psnr(gt_mag, rec_mag, max_val=MAXV)

    nmse = tf.reduce_sum(
        tf.square(gt_mag - rec_mag), axis=[1,2,3]
    ) / tf.reduce_sum(
        tf.square(gt_mag), axis=[1,2,3]
    )

    return psnr, ssim, nmse


In [37]:
import pandas as pd
from collections import defaultdict

def run_evaluation_and_save(val_gen):

    # -----------------------
    # Storage
    # -----------------------
    volume_psnr = defaultdict(list)
    volume_nmse = defaultdict(list)

    slice_ssim_records = []

    # -----------------------
    # Loop
    # -----------------------
    for SZF, SGT, MAXV, VOL, SLICE in val_gen:

        psnr, ssim, nmse = inference_metrics_step_slice(
            SZF, SGT, MAXV
        )

        psnr = psnr.numpy()
        ssim = ssim.numpy()
        nmse = nmse.numpy()

        for i in range(len(VOL)):
            vol_id   = int(VOL[i])
            slice_id = int(SLICE[i])

            volume_psnr[vol_id].append(psnr[i])
            volume_nmse[vol_id].append(nmse[i])

            slice_ssim_records.append({
                "volume_id": vol_id,
                "slice_id": slice_id,
                "SSIM": float(ssim[i])
            })

    # -----------------------
    # Volume-wise aggregation
    # -----------------------
    volume_records = []

    for vol_id in volume_psnr.keys():
        volume_records.append({
            "volume_id": vol_id,
            "PSNR": np.mean(volume_psnr[vol_id]),
            "NMSE": np.mean(volume_nmse[vol_id])
        })

    # -----------------------
    # Save to Excel
    # -----------------------
    df_volume = pd.DataFrame(volume_records)
    df_slice  = pd.DataFrame(slice_ssim_records)

    df_volume.to_excel(
        "volume_wise_PSNR_NMSE.xlsx",
        index=False
    )

    df_slice.to_excel(
        "slice_wise_SSIM.xlsx",
        index=False
    )

    print("✅ Saved:")
    print("  - volume_wise_PSNR_NMSE.xlsx")
    print("  - slice_wise_SSIM.xlsx")

    return df_volume, df_slice


In [38]:
df_volume, df_slice = run_evaluation_and_save(val_gen)


✅ Saved:
  - volume_wise_PSNR_NMSE.xlsx
  - slice_wise_SSIM.xlsx
