In [None]:
# train.py
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# from models.model import UNetDeep
from models.model import UNet2DDeep

class MRISuperResDataset(Dataset):
    """
    A dataset that takes lists (or arrays) of LR and HR 2D slices
    and returns them as PyTorch tensors.
    """
    def __init__(self, lr_slices, hr_slices, transform=None):
        """
        lr_slices, hr_slices: lists (or NumPy arrays) of shape (D, H, W)
        transform: any optional transform function (if needed)
        """
        self.lr_slices = lr_slices
        self.hr_slices = hr_slices
        self.transform = transform

    def __len__(self):
        return len(self.lr_slices)

    def __getitem__(self, idx):
        lr_slice = self.lr_slices[idx]  # shape [H, W]
        hr_slice = self.hr_slices[idx]  # shape [H, W]

        # Convert to float32
        lr_slice = lr_slice.astype(np.float32)
        hr_slice = hr_slice.astype(np.float32)

        # Add channel dimension: [1, H, W]
        lr_slice = np.expand_dims(lr_slice, axis=0)
        hr_slice = np.expand_dims(hr_slice, axis=0)

        # Convert to torch tensors
        lr_tensor = torch.from_numpy(lr_slice)
        hr_tensor = torch.from_numpy(hr_slice)

        if self.transform:
            lr_tensor = self.transform(lr_tensor)
            hr_tensor = self.transform(hr_tensor)

        return lr_tensor, hr_tensor

# def update_learning_rate(schedulers, val_loss):
#     """
#     Update the learning rate for each scheduler.
    
#     For schedulers of type ReduceLROnPlateau, use scheduler.step(val_loss),
#     otherwise, call scheduler.step() without arguments.
    
#     Parameters:
#       schedulers (list): List of learning rate scheduler objects.
#       val_loss (float): The validation loss used by ReduceLROnPlateau.
#     """
#     for scheduler in schedulers:
#         if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
#             scheduler.step(val_loss)
#         else:
#             scheduler.step()

def main():


    # 1) Load pre-saved 2D slices (uncomment these lines)
    slices_lr_axial = np.load("data/lr_slices_axial.npy", allow_pickle=True)
    slices_hr_axial = np.load("data/hr_slices_axial.npy", allow_pickle=True)
    slices_lr_coronal = np.load("data/lr_slices_coronal.npy", allow_pickle=True)
    slices_hr_coronal = np.load("data/hr_slices_coronal.npy", allow_pickle=True)
    slices_lr_sagittal = np.load("data/lr_slices_sagittal.npy", allow_pickle=True)
    slices_hr_sagittal = np.load("data/hr_slices_sagittal.npy", allow_pickle=True)

    all_lr_volumes = np.concatenate((slices_lr_axial, slices_lr_coronal), axis=0)
    all_hr_volumes = np.concatenate((slices_hr_axial, slices_hr_coronal), axis=0)

    # Here, we just combine everything for a single training set.

    # 2) Create Dataset and DataLoader
    train_dataset = MRISuperResDataset(all_lr_volumes, all_hr_volumes)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    # 3) Initialize the model
    model = UNet2DDeep(in_channels=1, out_channels=1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 4) Define loss & optimizer
    criterion = nn.MSELoss()  # or L1Loss, SmoothL1Loss, etc.
    optimizer = optim.Adam(model.parameters(), lr=1e-2)

    # # Setup learning rate schedulers
    # from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
    # scheduler_plateau = ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
    # scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
    # schedulers = [scheduler_plateau, scheduler_step]


    # 5) Train loop
    num_epochs = 20
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
        for lr_batch, hr_batch in loop:
            lr_batch = lr_batch.to(device)
            hr_batch = hr_batch.to(device)
    
            optimizer.zero_grad()
            outputs = model(lr_batch)
            loss = criterion(outputs, hr_batch)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}]  Loss: {epoch_loss:.4f}")


    # Save model weights
    torch.save(model.state_dict(), "superres_unet_v4.pth")
    print("Model saved as superres_unet_v4.pth")

if __name__ == "__main__":
    main()


Epoch 1/20: 100%|██████████| 128/128 [37:18<00:00, 17.49s/it, loss=1.91e+3]


Epoch [1/20]  Loss: 4748.8583


Epoch 2/20: 100%|██████████| 128/128 [33:24<00:00, 15.66s/it, loss=1.81e+3]


Epoch [2/20]  Loss: 1020.8412


Epoch 3/20: 100%|██████████| 128/128 [33:11<00:00, 15.56s/it, loss=662]    


Epoch [3/20]  Loss: 974.2535


Epoch 4/20: 100%|██████████| 128/128 [30:26<00:00, 14.27s/it, loss=766]    


Epoch [4/20]  Loss: 890.7086


Epoch 5/20: 100%|██████████| 128/128 [30:07<00:00, 14.12s/it, loss=1.34e+3]


Epoch [5/20]  Loss: 776.3741


Epoch 6/20: 100%|██████████| 128/128 [30:40<00:00, 14.38s/it, loss=320]    


Epoch [6/20]  Loss: 677.5660


Epoch 7/20: 100%|██████████| 128/128 [30:07<00:00, 14.12s/it, loss=1.14e+3]


Epoch [7/20]  Loss: 662.8645


Epoch 8/20: 100%|██████████| 128/128 [29:53<00:00, 14.01s/it, loss=416]    


Epoch [8/20]  Loss: 623.5797


Epoch 9/20: 100%|██████████| 128/128 [28:59<00:00, 13.59s/it, loss=690]    


Epoch [9/20]  Loss: 563.3344


Epoch 10/20: 100%|██████████| 128/128 [30:13<00:00, 14.17s/it, loss=564]    


Epoch [10/20]  Loss: 604.0691


Epoch 11/20: 100%|██████████| 128/128 [32:19<00:00, 15.16s/it, loss=1.27e+3]


Epoch [11/20]  Loss: 562.6161


Epoch 12/20: 100%|██████████| 128/128 [45:55<00:00, 21.53s/it, loss=541]


Epoch [12/20]  Loss: 521.6514


Epoch 13/20: 100%|██████████| 128/128 [34:22<00:00, 16.12s/it, loss=678]   


Epoch [13/20]  Loss: 536.7205


Epoch 14/20: 100%|██████████| 128/128 [1:10:01<00:00, 32.83s/it, loss=719]    


Epoch [14/20]  Loss: 618.4904


Epoch 15/20:   9%|▊         | 11/128 [02:44<31:05, 15.95s/it, loss=695]

Overall Purpose
Data Loading:
You load pre-saved 2D slices from multiple anatomical views and combine them into one training dataset.

Dataset & DataLoader:
The dataset class converts each 2D slice into a tensor with shape [1, H, W]. The DataLoader batches these samples.

Model Training:
A 2D U‑Net is trained on these 2D slices using MSE loss and the Adam optimizer. The training loop over 20 epochs minimizes the loss.

Model Saving:
The trained model’s weights are saved for later use in inference (for example, to super-resolve new LR images).

Super-Resolution Context:
In a super-resolution task, you use LR slices as input and the corresponding HR slices as the target. The model learns to map the low-quality, upsampled LR slices to the high-quality HR slices. Later, during inference, you can apply the model slice-by-slice to a new volume and then reassemble the 2D outputs into a full 3D volume.

Concatenating Slices from Different Views:
By concatenating axial and coronal slices into one array, you're effectively treating each 2D slice as an independent training sample. This means your training set will include slices from different anatomical views. Although they come from different orientations, the network will learn the mapping from low-resolution to high-resolution on a per-slice basis. For a baseline or overfitting experiment, this is fine.

Data Shape for 2D U‑Net:
In your dataset class, each 2D slice is converted to a tensor of shape 
[1,H,W] (1 channel, height, width). When the DataLoader batches these samples, you'll get an input of shape 
[B,1,H,W].

Your model is instantiated as:

UNet2D(in_channels=1, out_channels=1)  
This is exactly what the model expects—a single-channel 2D image per sample.

Considerations:

Heterogeneity: The network will see slices from both axial and coronal views. While this can be useful for a more diverse training set, keep in mind that the network might have to learn different mappings if the appearance of these views differs significantly.


In [9]:
# # train.py
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader
# from tqdm import tqdm

# from models.model import UNetDeep
# # 

# class MRISuperResDataset(Dataset):
#     def __init__(self, lr_volumes, hr_volumes, transform=None):
#         """
#         lr_volumes and hr_volumes are lists or arrays where each element is a 
#         3D volume with shape (1, 176, 256, 256).
#         """
#         self.lr_volumes = lr_volumes
#         self.hr_volumes = hr_volumes
#         self.transform = transform

#     def __len__(self):
#         return len(self.lr_volumes)

#     def __getitem__(self, idx):
#         lr_volume = self.lr_volumes[idx].astype(np.float32)
#         hr_volume = self.hr_volumes[idx].astype(np.float32)

#         # Convert to torch tensors
#         lr_tensor = torch.from_numpy(lr_volume)
#         hr_tensor = torch.from_numpy(hr_volume)

#         if self.transform:
#             lr_tensor = self.transform(lr_tensor)
#             hr_tensor = self.transform(hr_tensor)

#         return lr_tensor, hr_tensor


        
# def update_learning_rate(schedulers, val_loss):
#     """
#     Update the learning rate for each scheduler.
    
#     For schedulers of type ReduceLROnPlateau, use scheduler.step(val_loss),
#     otherwise, call scheduler.step() without arguments.
    
#     Parameters:
#       schedulers (list): List of learning rate scheduler objects.
#       val_loss (float): The validation loss used by ReduceLROnPlateau.
#     """
#     for scheduler in schedulers:
#         if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
#             scheduler.step(val_loss)
#         else:
#             scheduler.step()

# def main():
#     # 1) Load pre-saved 2D slices (uncomment these lines)
#     slices_lr_axial = np.load("data/lr_slices_axial.npy", allow_pickle=True)
#     slices_hr_axial = np.load("data/hr_slices_axial.npy", allow_pickle=True)
#     slices_lr_coronal = np.load("data/lr_slices_coronal.npy", allow_pickle=True)
#     slices_hr_coronal = np.load("data/hr_slices_coronal.npy", allow_pickle=True)
#     slices_lr_sagittal = np.load("data/lr_slices_sagittal.npy", allow_pickle=True)
#     slices_hr_sagittal = np.load("data/hr_slices_sagittal.npy", allow_pickle=True)

#     # Stack the list of 2D slices along a new axis (axis 0) to form a 3D volume.
#     # For example, if there are 176 slices, each volume will be (176, 256, 256)
#     lr_volume_axial    = np.stack(slices_lr_axial, axis=0)
#     lr_volume_coronal  = np.stack(slices_lr_coronal, axis=0)
#     lr_volume_sagittal = np.stack(slices_lr_sagittal, axis=0)
    
#     hr_volume_axial    = np.stack(slices_hr_axial, axis=0)
#     hr_volume_coronal  = np.stack(slices_hr_coronal, axis=0)
#     hr_volume_sagittal = np.stack(slices_hr_sagittal, axis=0)
    
#     # # Add a channel dimension so that each sample becomes shape (1, 176, 256, 256)
#     # lr_volume_axial    = np.expand_dims(lr_volume_axial, axis=0)
#     # lr_volume_coronal  = np.expand_dims(lr_volume_coronal, axis=0)
#     # lr_volume_sagittal = np.expand_dims(lr_volume_sagittal, axis=0)
    
#     # hr_volume_axial    = np.expand_dims(hr_volume_axial, axis=0)
#     # hr_volume_coronal  = np.expand_dims(hr_volume_coronal, axis=0)
#     # hr_volume_sagittal = np.expand_dims(hr_volume_sagittal, axis=0)

#     all_lr_volumes = np.concatenate((slices_lr_axial, slices_lr_coronal), axis=0)
#     all_hr_volumes = np.concatenate((slices_hr_axial, slices_hr_coronal), axis=0)

#     # Create Dataset and DataLoader
#     train_dataset = MRISuperResDataset(all_lr_volumes, all_hr_volumes)
#     train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)  # batch size 1 for overfitting

    
#     # Initialize the model
#     model = UNetDeep(in_channels=1, out_channels=1)
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model.to(device)
    
#     # Define loss & optimizer
#     criterion = nn.MSELoss()  # aiming for near-zero error (overfitting)
#     optimizer = optim.Adam(model.parameters(), lr=1e-2)
    
#     num_epochs = 1  # Increase epochs to fully overfit
#     model.train()
    
#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
#         for lr_tensor, hr_tensor in loop:
#             lr_tensor = lr_tensor.to(device)
#             hr_tensor = hr_tensor.to(device)
    
#             optimizer.zero_grad()
#             outputs = model(lr_tensor)
#             loss = criterion(outputs, hr_tensor)
#             loss.backward()
#             optimizer.step()
    
#             running_loss += loss.item()
#             loop.set_postfix(loss=loss.item())
    
#         epoch_loss = running_loss / len(train_loader)
#         print(f"Epoch [{epoch+1}/{num_epochs}]  Loss: {epoch_loss:.6f}")
    
#     torch.save(model.state_dict(), "superres_unet_overfit.pth")
#     print("Model saved as superres_unet_overfit.pth")

# if __name__ == "__main__":
#     main()

Epoch 1/1:   0%|          | 0/128 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[1, 4, 176, 256] to have 1 channels, but got 4 channels instead