In [None]:
# !pip install torch==2.3.0 torchaudio==2.3.0 torchvision==0.18.0
# !pip install albumentations numpy pandas scikit_learn kaggle
# !pip install resnest geffnet opencv-python pretrainedmodels tqdm Pillow packaging monai segmentation_models_pytorch

Training the U-nets

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from tqdm import tqdm

import segmentation_models_pytorch as smp
import monai.networks.nets as monai_nets

from utils_2D import *
from Kits2019_2D import Kits20192DDataset


In [None]:
# from google.colab import drive, files
# drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/flattened_data.zip -d flattened_data

In [None]:
# Define the dataset configuration
config = {
    "image_dir": "flattened_data/images",  # Directory containing the image files
    "mask_dir": "flattened_data/masks",   # Directory containing the mask files
    "split_train": 0.7,         # Training set proportion
    "split_val": 0.2,           # Validation set proportion
    "split_test": 0.1,          # Testing set proportion
    "image_size": 256,          # Target image size for resizing
    "batch_size": 32,           # Batch size for training
    "num_workers": 4,         # Number of workers for data loading
}

# Dynamically build the image and mask paths
image_paths, mask_paths = build_dataset_paths(config["image_dir"], config["mask_dir"])

# Update the config to include the paths
config.update({
    "image_paths": image_paths,
    "mask_paths": mask_paths
})

# If in google colab, update the paths to the mounted drive
# config["image_dir"] = "flattened_data/flattened_data/images"
# config["mask_dir"] = "flattened_data/flattened_data/masks"

# Prepare datasets using the updated config
train_dataset, val_dataset, test_dataset = prepare_datasets(config)

# Example output
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")



## Training

In [None]:
def train_model(model, train_loader, val_loader, config):
    """
    Train a MONAI model for medical image segmentation.

    Parameters
    ----------
    model : torch.nn.Module
        The MONAI model to train (e.g., UNet, AttentionUNet, SwinUNETR).
    train_loader : DataLoader
        DataLoader for the training dataset.
    val_loader : DataLoader
        DataLoader for the validation dataset.
    config : dict
        Configuration dictionary containing:
        - "device" (str): Device to train on ("cuda" or "cpu").
        - "lr" (float): Learning rate.
        - "epochs" (int): Number of training epochs.
        - "checkpoint_path" (str): Path to save the best model checkpoint.
        - "num_classes" (int): Number of segmentation classes.

    Returns
    -------
    torch.nn.Module
        The trained model.
    """
    device = config["device"]
    
    # Wrap model with DataParallel if multiple GPUs are available
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    criterion = nn.CrossEntropyLoss()

    best_loss = float('inf')
    checkpoint_path = config["checkpoint_path"]

    for epoch in range(config["epochs"]):
        model.train()
        epoch_loss = 0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device).long().squeeze(1)  # Fixed typo in squeeze

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{config['epochs']} - Training Loss: {avg_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device).long().squeeze(1)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        print(f"Epoch {epoch + 1}/{config['epochs']} - Validation Loss: {val_loss:.4f}")

        # Save the model with the best validation loss
        if val_loss < best_loss:
            best_loss = val_loss
            # Save the model state dict (handle DataParallel wrapper)
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), checkpoint_path)
            else:
                torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved with validation loss: {best_loss:.4f}")

    print("Training complete.")
    return model


In [None]:
def create_dataloaders(train_dataset, val_dataset, test_dataset, config):
    """
    Create DataLoaders with multi-GPU support
    """
    batch_size = config["batch_size"]
    num_workers = 4 * torch.cuda.device_count() if torch.cuda.is_available() else 4  # Scale workers with GPUs

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True  # Enables faster data transfer to GPU
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader


In [None]:
def select_model(model_name, config):
    """
    Initialize and return the segmentation model.
    """
    if model_name == "UNet":
        model = smp.Unet(
            encoder_name="resnet34",  
            encoder_weights="imagenet",   
            in_channels=config["in_channels"],  
            classes=config["num_classes"] 
        )
    elif model_name == "UNetV2":
        model = monai_nets.UNet(
            spatial_dims=config["spatial_dims"],
            in_channels=config["in_channels"],
            out_channels=config["out_channels"]
        )
    elif model_name == "AttentionUNet":
        model = monai_nets.AttentionUnet(
        spatial_dims=config["spatial_dims"],
        in_channels=config["in_channels"],
        out_channels=config["out_channels"],
    )
    elif model_name == "SwinUNETR":
        model = monai_nets.SwinUNETR(
        spatial_dims=config["spatial_dims"],
        in_channels=config["in_channels"],
        out_channels=config["out_channels"],
        img_size=config["img_size"],
        feature_size=config["feature_size"],
    )
    else:    
        raise ValueError(f"Unknown model name: {model_name}")

    return model

In [None]:
def check_mask_values(dataset):
    """Check unique values in masks"""
    all_unique = set()
    for i in range(len(dataset)):
        image, mask = dataset[i]
        unique_values = torch.unique(mask).cpu().numpy()
        all_unique.update(unique_values)
    print(f"All unique values in masks: {sorted(all_unique)}")
    return all_unique

In [None]:
selected_model_name = "UNet"
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "lr": 1e-4,
    "epochs": 20,
    "checkpoint_path": f"best_{selected_model_name}.pth",
    "num_classes": 3,
    "in_channels": 1,
    "batch_size": 16 * torch.cuda.device_count() if torch.cuda.is_available() else 16,
    # Add these parameters
    "image_dir": "flattened_data/images",  # Update this path to your image directory
    "mask_dir": "flattened_data/masks",    # Update this path to your mask directory
    "split_train": 0.7,
    "split_val": 0.15,
    "split_test": 0.15,
    "image_size": 256,
    "feature_size": 48
}

In [None]:
image_paths, mask_paths = build_dataset_paths(
    images_dir=config["image_dir"], 
    masks_dir=config["mask_dir"],
    image_ext=".jpg",  # Update if using different extension
    mask_ext=".jpg"    # Update if using different extension
)

# Update config with the paths
config["image_paths"] = image_paths
config["mask_paths"] = mask_paths

# Continue with the rest of your code
train_dataset, val_dataset, test_dataset = prepare_datasets(config)


# Assuming train_dataset, val_dataset, and test_dataset are already defined
train_loader, val_loader, test_loader = create_dataloaders(train_dataset, val_dataset, test_dataset, config)


print("Checking mask values...")
unique_train = check_mask_values(train_dataset)
unique_val = check_mask_values(val_dataset)
unique_test = check_mask_values(test_dataset)


# Select and initialize the model

model = select_model(selected_model_name, config)

# Train the selected model
trained_model = train_model(model, train_loader, val_loader, config)

#If in Google Colab, download the trained model
# files.download(f"best_{selected_model_name}.pth")