In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
class MRISliceGenerator(tf.keras.utils.Sequence):
    def __init__(self, file_list, batch_size=4, shuffle=True, use_dc=False):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.use_dc = use_dc   # NEW FLAG
        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, file_path in enumerate(self.file_list):
            with h5py.File(file_path, 'r') as f:
                num_slices = f['image_under'].shape[0]
                for slice_idx in range(num_slices):
                    self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):
        batch_map = self.slice_index_map[index * self.batch_size:(index + 1) * self.batch_size]

        input_img_batch = []
        target_img_batch = []
        input_kspace_batch = []

        for file_idx, slice_idx in batch_map:
            with h5py.File(self.file_list[file_idx], 'r') as f:
                input_img = f['image_under'][slice_idx]       # (H,W,2)
                target_img = f['image_full'][slice_idx]       # (H,W,2)
                input_kspace = f['kspace_under'][slice_idx]   # (H,W,2)

                input_img_batch.append(input_img)
                target_img_batch.append(target_img)
                input_kspace_batch.append(input_kspace)

        x_img = np.stack(input_img_batch, axis=0)
        x_kspace = np.stack(input_kspace_batch, axis=0)
        y_batch = np.stack(target_img_batch, axis=0)

        if self.use_dc:
            # DSMENet expects two inputs when DC is used
            return [x_img, x_kspace], y_batch
        else:
            # Only image input (ZF)
            return x_img, y_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)


In [3]:
train_folder = r"E:\DATA\fastmri_single_coil_FSSCAN_4x\train_norm"
val_folder = r"E:\DATA\fastmri_single_coil_FSSCAN_4x\val_norm"

In [4]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:half_train]
kspace_files_list_val = kspace_files_list_val[:half_val]

# Create generators
# train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True,mask=mask)
# val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True)
val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


2171
1784


In [5]:
%run model.ipynb

Model: "DSMENet_Functional"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 256, 256, 2  0           []                               
                                )]                                                                
                                                                                                  
 SRUN_1 (SRUN)                  ((None, 256, 256, 1  542204      ['input_image[0][0]']            
                                6),                                                               
                                 (None, 256, 256, 2                                               
                                ))                                                                
                                                                                 

In [6]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

# ------------------------------
# Basic losses
# ------------------------------

def ssim_loss(x, y):
    return 1.0 - tf.reduce_mean(tf.image.ssim(x, y, max_val=1.0))

def l1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

def nmse(pred, target):
    return tf.reduce_sum(tf.square(pred - target)) / tf.reduce_sum(tf.square(target))

# ------------------------------
# DSMENet loss (REWEIGHTED)
# ------------------------------

def dmse_loss(
    F_first,
    F_final,
    target,
    alpha=0.5,   # ‚Üì SSIM (final)
    beta=0.5,    # ‚Üì SSIM (intermediate)
    gamma=10.0   # ‚Üë pixel loss
):
    Lroc = ssim_loss(F_final, target)
    Lerc = ssim_loss(F_first, target)
    Lmps = l1_loss(F_final, target)

    total = alpha * Lroc + beta * Lerc + gamma * Lmps
    return total, Lroc, Lerc, Lmps

# ==============================
# Learning Rate Scheduler (StepLR)
# ==============================
# ------------------------------
# Training step
# ------------------------------

def train_step(model, optimizer, x, y):
    with tf.GradientTape() as tape:
        F_first, F_final = model(x, training=True)
        total_loss, Lroc, Lerc, Lmps = dmse_loss(F_first, F_final, y)

    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    ssim_val = tf.reduce_mean(tf.image.ssim(F_final, y, max_val=1.0))
    psnr_val = tf.reduce_mean(tf.image.psnr(F_final, y, max_val=1.0))
    nmse_val = nmse(F_final, y)

    return total_loss, ssim_val, psnr_val, nmse_val


# ------------------------------
# Validation step
# ------------------------------

def val_step(model, x, y):
    F_first, F_final = model(x, training=False)
    total_loss, _, _, _ = dmse_loss(F_first, F_final, y)

    ssim_val = tf.reduce_mean(tf.image.ssim(F_final, y, max_val=1.0))
    psnr_val = tf.reduce_mean(tf.image.psnr(F_final, y, max_val=1.0))
    nmse_val = nmse(F_final, y)

    return total_loss, ssim_val, psnr_val, nmse_val


In [7]:
import os
import numpy as np
import tensorflow as tf
from tqdm import tqdm

# =========================================================
# TRAIN FUNCTION (SAFE RESUME WITH OLD + NEW CHECKPOINTS)
# =========================================================
def train_dmse(model, train_gen, val_gen, epochs=50):

    # ------------------------------
    # Optimizer
    # ------------------------------
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)

    # ------------------------------
    # Resume-aware variables
    # ------------------------------
    epoch_counter = tf.Variable(0, dtype=tf.int64, name="epoch")
    best_val_ssim = tf.Variable(-1.0, dtype=tf.float32, name="best_val_ssim")

    # ------------------------------
    # Checkpoint
    # ------------------------------
    ckpt = tf.train.Checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=epoch_counter,
        best_val_ssim=best_val_ssim
    )

    manager = tf.train.CheckpointManager(
        ckpt,
        directory="./checkpoints_dmse_full",
        max_to_keep=1
    )

    # ------------------------------
    # RESTORE CHECKPOINT (CRITICAL LOGIC)
    # ------------------------------
    if manager.latest_checkpoint:
        ckpt.restore(manager.latest_checkpoint).expect_partial()
        print(f"\n‚úÖ Restored weights from {manager.latest_checkpoint}")

        # Old checkpoint ‚Üí epoch & SSIM did not exist
        if epoch_counter.numpy() == 0 and best_val_ssim.numpy() < 0:
            print("‚ö†Ô∏è Old checkpoint detected (no epoch / SSIM info).")
            print("‚û°Ô∏è Weights restored. Starting epoch count from 0.")
            start_epoch = 0
        else:
            start_epoch = int(epoch_counter.numpy())

        print(f"üîÅ Resuming from epoch {start_epoch}")
        print(f"‚≠ê Best Val SSIM so far: {best_val_ssim.numpy():.4f}")

    else:
        start_epoch = 0
        print("\nüÜï No checkpoint found. Training from scratch.")

    # =====================================================
    # TRAINING LOOP
    # =====================================================
    for epoch in range(start_epoch, epochs):
        print(f"\n===== Epoch {epoch+1}/{epochs} =====")

        # ======================
        # TRAINING
        # ======================
        train_losses, train_ssim, train_psnr, train_nmse = [], [], [], []

        train_bar = tqdm(range(len(train_gen)), desc="Training", ncols=120)

        for step in train_bar:
            x_batch, y_batch = train_gen[step]

            total, ssim_val, psnr_val, nmse_val = train_step(
                model, optimizer, x_batch, y_batch
            )

            train_losses.append(total.numpy())
            train_ssim.append(ssim_val.numpy())
            train_psnr.append(psnr_val.numpy())
            train_nmse.append(nmse_val.numpy())

            train_bar.set_postfix({
                "Loss": f"{total.numpy():.4f}",
                "SSIM": f"{ssim_val.numpy():.4f}",
                "PSNR": f"{psnr_val.numpy():.2f}",
                "NMSE": f"{nmse_val.numpy():.4f}"
            })

        # ======================
        # VALIDATION
        # ======================
        val_losses, val_ssim_list, val_psnr_list, val_nmse_list = [], [], [], []

        val_bar = tqdm(range(len(val_gen)), desc="Validation", ncols=120)

        for step in val_bar:
            x_val, y_val = val_gen[step]

            v_loss, v_ssim, v_psnr, v_nmse = val_step(
                model, x_val, y_val
            )

            val_losses.append(v_loss.numpy())
            val_ssim_list.append(v_ssim.numpy())
            val_psnr_list.append(v_psnr.numpy())
            val_nmse_list.append(v_nmse.numpy())

            val_bar.set_postfix({
                "Loss": f"{v_loss.numpy():.4f}",
                "SSIM": f"{v_ssim.numpy():.4f}",
                "PSNR": f"{v_psnr.numpy():.2f}",
                "NMSE": f"{v_nmse.numpy():.4f}"
            })

        # ======================
        # EPOCH SUMMARY
        # ======================
        mean_val_ssim = np.mean(val_ssim_list)

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss : {np.mean(train_losses):.4f}")
        print(f"  Val Loss   : {np.mean(val_losses):.4f}")
        print(f"  Val SSIM   : {mean_val_ssim:.4f}")
        print(f"  Val PSNR   : {np.mean(val_psnr_list):.2f}")
        print(f"  Val NMSE   : {np.mean(val_nmse_list):.4f}")

        # ======================
        # SAVE BEST CHECKPOINT
        # ======================
        if mean_val_ssim > best_val_ssim.numpy():
            best_val_ssim.assign(mean_val_ssim)
            print(f"üî• New BEST SSIM: {best_val_ssim.numpy():.4f}")

            epoch_counter.assign(epoch + 1)
            manager.save()

        # Always update epoch counter
        epoch_counter.assign(epoch + 1)

    print("\n‚úÖ Training complete.")


In [8]:
model = build_DSMENet_functional(
    N=6, M=1, T=2,
    H=320, W=320, C=2
)

train_gen = MRISliceGenerator(kspace_files_list_train, batch_size=8, shuffle=True)
val_gen   = MRISliceGenerator(kspace_files_list_val, batch_size=1, shuffle=False)

train_dmse(model, train_gen, val_gen, epochs=25)



‚úÖ Restored weights from ./checkpoints_dmse_full\ckpt-9
üîÅ Resuming from epoch 17
‚≠ê Best Val SSIM so far: 0.7150

===== Epoch 18/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:28:30<00:00,  2.88s/it, Loss=0.4202, SSIM=0.7545, PSNR=33.38, NMSE=0.0288]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:28:33<00:00,  1.34it/s, Loss=0.6242, SSIM=0.6322, PSNR=29.56, NMSE=0.0577]



Epoch 18 Summary:
  Train Loss : 0.4650
  Val Loss   : 0.4818
  Val SSIM   : 0.7154
  Val PSNR   : 32.70
  Val NMSE   : 0.1021
üî• New BEST SSIM: 0.7154

===== Epoch 19/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:17:38<00:00,  2.73s/it, Loss=0.4323, SSIM=0.7469, PSNR=32.68, NMSE=0.0317]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:25:34<00:00,  1.39it/s, Loss=0.6410, SSIM=0.6248, PSNR=29.20, NMSE=0.0626]



Epoch 19 Summary:
  Train Loss : 0.4618
  Val Loss   : 0.4992
  Val SSIM   : 0.7046
  Val PSNR   : 32.04
  Val NMSE   : 0.1081

===== Epoch 20/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:16:02<00:00,  2.71s/it, Loss=0.4275, SSIM=0.7492, PSNR=33.05, NMSE=0.0302]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:27:26<00:00,  1.36it/s, Loss=0.6316, SSIM=0.6266, PSNR=29.43, NMSE=0.0594]



Epoch 20 Summary:
  Train Loss : 0.4621
  Val Loss   : 0.4895
  Val SSIM   : 0.7087
  Val PSNR   : 32.44
  Val NMSE   : 0.1044

===== Epoch 21/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:19:05<00:00,  2.75s/it, Loss=0.4184, SSIM=0.7561, PSNR=33.43, NMSE=0.0286]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:28:10<00:00,  1.35it/s, Loss=0.6236, SSIM=0.6329, PSNR=29.56, NMSE=0.0576]



Epoch 21 Summary:
  Train Loss : 0.4596
  Val Loss   : 0.4803
  Val SSIM   : 0.7167
  Val PSNR   : 32.74
  Val NMSE   : 0.1016
üî• New BEST SSIM: 0.7167

===== Epoch 22/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:18:28<00:00,  2.74s/it, Loss=0.4244, SSIM=0.7513, PSNR=33.19, NMSE=0.0296]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:28:31<00:00,  1.34it/s, Loss=0.6282, SSIM=0.6297, PSNR=29.50, NMSE=0.0584]



Epoch 22 Summary:
  Train Loss : 0.4599
  Val Loss   : 0.4859
  Val SSIM   : 0.7126
  Val PSNR   : 32.54
  Val NMSE   : 0.1033

===== Epoch 23/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:18:43<00:00,  2.75s/it, Loss=0.4177, SSIM=0.7564, PSNR=33.45, NMSE=0.0285]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:29:49<00:00,  1.32it/s, Loss=0.6234, SSIM=0.6322, PSNR=29.59, NMSE=0.0573]



Epoch 23 Summary:
  Train Loss : 0.4590
  Val Loss   : 0.4798
  Val SSIM   : 0.7168
  Val PSNR   : 32.76
  Val NMSE   : 0.1009
üî• New BEST SSIM: 0.7168

===== Epoch 24/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:17:27<00:00,  2.73s/it, Loss=0.4530, SSIM=0.7307, PSNR=32.06, NMSE=0.0356]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:29:24<00:00,  1.33it/s, Loss=0.6396, SSIM=0.6264, PSNR=29.27, NMSE=0.0616]



Epoch 24 Summary:
  Train Loss : 0.4579
  Val Loss   : 0.4956
  Val SSIM   : 0.7078
  Val PSNR   : 32.16
  Val NMSE   : 0.1057

===== Epoch 25/25 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:20:09<00:00,  2.77s/it, Loss=0.4193, SSIM=0.7555, PSNR=33.38, NMSE=0.0289]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:30:30<00:00,  1.31it/s, Loss=0.6250, SSIM=0.6310, PSNR=29.56, NMSE=0.0577]


Epoch 25 Summary:
  Train Loss : 0.4584
  Val Loss   : 0.4819
  Val SSIM   : 0.7147
  Val PSNR   : 32.70
  Val NMSE   : 0.1014

‚úÖ Training complete.



