In [1]:
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
mask = np.load(r"C:\Users\DU\aman_fastmri\Data\mask_4x_320_random.npy")  # Shape: (1, 320, 320)
print("og shape:", mask.shape)

# # Use np.tile to reshape it to (1, 320, 320, 1)
# # var_sampling_mask = np.tile(var_sampling_mask[..., np.newaxis], (1, 1, 1, 1))  # Final shape: (1, 320, 320, 1)
# mask = np.tile(mask, (1, 320, 1, 2))  # tile height=320 times

# # Confirm final shape
# print("New shape:", mask.shape) 
# mask_for_plot = np.squeeze(mask[...,0])  # Shape: (320, 320)

# # Plot
# plt.figure(figsize=(5, 5))
# plt.imshow(mask_for_plot, cmap='gray')
# plt.title("Tiled Sampling Mask (320x320)")
# plt.axis('off')
# plt.show()

og shape: (1, 1, 320, 1)


In [3]:
import h5py
import numpy as np
import tensorflow as tf

def to_complex(x):
    return x[..., 0] + 1j * x[..., 1]

class MRISliceGeneratorMag(tf.keras.utils.Sequence):
    """
    Data generator for magnitude-only MRI reconstruction.

    Input  : undersampled magnitude image  (B, H, W, 1)
    Target : fully-sampled magnitude image (B, H, W, 1)
    """

    def __init__(self, file_list, batch_size=4, shuffle=True):
        self.file_list = file_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.slice_index_map = []
        self._build_index()

    def _build_index(self):
        for file_idx, file_path in enumerate(self.file_list):
            with h5py.File(file_path, 'r') as f:
                num_slices = f['image_under'].shape[0]
                for slice_idx in range(num_slices):
                    self.slice_index_map.append((file_idx, slice_idx))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.slice_index_map) / self.batch_size))

    def __getitem__(self, index):
        batch_map = self.slice_index_map[
            index * self.batch_size : (index + 1) * self.batch_size
        ]

        input_mag_batch = []
        target_mag_batch = []

        for file_idx, slice_idx in batch_map:
            with h5py.File(self.file_list[file_idx], 'r') as f:
                img_under = f['image_under'][slice_idx]  # (H, W, 2)
                img_full  = f['image_full'][slice_idx]   # (H, W, 2)

                # Convert to complex
                img_under_c = to_complex(img_under)
                img_full_c  = to_complex(img_full)

                # Magnitude
                img_under_mag = np.abs(img_under_c)
                img_full_mag  = np.abs(img_full_c)

                input_mag_batch.append(img_under_mag)
                target_mag_batch.append(img_full_mag)

        # Stack and add channel dimension
        x_batch = np.stack(input_mag_batch, axis=0)[..., np.newaxis]  # (B, H, W, 1)
        y_batch = np.stack(target_mag_batch, axis=0)[..., np.newaxis] # (B, H, W, 1)

        return x_batch, y_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.slice_index_map)


In [4]:
train_folder = r"D:\fastmri_singlecoil_FSSCAN\train_norm"
val_folder = r"D:\fastmri_singlecoil_FSSCAN\val_norm"

In [5]:
import h5py
import numpy as np
import glob
import os
kspace_files_list_train = sorted(glob.glob(os.path.join(train_folder, "*.h5")))
kspace_files_list_val = sorted(glob.glob(os.path.join(val_folder, "*.h5")))

# half_train = 20
# half_val = 10
half_train = len(kspace_files_list_train) 
half_val = len(kspace_files_list_val) 
# print("half_train",half_train)
# print("half_val",half_val)
kspace_files_list_train = kspace_files_list_train[:half_train]
kspace_files_list_val = kspace_files_list_val[:half_val]

# Create generators
# train_gen = MRISliceGeneratorMag(kspace_files_list_train,batch_size=4, shuffle=True,mask=mask)
# val_gen = MRISliceGeneratorMag(kspace_files_list_val, batch_size=4, shuffle=False,mask=mask)
train_gen = MRISliceGeneratorMag(kspace_files_list_train,batch_size=8, shuffle=True)
val_gen = MRISliceGeneratorMag(kspace_files_list_val, batch_size=4, shuffle=False)

print(len(train_gen))  
print(len(val_gen))  


4338
1784


In [6]:
%run ./DCRCNN.ipynb


In [7]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# ============================================================
# Directory Setup
# ============================================================
save_dir = "./SavedModels_DCRCNN_full"
os.makedirs(save_dir, exist_ok=True)

# ============================================================
# Configuration
# ============================================================
H, W = 320, 320
EPOCHS = 50
LEARNING_RATE = 1e-4

# TensorFlow checkpoint paths (NO .h5)
INIT_CKPT  = os.path.join(save_dir, "init_ckpt")
BEST_CKPT  = os.path.join(save_dir, "best_ckpt")
FINAL_CKPT = os.path.join(save_dir, "final_ckpt")

print("=" * 60)
print("üîß TRAINING CONFIGURATION")
print("=" * 60)
print(f" Save Directory:       {save_dir}")
print(f" Model Dimensions:     {H}x{W}")
print(f" Epochs:               {EPOCHS}")
print(f" Learning Rate:        {LEARNING_RATE}")
print(f" Init Checkpoint:      {INIT_CKPT}")
print(f" Best Checkpoint:      {BEST_CKPT}")
print(f" Final Checkpoint:     {FINAL_CKPT}")
print("=" * 60)

# ============================================================
# Model Setup
# ============================================================
model = build_dcr_cnn(
    input_shape=(320, 320, 1),
    num_dcr_blocks=10,   # or 3 / 8
    num_features=64,
    growth_rate=32
)


# ============================================================
# Optimizer & Compile
# ============================================================
optimizer = Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer, loss="mse")

# ============================================================
# Load Initial Weights (Optional Resume)
# ============================================================
if tf.train.latest_checkpoint(save_dir):
    model.load_weights(tf.train.latest_checkpoint(save_dir))
    print("‚úÖ Loaded latest checkpoint")
else:
    print("‚ÑπÔ∏è No checkpoint found. Training from scratch.")

# ============================================================
# Callbacks
# ============================================================
checkpoint_cb = ModelCheckpoint(
    filepath=BEST_CKPT,
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_loss",
    patience=20,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_cb = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=10,
    min_lr=1e-7,
    verbose=1
)

callbacks = [checkpoint_cb, earlystop_cb, reduce_lr_cb]

# ============================================================
# Training
# ============================================================
print("\nüöÄ STARTING TRAINING...")
print("=" * 60)

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ TRAINING COMPLETED")

# ============================================================
# Save Final Weights
# ============================================================
model.save_weights(FINAL_CKPT)
print(f"‚úÖ Final weights saved to {FINAL_CKPT}")

# ============================================================
# Training Analysis
# ============================================================
if history:
    print("\nüìä TRAINING ANALYSIS")
    print("=" * 60)
    best_epoch = np.argmin(history.history["val_loss"]) + 1
    print(f" Best Epoch: {best_epoch}")
    print(f" Best Val Loss: {np.min(history.history['val_loss']):.6f}")

# ============================================================
# Plot Training Curves
# ============================================================

üîß TRAINING CONFIGURATION
 Save Directory:       ./SavedModels_DCRCNN_full
 Model Dimensions:     320x320
 Epochs:               50
 Learning Rate:        0.0001
 Init Checkpoint:      ./SavedModels_DCRCNN_full\init_ckpt
 Best Checkpoint:      ./SavedModels_DCRCNN_full\best_ckpt
 Final Checkpoint:     ./SavedModels_DCRCNN_full\final_ckpt
‚ÑπÔ∏è No checkpoint found. Training from scratch.

üöÄ STARTING TRAINING...
Epoch 1/50
Epoch 1: val_loss improved from inf to 0.00059, saving model to ./SavedModels_DCRCNN_full\best_ckpt
Epoch 2/50
Epoch 2: val_loss improved from 0.00059 to 0.00056, saving model to ./SavedModels_DCRCNN_full\best_ckpt
Epoch 3/50
Epoch 3: val_loss improved from 0.00056 to 0.00053, saving model to ./SavedModels_DCRCNN_full\best_ckpt
Epoch 4/50
Epoch 4: val_loss improved from 0.00053 to 0.00052, saving model to ./SavedModels_DCRCNN_full\best_ckpt
Epoch 5/50
Epoch 5: val_loss improved from 0.00052 to 0.00051, saving model to ./SavedModels_DCRCNN_full\best_ckpt
Epoch 6/5