In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
mask = np.load(r"C:\Users\DU\aman_fastmri\Data\mask_4x_320_random.npy")  # Shape: (1, 320, 320)
#C:\Users\DU\aman_fastmri\Data
print("og_mask shape:", mask.shape)

og_mask shape: (1, 1, 320, 1)


In [3]:
class MRISliceGenerator(tf.keras.utils.Sequence):

    def __init__(self, file_list, batch_size=4, shuffle=True):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, file_path in enumerate(self.file_list):
            with h5py.File(file_path, 'r') as f:
                num_slices = f['image_under'].shape[0]
                for slice_idx in range(num_slices):
                    self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):

        batch_map = self.slice_index_map[
            index * self.batch_size:(index + 1) * self.batch_size
        ]

        SZF_batch = []
        SGT_batch = []
        MAX_batch = []

        for file_idx, slice_idx in batch_map:
            with h5py.File(self.file_list[file_idx], 'r') as f:
                SZF_batch.append(f['image_under'][slice_idx])
                SGT_batch.append(f['image_full'][slice_idx])
                MAX_batch.append(f['max_val_full_image'][0])


        SZF_batch = np.stack(SZF_batch).astype(np.float32)
        SGT_batch = np.stack(SGT_batch).astype(np.float32)
        MAX_batch = np.array(MAX_batch).astype(np.float32)

        return SZF_batch, SGT_batch, MAX_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)


In [4]:
train_folder = r"D:\fastmri_singlecoil_FSSCAN\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"

In [5]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:]
kspace_files_list_val = kspace_files_list_val[:]

# # Create generators
# train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True,mask=mask)
# val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=2, shuffle=True)
val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=1, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


17349
7135


In [6]:
%run model.ipynb

In [7]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

# --------------------------------------------------
# Directory Setup
# --------------------------------------------------
save_dir = "./SavedModels_RSCA_GAN_full"
os.makedirs(save_dir, exist_ok=True)

# --------------------------------------------------
# Configuration (PAPER-FAITHFUL)
# --------------------------------------------------
H, W = 320, 320        # paper resizes all images to 256x256

BATCH_SIZE     = 8   # ðŸ”‘ paper value
LEARNING_RATE  = 1e-4  # ðŸ”‘ paper value
BETA_1         = 0.9
BETA_2         = 0.999

EPOCHS         = 50   # ðŸ”‘ paper uses epochs

GEN_CKPT_PATH  = os.path.join(save_dir, "RSCA_GAN_Generator")
DISC_CKPT_PATH = os.path.join(save_dir, "RSCA_GAN_Discriminator")

print("=" * 60)
print("ðŸ”§ RSCA-GAN TRAINING CONFIGURATION (Paper-Aligned)")
print("=" * 60)
print(f" Save Directory:       {save_dir}")
print(f" Image Size:           {H} x {W}")
print(f" Batch Size:           {BATCH_SIZE}")
print(f" Learning Rate:        {LEARNING_RATE}")
print(f" Adam Betas:           ({BETA_1}, {BETA_2})")
print(f" Epochs:               {EPOCHS}")
print(f" Generator Ckpt:       {GEN_CKPT_PATH}")
print(f" Discriminator Ckpt:   {DISC_CKPT_PATH}")
print("=" * 60)


ðŸ”§ RSCA-GAN TRAINING CONFIGURATION (Paper-Aligned)
 Save Directory:       ./SavedModels_RSCA_GAN_full
 Image Size:           320 x 320
 Batch Size:           8
 Learning Rate:        0.0001
 Adam Betas:           (0.9, 0.999)
 Epochs:               50
 Generator Ckpt:       ./SavedModels_RSCA_GAN_full\RSCA_GAN_Generator
 Discriminator Ckpt:   ./SavedModels_RSCA_GAN_full\RSCA_GAN_Discriminator


In [8]:
model = RSCAGAN(in_channels=2, base_channels=64)

generator = model.generator
discriminator = model.discriminator


In [9]:
gen_optimizer = Adam(
    learning_rate=LEARNING_RATE,
    beta_1=BETA_1,
    beta_2=BETA_2,
    clipnorm=1.0
)

"""disc_optimizer = Adam(
    #learning_rate=LEARNING_RATE,
    learning_rate=8e-5,
    beta_1=BETA_1,
    beta_2=BETA_2,
    clipnorm=1.0
)"""
disc_optimizer = Adam(
    learning_rate=8e-5,   
    beta_1=0.9,
    beta_2=0.999
)


In [10]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(D_real, D_fake):
   # real_loss = bce(tf.ones_like(D_real), D_real)
   # fake_loss = bce(tf.zeros_like(D_fake), D_fake)
    real_labels = tf.ones_like(D_real) * 0.9
    fake_labels = tf.zeros_like(D_fake) + 0.1

    real_loss = bce(real_labels, D_real)
    fake_loss = bce(fake_labels, D_fake)

    return real_loss + fake_loss

def generator_adversarial_loss(D_fake):
    return bce(tf.ones_like(D_fake), D_fake)

def image_l1_loss(SGT, SRE):
    return tf.reduce_mean(tf.abs(SGT - SRE))

def frequency_l1_loss(SGT, SRE, mask):
    r_gt  = fft2c_tf(SGT)
    r_rec = fft2c_tf(SRE)
    return tf.reduce_mean(tf.abs(r_gt * mask - r_rec * mask))


In [11]:
def compute_nmse(gt, pred):
    return tf.reduce_sum(tf.square(gt - pred)) / tf.reduce_sum(tf.square(gt))

def compute_psnr(gt, pred):
    return tf.reduce_mean(tf.image.psnr(gt, pred, max_val=1.0))


def compute_ssim(gt, pred):
    return tf.reduce_mean(
        tf.image.ssim(gt, pred, max_val=1.0)
    )



In [12]:
@tf.function
def train_step(SZF, SGT, mask, MAXV):

    # -------- Train Discriminator --------
    with tf.GradientTape() as d_tape:
        SRE = generator(SZF, training=True)

        MAXV = tf.reshape(MAXV, (-1,1,1,1))

        # Normalize
        SGT_n = SGT / MAXV
        SRE_n = SRE / MAXV

        # âœ” Correct: normalized inputs to D
        D_real = discriminator(SGT_n, training=True)
        D_fake = discriminator(tf.stop_gradient(SRE_n), training=True)

        d_loss = discriminator_loss(D_real, D_fake)

    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(
        zip(d_grads, discriminator.trainable_variables)
    )

    # -------- Train Generator --------
    with tf.GradientTape() as g_tape:
        SRE = generator(SZF, training=True)

        # Clip
        SRE = tf.clip_by_value(SRE, 0.0, MAXV)

        # Normalize AGAIN (important!)
        SRE_n = SRE / MAXV

        # âœ” Correct: normalized input to D
        D_fake = discriminator(SRE_n, training=True)

        adv_loss  = generator_adversarial_loss(D_fake)
        img_loss  = image_l1_loss(SGT_n, SRE_n)
        freq_loss = frequency_l1_loss(SGT_n, SRE_n, mask)

        g_loss = 0.001 * adv_loss + freq_loss + 10.0 * img_loss

    g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(
        zip(g_grads, generator.trainable_variables)
    )

    return d_loss, g_loss, adv_loss, freq_loss, img_loss


In [13]:
"""@tf.function
def validation_step(SZF, SGT, mask, MAXV):

    SRE = generator(SZF, training=False)

    MAXV = tf.reshape(MAXV, (-1,1,1,1))

    SGT_n = SGT / MAXV
    SRE_n = SRE / MAXV

    D_fake = discriminator(SRE_n, training=False)
    


    adv_loss  = generator_adversarial_loss(D_fake)
    img_loss  = image_l1_loss(SGT_n, SRE_n)
    freq_loss = frequency_l1_loss(SGT_n, SRE_n, mask)

    total_loss = adv_loss + freq_loss + 10.0 * img_loss

    psnr = compute_psnr(SGT_n, SRE_n)
    ssim = compute_ssim(SGT_n, SRE_n)
    nmse = compute_nmse(SGT_n, SRE_n)

    return total_loss, psnr, ssim, nmse
"""
@tf.function
def validation_step(SZF, SGT, mask, MAXV):

    SRE = generator(SZF, training=False)

    MAXV = tf.reshape(MAXV, (-1,1,1,1))

    # Avoid divide by zero
    MAXV = tf.maximum(MAXV, 1e-6)

    SGT_n = SGT / MAXV
    SRE_n = SRE / MAXV

    # Clip
    SGT_n = tf.clip_by_value(SGT_n, 0.0, 1.0)
    SRE_n = tf.clip_by_value(SRE_n, 0.0, 1.0)

    # Metrics
    psnr = tf.reduce_mean(
        tf.image.psnr(SGT_n, SRE_n, max_val=1.0)
    )

    ssim = tf.reduce_mean(
        tf.image.ssim(SGT_n, SRE_n, max_val=1.0)
    )

    nmse = tf.reduce_sum(
        tf.square(SGT_n - SRE_n)
    ) / tf.reduce_sum(tf.square(SGT_n))

    # Loss (same as train)
    D_fake = discriminator(SRE_n, training=False)

    adv_loss  = generator_adversarial_loss(D_fake)
    img_loss  = image_l1_loss(SGT_n, SRE_n)
    freq_loss = frequency_l1_loss(SGT_n, SRE_n, mask)

    total_loss = 0.005 * adv_loss + freq_loss + 10.0 * img_loss

    return total_loss, psnr, ssim, nmse


In [14]:
def run_validation(val_gen, mask):
    val_loss = []
    val_psnr = []
    val_ssim = []
    val_nmse = []

    for SZF, SGT, MAXV in val_gen:
        loss, psnr, ssim, nmse = validation_step(SZF, SGT, mask, MAXV)


        val_loss.append(loss)
        val_psnr.append(psnr)
        val_ssim.append(ssim)
        val_nmse.append(nmse)

    return (
        tf.reduce_mean(val_loss),
        tf.reduce_mean(val_psnr),
        tf.reduce_mean(val_ssim),
        tf.reduce_mean(val_nmse)
    )


In [15]:
# >>> ADD <<<
epoch_counter = tf.Variable(0, dtype=tf.int64)

ckpt = tf.train.Checkpoint(
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=gen_optimizer,
    disc_optimizer=disc_optimizer,
    epoch=epoch_counter
)

ckpt_manager = tf.train.CheckpointManager(
    ckpt,
    directory=save_dir,
    max_to_keep=5
)

best_val_loss = float("inf")


In [16]:
][ipjuhy
   mk >>> PLACE THIS HERE <<<
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
    start_epoch = int(epoch_counter.numpy())
    print(f"âœ… Resuming from epoch {start_epoch}")
else:
    start_epoch = 0
    print("ðŸ†• Training from scratch")


âœ… Resuming from epoch 4


In [17]:
for epoch in range(start_epoch, EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{EPOCHS} =====")

    # ======================
    # Training
    # ======================
    for step, (SZF, SGT, MAXV) in enumerate(train_gen):

        d_loss, g_loss, adv, freq, img = train_step(
            SZF, SGT, mask, MAXV
        )

        if step % 1000 == 0:
            print(
                f"[Train {step}] "
                f"D: {d_loss:.4f} | "
                f"G: {g_loss:.4f} | "
                f"Adv: {adv:.4f} | "
                f"F: {freq:.4f} | "
                f"I: {img:.4f}"
            )

    # ======================
    # Validation
    # ======================
    val_loss, val_psnr, val_ssim, val_nmse = run_validation(val_gen, mask)

    print("\nðŸ“Š Validation Results")
    print(f" Val Loss : {val_loss:.4f}")
    print(f" Val PSNR : {val_psnr:.2f}")
    print(f" Val SSIM : {val_ssim:.4f}")
    print(f" Val NMSE : {val_nmse:.4f}")

    # Save epoch counter
    epoch_counter.assign(epoch + 1)

    # ======================
    # Save BEST model
    # ======================
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_manager.save()
        print(f"ðŸ”¥ Saved BEST model (Val Loss = {val_loss:.4f})")



===== Epoch 5/50 =====
[Train 0] D: 48.2926 | G: 0.0023 | Adv: 0.1074 | F: 0.0007 | I: 0.0001
[Train 1000] D: 19.0455 | G: 0.0028 | Adv: 0.1060 | F: 0.0007 | I: 0.0002
[Train 2000] D: 9.9197 | G: 0.0027 | Adv: 0.1020 | F: 0.0007 | I: 0.0002
[Train 3000] D: 8.1066 | G: 0.0027 | Adv: 0.1432 | F: 0.0007 | I: 0.0002
[Train 4000] D: 3.1589 | G: 0.0027 | Adv: 0.1195 | F: 0.0007 | I: 0.0002
[Train 5000] D: 13.6206 | G: 0.0039 | Adv: 0.1651 | F: 0.0007 | I: 0.0003
[Train 6000] D: 4.7320 | G: 0.0032 | Adv: 0.1245 | F: 0.0007 | I: 0.0002
[Train 7000] D: 1013.4033 | G: 0.0084 | Adv: 0.0884 | F: 0.0022 | I: 0.0006
[Train 8000] D: 32.5183 | G: 0.0090 | Adv: 0.1564 | F: 0.0017 | I: 0.0007
[Train 9000] D: 9.9809 | G: 0.0040 | Adv: 0.1376 | F: 0.0014 | I: 0.0002
[Train 10000] D: 102.2062 | G: 0.0045 | Adv: 0.1468 | F: 0.0010 | I: 0.0003
[Train 11000] D: 114.6001 | G: 0.0067 | Adv: 0.1134 | F: 0.0019 | I: 0.0005
[Train 12000] D: 11.4457 | G: 0.0043 | Adv: 0.1520 | F: 0.0011 | I: 0.0003
[Train 13000] D