In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
mask = np.load(r"C:\Users\DU\aman_fastmri\Data\mask_4x_320_random.npy")  # Shape: (1, 320, 320)
print("og shape:", mask.shape)

# # Use np.tile to reshape it to (1, 320, 320, 1)
# # var_sampling_mask = np.tile(var_sampling_mask[..., np.newaxis], (1, 1, 1, 1))  # Final shape: (1, 320, 320, 1)
# mask = np.tile(mask, (1, 320, 1, 2))  # tile height=320 times

# # Confirm final shape
# print("New shape:", mask.shape) 
# mask_for_plot = np.squeeze(mask[...,0])  # Shape: (320, 320)

# # Plot
# plt.figure(figsize=(5, 5))
# plt.imshow(mask_for_plot, cmap='gray')
# plt.title("Tiled Sampling Mask (320x320)")
# plt.axis('off')
# plt.show()

og shape: (1, 1, 320, 1)


In [3]:
import h5py
import numpy as np
import tensorflow as tf

class MRISliceGenerator(tf.keras.utils.Sequence):
    def __init__(self, file_list, batch_size=4, shuffle=True, mask=None):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.mask = mask  # Shape: (1, 320, 320, 2)
        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, file_path in enumerate(self.file_list):
            with h5py.File(file_path, 'r') as f:
                num_slices = f['image_under'].shape[0]
                for slice_idx in range(num_slices):
                    self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):
        batch_map = self.slice_index_map[index * self.batch_size:(index + 1) * self.batch_size]

        input_img_batch = []
        target_img_batch = []
        input_kspace_batch = []

        for file_idx, slice_idx in batch_map:
            with h5py.File(self.file_list[file_idx], 'r') as f:
                input_img = f['image_under'][slice_idx]       # shape: [H, W, 2]
                target_img = f['image_full'][slice_idx]       # shape: [H, W, 2]
                input_kspace = f['kspace_under'][slice_idx]   # shape: [H, W, 2]

                input_img_batch.append(input_img)
                target_img_batch.append(target_img)
                input_kspace_batch.append(input_kspace)

        x_img = np.stack(input_img_batch, axis=0)
        x_kspace = np.stack(input_kspace_batch, axis=0)
        y_batch = np.stack(target_img_batch, axis=0)

        # if self.mask is not None:
        #     actual_batch_size = len(x_img)
        #     if self.mask.shape == (1, 320, 320, 2):
        #         mask_batch = np.tile(self.mask, (actual_batch_size, 1, 1, 1))
        #     else:
        #         raise ValueError("Mask must have shape (1, 320, 320, 2)")
        #     return [x_img, mask_batch, x_kspace], y_batch
        # else:
        #     return [x_img, x_kspace], y_batch
        return x_img, y_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)


In [4]:
train_folder = r"D:\fastmri_singlecoil_FSSCAN\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"

In [6]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:half_train]
kspace_files_list_val = kspace_files_list_val[:half_val]

# Create generators
# train_gen = MRISliceGeneratorMag(kspace_files_list_train,batch_size=4, shuffle=True,mask=mask)
# val_gen = MRISliceGeneratorMag(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGenerator(kspace_files_list_train,batch_size=8, shuffle=True)
val_gen = MRISliceGenerator(kspace_files_list_val, batch_size=4, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


4338
1784


In [7]:
%run ./DCR-Unet.ipynb
#SOTA_paper_2_DCRCNN/SOTA_paper_2_DCRCNN-20251226T060810Z-1-001/SOTA_paper_2_DCRCNN/DCR-Unet.ipynb

Model: "DCR_UNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 320, 320, 2  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 320, 320, 32  96          ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 p_re_lu (PReLU)                (None, 320, 320, 32  32          ['conv2d[0][0]']                 
                                )                                                          

In [8]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# ============================================================
# Directory Setup
# ============================================================
save_dir = "./SavedModels_DCUNET_full_2"
os.makedirs(save_dir, exist_ok=True)

# ============================================================
# Configuration
# ============================================================
H, W = 320, 320
EPOCHS = 50
LEARNING_RATE = 1e-4

# TensorFlow checkpoint paths (NO .h5)
INIT_CKPT  = os.path.join(save_dir, "init_ckpt")
BEST_CKPT  = os.path.join(save_dir, "best_ckpt")
FINAL_CKPT = os.path.join(save_dir, "final_ckpt")

print("=" * 60)
print("üîß TRAINING CONFIGURATION")
print("=" * 60)
print(f" Save Directory:       {save_dir}")
print(f" Model Dimensions:     {H}x{W}")
print(f" Epochs:               {EPOCHS}")
print(f" Learning Rate:        {LEARNING_RATE}")
print(f" Init Checkpoint:      {INIT_CKPT}")
print(f" Best Checkpoint:      {BEST_CKPT}")
print(f" Final Checkpoint:     {FINAL_CKPT}")
print("=" * 60)

# ============================================================
# Model Setup
# ============================================================
# model = build_dcr_cnn(
#     input_shape=(320, 320, 1),
#     num_dcr_blocks=10,   # or 3 / 8
#     num_features=64,
#     growth_rate=32
# )

model = build_dcr_unet(input_shape=(320, 320, 2))
# ============================================================
# Optimizer & Compile
# ============================================================
optimizer = Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer, loss="mse")

# ============================================================
# Load Initial Weights (Optional Resume)
# ============================================================
if tf.train.latest_checkpoint(save_dir):
    model.load_weights(tf.train.latest_checkpoint(save_dir))
    print("‚úÖ Loaded latest checkpoint")
else:
    print("‚ÑπÔ∏è No checkpoint found. Training from scratch.")

# ============================================================
# Callbacks
# ============================================================
checkpoint_cb = ModelCheckpoint(
    filepath=BEST_CKPT,
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_loss",
    patience=20,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=10,
    min_lr=1e-7,
    verbose=1
)

callbacks = [checkpoint_cb, earlystop_cb, reduce_lr_cb]

# ============================================================
# Training
# ============================================================
print("\nüöÄ STARTING TRAINING...")
print("=" * 60)

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ TRAINING COMPLETED")

# ============================================================
# Save Final Weights
# ============================================================
model.save_weights(FINAL_CKPT)
print(f"‚úÖ Final weights saved to {FINAL_CKPT}")

# ============================================================
# Training Analysis
# ============================================================
if history:
    print("\nüìä TRAINING ANALYSIS")
    print("=" * 60)
    best_epoch = np.argmin(history.history["val_loss"]) + 1
    print(f" Best Epoch: {best_epoch}")
    print(f" Best Val Loss: {np.min(history.history['val_loss']):.6f}")

# ============================================================
# Plot Training Curves
# ============================================================

üîß TRAINING CONFIGURATION
 Save Directory:       ./SavedModels_DCUNET_full_2
 Model Dimensions:     320x320
 Epochs:               50
 Learning Rate:        0.0001
 Init Checkpoint:      ./SavedModels_DCUNET_full_2\init_ckpt
 Best Checkpoint:      ./SavedModels_DCUNET_full_2\best_ckpt
 Final Checkpoint:     ./SavedModels_DCUNET_full_2\final_ckpt
‚ÑπÔ∏è No checkpoint found. Training from scratch.

üöÄ STARTING TRAINING...
Epoch 1/50
Epoch 1: val_loss improved from inf to 0.00082, saving model to ./SavedModels_DCUNET_full_2\best_ckpt
Epoch 2/50
Epoch 2: val_loss improved from 0.00082 to 0.00079, saving model to ./SavedModels_DCUNET_full_2\best_ckpt
Epoch 3/50
Epoch 3: val_loss improved from 0.00079 to 0.00077, saving model to ./SavedModels_DCUNET_full_2\best_ckpt
Epoch 4/50
Epoch 4: val_loss improved from 0.00077 to 0.00076, saving model to ./SavedModels_DCUNET_full_2\best_ckpt
Epoch 5/50
Epoch 5: val_loss improved from 0.00076 to 0.00075, saving model to ./SavedModels_DCUNET_full_2\b

KeyboardInterrupt: 