In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
class MRISliceGenerator(tf.keras.utils.Sequence):
    def __init__(self, file_list, batch_size=4, shuffle=True, use_dc=False):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.use_dc = use_dc   # NEW FLAG
        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, file_path in enumerate(self.file_list):
            with h5py.File(file_path, 'r') as f:
                num_slices = f['image_under'].shape[0]
                for slice_idx in range(num_slices):
                    self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):
        batch_map = self.slice_index_map[index * self.batch_size:(index + 1) * self.batch_size]

        input_img_batch = []
        target_img_batch = []
        input_kspace_batch = []

        for file_idx, slice_idx in batch_map:
            with h5py.File(self.file_list[file_idx], 'r') as f:
                input_img = f['image_under'][slice_idx]       # (H,W,2)
                target_img = f['image_full'][slice_idx]       # (H,W,2)
                input_kspace = f['kspace_under'][slice_idx]   # (H,W,2)

                input_img_batch.append(input_img)
                target_img_batch.append(target_img)
                input_kspace_batch.append(input_kspace)

        x_img = np.stack(input_img_batch, axis=0)
        x_kspace = np.stack(input_kspace_batch, axis=0)
        y_batch = np.stack(target_img_batch, axis=0)

        if self.use_dc:
            # DSMENet expects two inputs when DC is used
            return [x_img, x_kspace], y_batch
        else:
            # Only image input (ZF)
            return x_img, y_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)


In [3]:
train_folder = r"E:\DATA\fastmri_single_coil_FSSCAN_4x\train_norm"
val_folder = r"E:\DATA\fastmri_single_coil_FSSCAN_4x\val_norm"

In [4]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:half_train]
kspace_files_list_val = kspace_files_list_val[:half_val]

# Create generators
# train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True,mask=mask)
# val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=16, shuffle=True)
val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


2171
1784


In [5]:
%run model.ipynb

Model: "DSMENet_Functional"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 256, 256, 2  0           []                               
                                )]                                                                
                                                                                                  
 SRUN_1 (SRUN)                  ((None, 256, 256, 1  542204      ['input_image[0][0]']            
                                6),                                                               
                                 (None, 256, 256, 2                                               
                                ))                                                                
                                                                                 

In [6]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

# ------------------------------
# Basic losses
# ------------------------------

def ssim_loss(x, y):
    return 1.0 - tf.reduce_mean(tf.image.ssim(x, y, max_val=1.0))

def l1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

def nmse(pred, target):
    return tf.reduce_sum(tf.square(pred - target)) / tf.reduce_sum(tf.square(target))

# ------------------------------
# DSMENet loss (REWEIGHTED)
# ------------------------------

def dmse_loss(
    F_first,
    F_final,
    target,
    alpha=0.5,   # ‚Üì SSIM (final)
    beta=0.5,    # ‚Üì SSIM (intermediate)
    gamma=10.0   # ‚Üë pixel loss
):
    Lroc = ssim_loss(F_final, target)
    Lerc = ssim_loss(F_first, target)
    Lmps = l1_loss(F_final, target)

    total = alpha * Lroc + beta * Lerc + gamma * Lmps
    return total, Lroc, Lerc, Lmps

# ==============================
# Learning Rate Scheduler (StepLR)
# ==============================
# ------------------------------
# Training step
# ------------------------------

def train_step(model, optimizer, x, y):
    with tf.GradientTape() as tape:
        F_first, F_final = model(x, training=True)
        total_loss, Lroc, Lerc, Lmps = dmse_loss(F_first, F_final, y)

    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    ssim_val = tf.reduce_mean(tf.image.ssim(F_final, y, max_val=1.0))
    psnr_val = tf.reduce_mean(tf.image.psnr(F_final, y, max_val=1.0))
    nmse_val = nmse(F_final, y)

    return total_loss, ssim_val, psnr_val, nmse_val


# ------------------------------
# Validation step
# ------------------------------

def val_step(model, x, y):
    F_first, F_final = model(x, training=False)
    total_loss, _, _, _ = dmse_loss(F_first, F_final, y)

    ssim_val = tf.reduce_mean(tf.image.ssim(F_final, y, max_val=1.0))
    psnr_val = tf.reduce_mean(tf.image.psnr(F_final, y, max_val=1.0))
    nmse_val = nmse(F_final, y)

    return total_loss, ssim_val, psnr_val, nmse_val


In [7]:
import os
import numpy as np
import tensorflow as tf
from tqdm import tqdm

# =========================================================
# TRAIN FUNCTION (SAFE RESUME WITH OLD + NEW CHECKPOINTS)
# =========================================================
def train_dmse(model, train_gen, val_gen, epochs=50):

    # ------------------------------
    # Optimizer
    # ------------------------------
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)

    # ------------------------------
    # Resume-aware variables
    # ------------------------------
    epoch_counter = tf.Variable(0, dtype=tf.int64, name="epoch")
    best_val_ssim = tf.Variable(-1.0, dtype=tf.float32, name="best_val_ssim")

    # ------------------------------
    # Checkpoint
    # ------------------------------
    ckpt = tf.train.Checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=epoch_counter,
        best_val_ssim=best_val_ssim
    )

    manager = tf.train.CheckpointManager(
        ckpt,
        directory="./checkpoints_dmse_full",
        max_to_keep=1
    )

    # ------------------------------
    # RESTORE CHECKPOINT (CRITICAL LOGIC)
    # ------------------------------
    if manager.latest_checkpoint:
        ckpt.restore(manager.latest_checkpoint).expect_partial()
        print(f"\n‚úÖ Restored weights from {manager.latest_checkpoint}")

        # Old checkpoint ‚Üí epoch & SSIM did not exist
        if epoch_counter.numpy() == 0 and best_val_ssim.numpy() < 0:
            print("‚ö†Ô∏è Old checkpoint detected (no epoch / SSIM info).")
            print("‚û°Ô∏è Weights restored. Starting epoch count from 0.")
            start_epoch = 0
        else:
            start_epoch = int(epoch_counter.numpy())

        print(f"üîÅ Resuming from epoch {start_epoch}")
        print(f"‚≠ê Best Val SSIM so far: {best_val_ssim.numpy():.4f}")

    else:
        start_epoch = 0
        print("\nüÜï No checkpoint found. Training from scratch.")

    # =====================================================
    # TRAINING LOOP
    # =====================================================
    for epoch in range(start_epoch, epochs):
        print(f"\n===== Epoch {epoch+1}/{epochs} =====")

        # ======================
        # TRAINING
        # ======================
        train_losses, train_ssim, train_psnr, train_nmse = [], [], [], []

        train_bar = tqdm(range(len(train_gen)), desc="Training", ncols=120)

        for step in train_bar:
            x_batch, y_batch = train_gen[step]

            total, ssim_val, psnr_val, nmse_val = train_step(
                model, optimizer, x_batch, y_batch
            )

            train_losses.append(total.numpy())
            train_ssim.append(ssim_val.numpy())
            train_psnr.append(psnr_val.numpy())
            train_nmse.append(nmse_val.numpy())

            train_bar.set_postfix({
                "Loss": f"{total.numpy():.4f}",
                "SSIM": f"{ssim_val.numpy():.4f}",
                "PSNR": f"{psnr_val.numpy():.2f}",
                "NMSE": f"{nmse_val.numpy():.4f}"
            })

        # ======================
        # VALIDATION
        # ======================
        val_losses, val_ssim_list, val_psnr_list, val_nmse_list = [], [], [], []

        val_bar = tqdm(range(len(val_gen)), desc="Validation", ncols=120)

        for step in val_bar:
            x_val, y_val = val_gen[step]

            v_loss, v_ssim, v_psnr, v_nmse = val_step(
                model, x_val, y_val
            )

            val_losses.append(v_loss.numpy())
            val_ssim_list.append(v_ssim.numpy())
            val_psnr_list.append(v_psnr.numpy())
            val_nmse_list.append(v_nmse.numpy())

            val_bar.set_postfix({
                "Loss": f"{v_loss.numpy():.4f}",
                "SSIM": f"{v_ssim.numpy():.4f}",
                "PSNR": f"{v_psnr.numpy():.2f}",
                "NMSE": f"{v_nmse.numpy():.4f}"
            })

        # ======================
        # EPOCH SUMMARY
        # ======================
        mean_val_ssim = np.mean(val_ssim_list)

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss : {np.mean(train_losses):.4f}")
        print(f"  Val Loss   : {np.mean(val_losses):.4f}")
        print(f"  Val SSIM   : {mean_val_ssim:.4f}")
        print(f"  Val PSNR   : {np.mean(val_psnr_list):.2f}")
        print(f"  Val NMSE   : {np.mean(val_nmse_list):.4f}")

        # ======================
        # SAVE BEST CHECKPOINT
        # ======================
        if mean_val_ssim > best_val_ssim.numpy():
            best_val_ssim.assign(mean_val_ssim)
            print(f"üî• New BEST SSIM: {best_val_ssim.numpy():.4f}")

            epoch_counter.assign(epoch + 1)
            manager.save()

        # Always update epoch counter
        epoch_counter.assign(epoch + 1)

    print("\n‚úÖ Training complete.")


In [8]:
model = build_DSMENet_functional(
    N=6, M=1, T=2,
    H=320, W=320, C=2
)

train_gen = MRISliceGenerator(kspace_files_list_train, batch_size=8, shuffle=True)
val_gen   = MRISliceGenerator(kspace_files_list_val, batch_size=1, shuffle=False)

train_dmse(model, train_gen, val_gen, epochs=50)



‚úÖ Restored weights from ./checkpoints_dmse_full\ckpt-12
üîÅ Resuming from epoch 23
‚≠ê Best Val SSIM so far: 0.7168

===== Epoch 24/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:16:35<00:00,  2.72s/it, Loss=0.4145, SSIM=0.7568, PSNR=33.81, NMSE=0.0488]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:26:21<00:00,  1.38it/s, Loss=0.6225, SSIM=0.6329, PSNR=29.59, NMSE=0.0572]



Epoch 24 Summary:
  Train Loss : 0.4594
  Val Loss   : 0.4793
  Val SSIM   : 0.7174
  Val PSNR   : 32.76
  Val NMSE   : 0.1012
üî• New BEST SSIM: 0.7174

===== Epoch 25/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:18:40<00:00,  2.75s/it, Loss=0.4145, SSIM=0.7562, PSNR=33.86, NMSE=0.0484]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:27:04<00:00,  1.37it/s, Loss=0.6235, SSIM=0.6318, PSNR=29.59, NMSE=0.0573]



Epoch 25 Summary:
  Train Loss : 0.4592
  Val Loss   : 0.4803
  Val SSIM   : 0.7164
  Val PSNR   : 32.72
  Val NMSE   : 0.1016

===== Epoch 26/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:16:07<00:00,  2.71s/it, Loss=0.4130, SSIM=0.7575, PSNR=33.89, NMSE=0.0482]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:28:13<00:00,  1.35it/s, Loss=0.6225, SSIM=0.6336, PSNR=29.57, NMSE=0.0575]



Epoch 26 Summary:
  Train Loss : 0.4581
  Val Loss   : 0.4795
  Val SSIM   : 0.7174
  Val PSNR   : 32.75
  Val NMSE   : 0.1016
üî• New BEST SSIM: 0.7174

===== Epoch 27/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:19:43<00:00,  2.76s/it, Loss=0.4127, SSIM=0.7573, PSNR=33.91, NMSE=0.0480]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:29:04<00:00,  1.33it/s, Loss=0.6221, SSIM=0.6334, PSNR=29.59, NMSE=0.0572]



Epoch 27 Summary:
  Train Loss : 0.4589
  Val Loss   : 0.4791
  Val SSIM   : 0.7176
  Val PSNR   : 32.76
  Val NMSE   : 0.1013
üî• New BEST SSIM: 0.7176

===== Epoch 28/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:21:02<00:00,  2.78s/it, Loss=0.4154, SSIM=0.7553, PSNR=33.83, NMSE=0.0486]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:30:16<00:00,  1.32it/s, Loss=0.6238, SSIM=0.6316, PSNR=29.57, NMSE=0.0574]



Epoch 28 Summary:
  Train Loss : 0.4595
  Val Loss   : 0.4811
  Val SSIM   : 0.7157
  Val PSNR   : 32.69
  Val NMSE   : 0.1021

===== Epoch 29/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:21:36<00:00,  2.79s/it, Loss=0.4127, SSIM=0.7578, PSNR=33.89, NMSE=0.0481]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:30:29<00:00,  1.31it/s, Loss=0.6220, SSIM=0.6337, PSNR=29.59, NMSE=0.0572]



Epoch 29 Summary:
  Train Loss : 0.4575
  Val Loss   : 0.4791
  Val SSIM   : 0.7178
  Val PSNR   : 32.76
  Val NMSE   : 0.1013
üî• New BEST SSIM: 0.7178

===== Epoch 30/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:20:33<00:00,  2.77s/it, Loss=0.4123, SSIM=0.7576, PSNR=33.92, NMSE=0.0479]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:32:48<00:00,  1.28it/s, Loss=0.6215, SSIM=0.6339, PSNR=29.60, NMSE=0.0571]



Epoch 30 Summary:
  Train Loss : 0.4573
  Val Loss   : 0.4787
  Val SSIM   : 0.7178
  Val PSNR   : 32.77
  Val NMSE   : 0.1012
üî• New BEST SSIM: 0.7178

===== Epoch 31/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:23:24<00:00,  2.81s/it, Loss=0.4143, SSIM=0.7572, PSNR=33.80, NMSE=0.0489]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:35:01<00:00,  1.25it/s, Loss=0.6216, SSIM=0.6341, PSNR=29.59, NMSE=0.0573]



Epoch 31 Summary:
  Train Loss : 0.4568
  Val Loss   : 0.4789
  Val SSIM   : 0.7178
  Val PSNR   : 32.76
  Val NMSE   : 0.1013
üî• New BEST SSIM: 0.7178

===== Epoch 32/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:25:25<00:00,  2.84s/it, Loss=0.4145, SSIM=0.7557, PSNR=33.87, NMSE=0.0484]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:34:35<00:00,  1.26it/s, Loss=0.6248, SSIM=0.6304, PSNR=29.56, NMSE=0.0576]



Epoch 32 Summary:
  Train Loss : 0.4620
  Val Loss   : 0.4809
  Val SSIM   : 0.7157
  Val PSNR   : 32.71
  Val NMSE   : 0.1020

===== Epoch 33/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:27:14<00:00,  2.86s/it, Loss=0.4119, SSIM=0.7579, PSNR=33.93, NMSE=0.0478]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:39:24<00:00,  1.20it/s, Loss=0.6213, SSIM=0.6340, PSNR=29.60, NMSE=0.0571]



Epoch 33 Summary:
  Train Loss : 0.4574
  Val Loss   : 0.4782
  Val SSIM   : 0.7182
  Val PSNR   : 32.79
  Val NMSE   : 0.1010
üî• New BEST SSIM: 0.7182

===== Epoch 34/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:29:27<00:00,  2.89s/it, Loss=0.4118, SSIM=0.7579, PSNR=33.94, NMSE=0.0478]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:38:46<00:00,  1.20it/s, Loss=0.6210, SSIM=0.6344, PSNR=29.60, NMSE=0.0570]



Epoch 34 Summary:
  Train Loss : 0.4570
  Val Loss   : 0.4779
  Val SSIM   : 0.7184
  Val PSNR   : 32.79
  Val NMSE   : 0.1010
üî• New BEST SSIM: 0.7184

===== Epoch 35/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:26:59<00:00,  2.86s/it, Loss=0.4115, SSIM=0.7582, PSNR=33.94, NMSE=0.0477]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:41:59<00:00,  1.17it/s, Loss=0.6209, SSIM=0.6346, PSNR=29.60, NMSE=0.0571]



Epoch 35 Summary:
  Train Loss : 0.4562
  Val Loss   : 0.4777
  Val SSIM   : 0.7185
  Val PSNR   : 32.80
  Val NMSE   : 0.1009
üî• New BEST SSIM: 0.7185

===== Epoch 36/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:29:14<00:00,  2.89s/it, Loss=0.4114, SSIM=0.7583, PSNR=33.94, NMSE=0.0477]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:44:05<00:00,  1.14it/s, Loss=0.6212, SSIM=0.6344, PSNR=29.60, NMSE=0.0571]



Epoch 36 Summary:
  Train Loss : 0.4556
  Val Loss   : 0.4779
  Val SSIM   : 0.7183
  Val PSNR   : 32.79
  Val NMSE   : 0.1009

===== Epoch 37/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:31:53<00:00,  2.93s/it, Loss=0.4236, SSIM=0.7422, PSNR=33.72, NMSE=0.0490]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:46:49<00:00,  1.11it/s, Loss=0.6267, SSIM=0.6316, PSNR=29.47, NMSE=0.0588]



Epoch 37 Summary:
  Train Loss : 0.4609
  Val Loss   : 0.4835
  Val SSIM   : 0.7148
  Val PSNR   : 32.60
  Val NMSE   : 0.1029

===== Epoch 38/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:30:48<00:00,  2.91s/it, Loss=0.4110, SSIM=0.7585, PSNR=33.96, NMSE=0.0476]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:47:46<00:00,  1.10it/s, Loss=0.6203, SSIM=0.6349, PSNR=29.61, NMSE=0.0569]



Epoch 38 Summary:
  Train Loss : 0.4565
  Val Loss   : 0.4777
  Val SSIM   : 0.7185
  Val PSNR   : 32.80
  Val NMSE   : 0.1009
üî• New BEST SSIM: 0.7185

===== Epoch 39/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:34:35<00:00,  2.97s/it, Loss=0.4116, SSIM=0.7582, PSNR=33.94, NMSE=0.0477]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:47:22<00:00,  1.11it/s, Loss=0.6205, SSIM=0.6347, PSNR=29.61, NMSE=0.0569]



Epoch 39 Summary:
  Train Loss : 0.4555
  Val Loss   : 0.4776
  Val SSIM   : 0.7186
  Val PSNR   : 32.80
  Val NMSE   : 0.1006
üî• New BEST SSIM: 0.7186

===== Epoch 40/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:35:06<00:00,  2.97s/it, Loss=0.4117, SSIM=0.7579, PSNR=33.93, NMSE=0.0477]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:51:12<00:00,  1.07it/s, Loss=0.6210, SSIM=0.6348, PSNR=29.59, NMSE=0.0572]



Epoch 40 Summary:
  Train Loss : 0.4593
  Val Loss   : 0.4780
  Val SSIM   : 0.7184
  Val PSNR   : 32.78
  Val NMSE   : 0.1011

===== Epoch 41/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:36:35<00:00,  2.99s/it, Loss=0.4135, SSIM=0.7565, PSNR=33.86, NMSE=0.0481]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:55:09<00:00,  1.03it/s, Loss=0.6206, SSIM=0.6345, PSNR=29.61, NMSE=0.0569]



Epoch 41 Summary:
  Train Loss : 0.4555
  Val Loss   : 0.4780
  Val SSIM   : 0.7179
  Val PSNR   : 32.80
  Val NMSE   : 0.1008

===== Epoch 42/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:38:12<00:00,  3.02s/it, Loss=0.4106, SSIM=0.7587, PSNR=33.96, NMSE=0.0475]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:52:29<00:00,  1.06it/s, Loss=0.6202, SSIM=0.6347, PSNR=29.63, NMSE=0.0567]



Epoch 42 Summary:
  Train Loss : 0.4550
  Val Loss   : 0.4772
  Val SSIM   : 0.7187
  Val PSNR   : 32.82
  Val NMSE   : 0.1005
üî• New BEST SSIM: 0.7187

===== Epoch 43/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:37:51<00:00,  3.01s/it, Loss=0.4110, SSIM=0.7584, PSNR=33.95, NMSE=0.0475]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [1:59:15<00:00,  1.00s/it, Loss=0.6215, SSIM=0.6342, PSNR=29.59, NMSE=0.0572]



Epoch 43 Summary:
  Train Loss : 0.4548
  Val Loss   : 0.4786
  Val SSIM   : 0.7178
  Val PSNR   : 32.76
  Val NMSE   : 0.1009

===== Epoch 44/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:41:01<00:00,  3.05s/it, Loss=0.4110, SSIM=0.7585, PSNR=33.95, NMSE=0.0476]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:01:40<00:00,  1.02s/it, Loss=0.6204, SSIM=0.6350, PSNR=29.62, NMSE=0.0569]



Epoch 44 Summary:
  Train Loss : 0.4579
  Val Loss   : 0.4775
  Val SSIM   : 0.7187
  Val PSNR   : 32.80
  Val NMSE   : 0.1007
üî• New BEST SSIM: 0.7187

===== Epoch 45/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:42:19<00:00,  3.07s/it, Loss=0.4101, SSIM=0.7591, PSNR=33.99, NMSE=0.0474]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:03:33<00:00,  1.04s/it, Loss=0.6199, SSIM=0.6350, PSNR=29.63, NMSE=0.0567]



Epoch 45 Summary:
  Train Loss : 0.4547
  Val Loss   : 0.4769
  Val SSIM   : 0.7189
  Val PSNR   : 32.82
  Val NMSE   : 0.1004
üî• New BEST SSIM: 0.7189

===== Epoch 46/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:43:51<00:00,  3.09s/it, Loss=0.4100, SSIM=0.7591, PSNR=33.98, NMSE=0.0474]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:05:34<00:00,  1.06s/it, Loss=0.6197, SSIM=0.6353, PSNR=29.62, NMSE=0.0568]



Epoch 46 Summary:
  Train Loss : 0.4556
  Val Loss   : 0.4767
  Val SSIM   : 0.7191
  Val PSNR   : 32.82
  Val NMSE   : 0.1005
üî• New BEST SSIM: 0.7191

===== Epoch 47/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:46:04<00:00,  3.12s/it, Loss=0.4108, SSIM=0.7587, PSNR=33.96, NMSE=0.0475]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:07:49<00:00,  1.07s/it, Loss=0.6207, SSIM=0.6349, PSNR=29.60, NMSE=0.0571]



Epoch 47 Summary:
  Train Loss : 0.4591
  Val Loss   : 0.4779
  Val SSIM   : 0.7182
  Val PSNR   : 32.79
  Val NMSE   : 0.1009

===== Epoch 48/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:44:45<00:00,  3.11s/it, Loss=0.4106, SSIM=0.7587, PSNR=33.97, NMSE=0.0474]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:04:19<00:00,  1.05s/it, Loss=0.6202, SSIM=0.6350, PSNR=29.61, NMSE=0.0569]



Epoch 48 Summary:
  Train Loss : 0.4546
  Val Loss   : 0.4773
  Val SSIM   : 0.7190
  Val PSNR   : 32.80
  Val NMSE   : 0.1006

===== Epoch 49/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:45:50<00:00,  3.12s/it, Loss=0.4103, SSIM=0.7588, PSNR=33.97, NMSE=0.0474]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:11:07<00:00,  1.10s/it, Loss=0.6199, SSIM=0.6349, PSNR=29.63, NMSE=0.0568]



Epoch 49 Summary:
  Train Loss : 0.4543
  Val Loss   : 0.4769
  Val SSIM   : 0.7189
  Val PSNR   : 32.82
  Val NMSE   : 0.1004

===== Epoch 50/50 =====


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4342/4342 [3:47:02<00:00,  3.14s/it, Loss=0.4124, SSIM=0.7571, PSNR=33.92, NMSE=0.0478]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [2:12:41<00:00,  1.12s/it, Loss=0.6217, SSIM=0.6337, PSNR=29.59, NMSE=0.0573]


Epoch 50 Summary:
  Train Loss : 0.4583
  Val Loss   : 0.4796
  Val SSIM   : 0.7169
  Val PSNR   : 32.72
  Val NMSE   : 0.1016

‚úÖ Training complete.





In [10]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm

# =========================================================
# METRICS (SAME AS TRAINING)
# =========================================================

def nmse(pred, target):
    return tf.reduce_sum(tf.square(pred - target)) / tf.reduce_sum(tf.square(target))


# =========================================================
# CHECKPOINT LOADER
# =========================================================

def load_dmse_checkpoint(model, ckpt_dir="./checkpoints_dmse_full"):
    optimizer = tf.keras.optimizers.Adam()  # dummy optimizer

    epoch_counter = tf.Variable(0, dtype=tf.int64)
    best_val_ssim = tf.Variable(-1.0, dtype=tf.float32)

    ckpt = tf.train.Checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=epoch_counter,
        best_val_ssim=best_val_ssim
    )

    manager = tf.train.CheckpointManager(
        ckpt,
        directory=ckpt_dir,
        max_to_keep=1
    )

    if not manager.latest_checkpoint:
        raise RuntimeError("‚ùå No checkpoint found!")

    ckpt.restore(manager.latest_checkpoint).expect_partial()

    print(f"\n‚úÖ Restored checkpoint: {manager.latest_checkpoint}")
    print(f"‚≠ê Best Val SSIM: {best_val_ssim.numpy():.4f}")

    return model


# =========================================================
# INFERENCE STEP
# =========================================================

@tf.function
def inference_step(model, x):
    _, F_final = model(x, training=False)
    return F_final


# =========================================================
# EVALUATION LOOP (NO SAVING)
# =========================================================

def evaluate_dmse(model, val_gen):
    ssim_list = []
    psnr_list = []
    nmse_list = []

    val_bar = tqdm(range(len(val_gen)), desc="Evaluating", ncols=120)

    for step in val_bar:
        x_val, y_val = val_gen[step]

        # Forward pass
        F_final = inference_step(model, x_val)

        # Metrics
        ssim_val = tf.reduce_mean(tf.image.ssim(F_final, y_val, max_val=1.0))
        psnr_val = tf.reduce_mean(tf.image.psnr(F_final, y_val, max_val=1.0))
        nmse_val = nmse(F_final, y_val)

        ssim_list.append(ssim_val.numpy())
        psnr_list.append(psnr_val.numpy())
        nmse_list.append(nmse_val.numpy())

        val_bar.set_postfix({
            "SSIM": f"{ssim_val.numpy():.4f}",
            "PSNR": f"{psnr_val.numpy():.2f}",
            "NMSE": f"{nmse_val.numpy():.4f}"
        })

    # =====================================================
    # FINAL METRICS
    # =====================================================
    mean_ssim = float(np.mean(ssim_list))
    mean_psnr = float(np.mean(psnr_list))
    mean_nmse = float(np.mean(nmse_list))

    print("\nüìä FINAL VALIDATION RESULTS")
    print(f"  Mean SSIM : {mean_ssim:.4f}")
    print(f"  Mean PSNR : {mean_psnr:.2f}")
    print(f"  Mean NMSE : {mean_nmse:.4f}")

    return mean_ssim, mean_psnr, mean_nmse


# =========================================================
# RUN EVALUATION
# =========================================================

model = load_dmse_checkpoint(model)

mean_ssim, mean_psnr, mean_nmse = evaluate_dmse(model, val_gen)



‚úÖ Restored checkpoint: ./checkpoints_dmse_full\ckpt-27
‚≠ê Best Val SSIM: 0.7191


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7135/7135 [42:00<00:00,  2.83it/s, SSIM=0.6353, PSNR=29.62, NMSE=0.0568]


üìä FINAL VALIDATION RESULTS
  Mean SSIM : 0.7192
  Mean PSNR : 32.82
  Mean NMSE : 0.1005



