In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [3]:
mask = np.load(r"C:\Users\DU\aman_fastmri\Data\mask_4x_320_random.npy")  # Shape: (1, 320, 320)
#C:\Users\DU\aman_fastmri\Data
print("og_mask shape:", mask.shape)

og_mask shape: (1, 1, 320, 1)


In [4]:
import h5py
import numpy as np
import tensorflow as tf

class MRISliceGenerator(tf.keras.utils.Sequence):

    def __init__(self, file_list, batch_size=4, shuffle=True):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Open files once (important for stability & speed)
        self.files = [h5py.File(fp, 'r') for fp in self.file_list]

        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, f in enumerate(self.files):
            num_slices = f['image_under'].shape[0]
            for slice_idx in range(num_slices):
                self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):

        batch_map = self.slice_index_map[
            index * self.batch_size:(index + 1) * self.batch_size
        ]

        SZF_batch = []
        SGT_batch = []

        for file_idx, slice_idx in batch_map:
            f = self.files[file_idx]

            # Shape: (H, W, 2)  -> (real, imag)
            SZF_batch.append(f['image_under'][slice_idx])
            SGT_batch.append(f['image_full'][slice_idx])

        SZF_batch = np.stack(SZF_batch).astype(np.float32)
        SGT_batch = np.stack(SGT_batch).astype(np.float32)

        return SZF_batch, SGT_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)

    def __del__(self):
        for f in self.files:
            try:
                f.close()
            except Exception:
                pass


In [5]:
import h5py
import numpy as np
import tensorflow as tf

class MRISliceValGenerator(tf.keras.utils.Sequence):

    def __init__(self, file_list, batch_size=4, shuffle=False):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Open files once
        self.files = [h5py.File(fp, 'r') for fp in self.file_list]

        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, f in enumerate(self.files):
            num_slices = f['image_under'].shape[0]
            for slice_idx in range(num_slices):
                self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):

        batch_map = self.slice_index_map[
            index * self.batch_size:(index + 1) * self.batch_size
        ]

        SZF_batch = []
        SGT_batch = []
        MAXV_batch = []

        for file_idx, slice_idx in batch_map:
            f = self.files[file_idx]

            # Complex images (H, W, 2)
            SZF_batch.append(f['image_under'][slice_idx])
            SGT_batch.append(f['image_full'][slice_idx])

            # Volume-level max (same for all slices of that volume)
            MAXV_batch.append(f['max_val_full_image'][0])

        SZF_batch = np.stack(SZF_batch).astype(np.float32)
        SGT_batch = np.stack(SGT_batch).astype(np.float32)
        MAXV_batch = np.array(MAXV_batch).astype(np.float32)

        return SZF_batch, SGT_batch, MAXV_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)

    def __del__(self):
        for f in self.files:
            try:
                f.close()
            except Exception:
                pass


In [7]:
train_folder = r"D:\fastmri_singlecoil_FSSCAN\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"

In [8]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:]
kspace_files_list_val = kspace_files_list_val[:]

# # Create generators
# train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True,mask=mask)
# val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=2, shuffle=True)
val_gen = MRISliceValGenerator(kspace_files_list_val, batch_size=1, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


17349
7135


In [1]:
%run model.ipynb

In [5]:
model = RSCAGAN(in_channels=2, base_channels=32)
model.summary()

ValueError: This model has not yet been built. Build the model first by calling `build()` or by calling the model on a batch of data.

In [4]:

import numpy as np

def count_parameters_millions(model):
    trainable = np.sum([np.prod(v.shape) for v in model.trainable_variables])
    non_trainable = np.sum([np.prod(v.shape) for v in model.non_trainable_variables])
    total = trainable + non_trainable
    return (
        total / 1e6,
        trainable / 1e6,
        non_trainable / 1e6
    )

total_M, trainable_M, non_trainable_M = count_parameters_millions(model)

print("\n" + "=" * 40)
print(f"Total parameters:       {total_M:.3f} M")
print(f"Trainable parameters:   {trainable_M:.3f} M")
print(f"Non-trainable params:   {non_trainable_M:.3f} M")
print("=" * 40)
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def compute_flops(model, input_shape):
    """
    input_shape: tuple, e.g. (1, H, W, 2)
    returns FLOPs (float operations) for one forward pass
    """

    @tf.function
    def forward(x):
        return model(x)

    concrete_func = forward.get_concrete_function(
        tf.TensorSpec(input_shape, tf.float32)
    )

    frozen_func = convert_variables_to_constants_v2(concrete_func)
    graph_def = frozen_func.graph.as_graph_def()

    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name="")

        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

        flops = tf.compat.v1.profiler.profile(
            graph=graph,
            run_meta=run_meta,
            cmd="op",
            options=opts
        )

    return flops.total_float_ops
# Example: infer H, W from your data or define explicitly
H, W = 320,320

flops = compute_flops(model, input_shape=(1,1, H, W, 2))

print(f"FLOPs (single forward pass): {flops / 1e9:.2f} GFLOPs")



Total parameters:       0.000 M
Trainable parameters:   0.000 M
Non-trainable params:   0.000 M


OperatorNotAllowedInGraphError: in user code:

    File "C:\Users\DU\AppData\Local\Temp\ipykernel_23596\915890014.py", line 31, in forward  *
        return model(x)
    File "C:\Users\DU\anaconda3\envs\WNet\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None

    OperatorNotAllowedInGraphError: Exception encountered when calling layer "RSCA_GAN" "                 f"(type RSCAGAN).
    
    in user code:
    
        File "C:\Users\DU\AppData\Local\Temp\ipykernel_23596\1834071404.py", line 33, in call  *
            SZF, SGT = inputs
    
        OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
    
    
    Call arguments received by layer "RSCA_GAN" "                 f"(type RSCAGAN):
      â€¢ inputs=tf.Tensor(shape=(1, 1, 320, 320, 2), dtype=float32)
      â€¢ training=False


In [10]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

# --------------------------------------------------
# Directory Setup
# --------------------------------------------------
save_dir = "./SavedModels_RSCA_GAN_full_2"
os.makedirs(save_dir, exist_ok=True)

# --------------------------------------------------
# Configuration (PAPER-FAITHFUL)
# --------------------------------------------------
H, W = 320, 320        # paper resizes all images to 256x256

BATCH_SIZE     = 8   # ðŸ”‘ paper value
LEARNING_RATE  = 1e-4  # ðŸ”‘ paper value
BETA_1         = 0.9
BETA_2         = 0.999

EPOCHS         = 50   # ðŸ”‘ paper uses epochs

GEN_CKPT_PATH  = os.path.join(save_dir, "RSCA_GAN_Generator")
DISC_CKPT_PATH = os.path.join(save_dir, "RSCA_GAN_Discriminator")

print("=" * 60)
print("ðŸ”§ RSCA-GAN TRAINING CONFIGURATION (Paper-Aligned)")
print("=" * 60)
print(f" Save Directory:       {save_dir}")
print(f" Image Size:           {H} x {W}")
print(f" Batch Size:           {BATCH_SIZE}")
print(f" Learning Rate:        {LEARNING_RATE}")
print(f" Adam Betas:           ({BETA_1}, {BETA_2})")
print(f" Epochs:               {EPOCHS}")
print(f" Generator Ckpt:       {GEN_CKPT_PATH}")
print(f" Discriminator Ckpt:   {DISC_CKPT_PATH}")
print("=" * 60)


ðŸ”§ RSCA-GAN TRAINING CONFIGURATION (Paper-Aligned)
 Save Directory:       ./SavedModels_RSCA_GAN_full_2
 Image Size:           320 x 320
 Batch Size:           8
 Learning Rate:        0.0001
 Adam Betas:           (0.9, 0.999)
 Epochs:               50
 Generator Ckpt:       ./SavedModels_RSCA_GAN_full_2\RSCA_GAN_Generator
 Discriminator Ckpt:   ./SavedModels_RSCA_GAN_full_2\RSCA_GAN_Discriminator


In [11]:
model = RSCAGAN(in_channels=2, base_channels=32)

generator = model.generator
discriminator = model.discriminator


In [12]:
gen_optimizer = Adam(
    learning_rate=LEARNING_RATE,
    beta_1=BETA_1,
    beta_2=BETA_2,
    # clipnorm=1.0
)

"""disc_optimizer = Adam(
    #learning_rate=LEARNING_RATE,
    learning_rate=8e-5,
    beta_1=BETA_1,
    beta_2=BETA_2,
    clipnorm=1.0
)"""
disc_optimizer = Adam(
    learning_rate=8e-5,   
    beta_1=0.9,
    beta_2=0.999
)


In [13]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(D_real, D_fake):
   # real_loss = bce(tf.ones_like(D_real), D_real)
   # fake_loss = bce(tf.zeros_like(D_fake), D_fake)
    real_labels = tf.ones_like(D_real) * 0.9
    fake_labels = tf.zeros_like(D_fake) + 0.1

    real_loss = bce(real_labels, D_real)
    fake_loss = bce(fake_labels, D_fake)

    return real_loss + fake_loss

def generator_adversarial_loss(D_fake):
    return bce(tf.ones_like(D_fake), D_fake)

def image_l1_loss(SGT, SRE):
    return tf.reduce_mean(tf.abs(SGT - SRE))

def frequency_l1_loss(SGT, SRE, mask):
    r_gt  = fft2c_tf(SGT)
    r_rec = fft2c_tf(SRE)
    return tf.reduce_mean(tf.abs(r_gt * mask - r_rec * mask))


In [14]:
# def compute_nmse(gt, pred):
#     return tf.reduce_sum(tf.square(gt - pred)) / tf.reduce_sum(tf.square(gt))

# def compute_psnr(gt, pred):
#     return tf.reduce_mean(tf.image.psnr(gt, pred, max_val=1.0))


# def compute_ssim(gt, pred):
#     return tf.reduce_mean(
#         tf.image.ssim(gt, pred, max_val=1.0)
#     )



In [15]:
@tf.function
def train_step(SZF, SGT, mask):

    # ======================
    # Train Discriminator
    # ======================
    with tf.GradientTape() as d_tape:
        SRE = generator(SZF, training=True)

        D_real = discriminator(SGT, training=True)
        D_fake = discriminator(tf.stop_gradient(SRE), training=True)

        d_loss = discriminator_loss(D_real, D_fake)

    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(
        zip(d_grads, discriminator.trainable_variables)
    )

    # ======================
    # Train Generator
    # ======================
    with tf.GradientTape() as g_tape:
        SRE = generator(SZF, training=True)

        D_fake = discriminator(SRE, training=True)

        adv_loss  = generator_adversarial_loss(D_fake)
        img_loss  = image_l1_loss(SGT, SRE)
        freq_loss = frequency_l1_loss(SGT, SRE, mask)

        g_loss = 0.001 * adv_loss + freq_loss + 10.0 * img_loss

    g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(
        zip(g_grads, generator.trainable_variables)
    )

    return d_loss, g_loss, adv_loss, freq_loss, img_loss


In [21]:
def complex_to_mag(x):
    return tf.sqrt(x[..., 0]**2 + x[..., 1]**2 + 1e-8)
@tf.function
def validation_step(SZF, SGT, mask, MAXV):

    # Generator inference
    SRE = generator(SZF, training=False)

    # ======================
    # Reconstruction loss (NO adversarial)
    # ======================
    img_loss  = image_l1_loss(SGT, SRE)
    freq_loss = frequency_l1_loss(SGT, SRE, mask)

    total_loss = freq_loss + 10.0 * img_loss

    # ======================
    # Metrics (magnitude domain)
    # ======================
    gt_mag  = complex_to_mag(SGT)
    rec_mag = complex_to_mag(SRE)
    print("gt_mag",gt_mag.shape)
    print("rec_mag",rec_mag.shape)
    

    # Denormalize using volume-level MAXV
    MAXV = tf.reshape(MAXV, (-1, 1, 1))
    gt_mag  = gt_mag * MAXV
    rec_mag = rec_mag * MAXV
    gt_mag  = tf.expand_dims(gt_mag, axis=-1)   # [B,H,W,1]
    rec_mag = tf.expand_dims(rec_mag, axis=-1)  # [B,H,W,1]


    psnr = tf.reduce_mean(
        tf.image.psnr(gt_mag, rec_mag, max_val=MAXV)
    )

    ssim = tf.reduce_mean(
        tf.image.ssim(gt_mag, rec_mag, max_val=MAXV)
    )

    nmse = tf.reduce_sum(
        tf.square(gt_mag - rec_mag)
    ) / tf.reduce_sum(tf.square(gt_mag))

    return total_loss, psnr, ssim, nmse


In [22]:
def run_validation(val_gen, mask):
    val_loss = []
    val_psnr = []
    val_ssim = []
    val_nmse = []

    for SZF, SGT, MAXV in val_gen:
        loss, psnr, ssim, nmse = validation_step(SZF, SGT, mask, MAXV)


        val_loss.append(loss)
        val_psnr.append(psnr)
        val_ssim.append(ssim)
        val_nmse.append(nmse)

    return (
        tf.reduce_mean(val_loss),
        tf.reduce_mean(val_psnr),
        tf.reduce_mean(val_ssim),
        tf.reduce_mean(val_nmse)
    )


In [23]:
# >>> ADD <<<
epoch_counter = tf.Variable(0, dtype=tf.int64)

ckpt = tf.train.Checkpoint(
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=gen_optimizer,
    disc_optimizer=disc_optimizer,
    epoch=epoch_counter
)

ckpt_manager = tf.train.CheckpointManager(
    ckpt,
    directory=save_dir,
    max_to_keep=5
)

best_val_loss = float("inf")


In [24]:

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
    start_epoch = int(epoch_counter.numpy())
    print(f"âœ… Resuming from epoch {start_epoch}")
else:
    start_epoch = 0
    print("ðŸ†• Training from scratch")


ðŸ†• Training from scratch


In [25]:
for epoch in range(start_epoch, EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{EPOCHS} =====")

    # ======================
    # Training
    # ======================
    for step, (SZF, SGT) in enumerate(train_gen):

        d_loss, g_loss, adv, freq, img = train_step(
            SZF, SGT, mask
        )

        if step % 1000 == 0:
            print(
                f"[Train {step}] "
                f"D: {d_loss:.4f} | "
                f"G: {g_loss:.4f} | "
                f"Adv: {adv:.4f} | "
                f"F: {freq:.4f} | "
                f"I: {img:.4f}"
            )

    # ======================
    # Validation
    # ======================
    val_loss, val_psnr, val_ssim, val_nmse = run_validation(val_gen, mask)

    print("\nðŸ“Š Validation Results")
    print(f" Val Loss : {val_loss:.4f}")
    print(f" Val PSNR : {val_psnr:.2f}")
    print(f" Val SSIM : {val_ssim:.4f}")
    print(f" Val NMSE : {val_nmse:.4f}")

    # Save epoch counter
    epoch_counter.assign(epoch + 1)

    # ======================
    # Save BEST model
    # ======================
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_manager.save()
        print(f"ðŸ”¥ Saved BEST model (Val Loss = {val_loss:.4f})")



===== Epoch 1/50 =====
[Train 0] D: 0.8204 | G: 0.2787 | Adv: 1.4261 | F: 0.0043 | I: 0.0273
[Train 1000] D: 0.8385 | G: 0.2041 | Adv: 2.0108 | F: 0.0032 | I: 0.0199
[Train 2000] D: 0.9586 | G: 0.1808 | Adv: 1.5680 | F: 0.0028 | I: 0.0176
[Train 3000] D: 1.0475 | G: 0.1058 | Adv: 2.0778 | F: 0.0017 | I: 0.0102
[Train 4000] D: 0.8246 | G: 0.2453 | Adv: 1.1141 | F: 0.0039 | I: 0.0240
[Train 5000] D: 1.0908 | G: 0.1162 | Adv: 1.9670 | F: 0.0018 | I: 0.0112
[Train 6000] D: 0.8894 | G: 0.2364 | Adv: 1.2889 | F: 0.0039 | I: 0.0231
[Train 7000] D: 1.0552 | G: 0.1453 | Adv: 2.0084 | F: 0.0023 | I: 0.0141
[Train 8000] D: 0.8295 | G: 0.1948 | Adv: 1.3593 | F: 0.0032 | I: 0.0190
[Train 9000] D: 0.9029 | G: 0.1836 | Adv: 2.4953 | F: 0.0029 | I: 0.0178
[Train 10000] D: 2.2569 | G: 0.1257 | Adv: 2.8443 | F: 0.0021 | I: 0.0121
[Train 11000] D: 0.9070 | G: 0.1613 | Adv: 1.4189 | F: 0.0026 | I: 0.0157
[Train 12000] D: 0.8209 | G: 0.1533 | Adv: 1.8302 | F: 0.0024 | I: 0.0149
[Train 13000] D: 0.8183 | G

KeyboardInterrupt: 