# UNetR Model Training

This notebook demonstrates a complete pipeline for performing 3D medical image segmentation using the UNETR architecture. 

It has been prepared based on the following original study:

[1]: Hatamizadeh, A., Tang, Y., Nath, V., Yang, D., Myronenko, A., Landman, B., Roth, H.R. and Xu, D., 2022. Unetr: Transformers for 3d medical image segmentation. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 574-584).

Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

## 1. Setup

In [None]:
#!pip install -q "monai-weekly[nibabel, tqdm, einops]"
#!python -c "import matplotlib" || pip install -q matplotlib
#%matplotlib inline

### 1.1 Import Libraries

In [None]:
import os
import re
import time
import torch
import random
import subprocess
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm.notebook import tqdm

# MONAI modules
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandAdjustContrastd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ResizeWithPadOrCropd,
    ScaleIntensityRanged,
    RandRotate90d,
)
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

# SimpleITK for medical image processing
import SimpleITK as sitk

# Set seed for reproducibility
seed_value = 12345
random.seed(seed_value)               
np.random.seed(seed_value)            
torch.manual_seed(seed_value)   
torch.cuda.manual_seed(seed_value) 
torch.cuda.manual_seed_all(seed_value) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  

### 1.2.Specify Dataset Directory

In [None]:
# Define placeholders for directory paths
data_dir = "<DATA_DIR>"                          # Path to the dataset directory, e.g., '/path/to/your/dataset'
results_base_dir = "<RESULTS_BASE_DIR>"          # Path to the base directory for saving results, e.g., '/path/to/your/results'

# Ensure the results directory exists; if not, create it
os.makedirs(results_base_dir, exist_ok=True)

## 2. Data Preprocessing

In [None]:
def calculate_windowing_range(image_dir, lower_percentile=5, upper_percentile=95, sample_size=5000, exclude_zeros=True):
    """
    Calculates a windowing range for image intensity values by analyzing specified percentiles
    (e.g., 5th and 95th percentiles). It avoids outliers by sampling random pixel values from images.
    
    Args:
    - image_dir (str): Directory containing NIfTI (.nii.gz) format images.
    - lower_percentile (int, optional): Lower percentile threshold. Default is 5.
    - upper_percentile (int, optional): Upper percentile threshold. Default is 95.
    - sample_size (int, optional): Number of random pixel values to sample from each image. Default is 5000.
    - exclude_zeros (bool, optional): Whether to exclude zero-intensity (background/air) pixels. Default is True.
    
    Returns:
    - lower_bound (float): Lower bound of the windowing range.
    - upper_bound (float): Upper bound of the windowing range.
    """
    all_sampled_values = []  # Store intensity values sampled from all images

    # Filter for .nii.gz files only
    image_files = [f for f in os.listdir(image_dir) if f.endswith('.nii.gz')]
    
    for image_file in tqdm(image_files, desc="Processing images"):
        # Load the image
        image_path = os.path.join(image_dir, image_file)
        image = sitk.ReadImage(image_path)
        image_np = sitk.GetArrayFromImage(image)
        
        # Exclude zero-intensity pixels (optional)
        if exclude_zeros:
            image_np = image_np[image_np != 0]
        
        # Randomly sample pixel values from the image (to reduce memory usage for large images)
        if image_np.size > 0:  # Ensure there are non-zero pixels
            sampled_values = np.random.choice(image_np.flatten(), size=min(sample_size, image_np.size), replace=False)
            all_sampled_values.extend(sampled_values)  # Combine sampled values

    # Compute percentiles from all sampled values
    all_sampled_values = np.array(all_sampled_values)
    
    # Visualize the intensity distribution (optional)
    plt.hist(all_sampled_values, bins=100)
    plt.title('Intensity Distribution')
    plt.xlabel('Intensity')
    plt.ylabel('Frequency')
    plt.show()

    # Calculate percentiles
    lower_bound = np.percentile(all_sampled_values, lower_percentile)
    upper_bound = np.percentile(all_sampled_values, upper_percentile)
    
    return lower_bound, upper_bound

# Specify the path to the image dataset
image_dir = os.path.join(data_dir, 'imagesTr')

# Calculate windowing range between 5th and 95th percentiles, excluding zero values
lower_bound, upper_bound = calculate_windowing_range(image_dir)

print(f"Recommended windowing range: [{lower_bound}, {upper_bound}]")

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=16, #0, #-175,
            a_max=1668, #1128, #250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Resize images to ensure dimensions are divisible by 16
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(224, 224, 64)),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(224, 224, 64),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
            allow_smaller=True  # Allow cropping even if image is smaller than spatial_size
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=(-0.1, 0.1),
            prob=0.30,
        ),
        RandAdjustContrastd(
            keys=["image"],
            prob=0.30,
            gamma=(0.8, 1.2)
        ),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=lower_bound,
            a_max=upper_bound,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Ensure validation images also match the expected size
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(224, 224, 64)),
    ]
)

## 3.Data Loading and Preparation

1. Download dataset
2. Put images in the ./data/imagesTr
3. Put labels in the ./data/labelsTr
4. make JSON file accordingly: ./data/dataset_0.json

Example of JSON file:

        {
    "description": "btcv yucheng",
    "labels": {
        "0": "background",
        "1": "spleen",
        "2": "rkid",
        "3": "lkid",
        "4": "gall",
        "5": "eso",
        "6": "liver",
        "7": "sto",
        "8": "aorta",
        "9": "IVC",
        "10": "veins",
        "11": "pancreas",
        "12": "rad",
        "13": "lad"
    },
    "licence": "yt",
    "modality": {
        "0": "CT"
    },
    "name": "btcv",
    "numTest": 20,
    "numTraining": 80,
    "reference": "Vanderbilt University",
    "release": "1.0 06/08/2015",
    "tensorImageSize": "3D",
    "test": [
        "imagesTs/img0061.nii.gz",
        "imagesTs/img0062.nii.gz",
        ...
        "imagesTs/img0080.nii.gz"
    ],
    "training": [
        {
            "image": "imagesTr/img0001.nii.gz",
            "label": "labelsTr/label0001.nii.gz"
        },
        ...
        {
            "image": "imagesTr/img0034.nii.gz",
            "label": "labelsTr/label0034.nii.gz"
        }
    ],
    "validation": [
        {
            "image": "imagesTr/img0035.nii.gz",
            "label": "labelsTr/label0035.nii.gz"
        },
        ...
        {
            "image": "imagesTr/img0040.nii.gz",
            "label": "labelsTr/label0040.nii.gz"
        }
        ]
    }
    

In [None]:
datasets = "/path/yo/your/json/file"

datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

### 3.1 Check Augmented Images

In [None]:
def visualize_slices(data_loader):
    """
    Visualizes 5 slices around the center slice from a batch of images and labels.

    Args:
    - data_loader (DataLoader): A PyTorch DataLoader providing 3D medical image data.

    Returns:
    None. Displays the visualization using matplotlib.
    """
    # Get a batch of data
    batch_data = next(iter(data_loader))
    images = batch_data['image']
    labels = batch_data['label']

    # Convert tensors to numpy arrays
    images_np = images[0].numpy()
    labels_np = labels[0].numpy()

    # Center slice index along the z-axis
    center_slice_idx = images_np.shape[-1] // 2
    slice_indices = [center_slice_idx - 24, center_slice_idx - 12, center_slice_idx, center_slice_idx + 12, center_slice_idx + 24]
    
    # Plot image slices, label slices, and overlays
    fig, axes = plt.subplots(3, 5, figsize=(15, 9))
    for i, slice_idx in enumerate(slice_indices):
        axes[0, i].imshow(images_np[0, :, :, slice_idx], cmap='gray')
        axes[0, i].set_title(f"Image Slice {slice_idx}")
        axes[0, i].axis('off')
        
        axes[1, i].imshow(labels_np[0, :, :, slice_idx], cmap='gray')
        axes[1, i].set_title(f"Label Slice {slice_idx}")
        axes[1, i].axis('off')
        
        axes[2, i].imshow(images_np[0, :, :, slice_idx], cmap='gray')
        axes[2, i].imshow(labels_np[0, :, :, slice_idx], cmap='Greens', alpha=0.5)
        axes[2, i].set_title(f"Overlay Slice {slice_idx}")
        axes[2, i].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
visualize_slices(train_loader)

## 4. Training Pipeline

### 4.1. Dynamic Training Session Management
This section dynamically determines the next training session number based on existing folders in the results directory. It creates a new directory for the current training session, including a timestamp and dataset name, ensuring all results are organized systematically.

In [None]:
# Function to determine the next training number dynamically
def get_next_training_number(results_base_dir):
    # Ensure the results directory exists
    if not os.path.exists(results_base_dir):
        os.makedirs(results_base_dir)

    # Get all subdirectories that start with 'train'
    existing_dirs = [d for d in os.listdir(results_base_dir) if os.path.isdir(os.path.join(results_base_dir, d)) and re.match(r'train\d{3}', d)]

    # If there are no existing directories, start with train001
    if not existing_dirs:
        return 'train001'

    # Extract the numbers from the folder names (e.g., 'train001', 'train002', ...)
    existing_numbers = [int(re.search(r'train(\d{3})', d).group(1)) for d in existing_dirs]
    
    # Determine the next number
    next_number = max(existing_numbers) + 1
    return f'train{next_number:03d}'  # Ensure it's zero-padded to 3 digits

# Create a folder for the current training session
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
training_number = get_next_training_number(results_base_dir)
train_name = 'schwannoma'  # You can update this to reflect the dataset you're using
results_save_dir = os.path.join(results_base_dir, f"{training_number}_{train_name}_{timestamp}")
os.makedirs(results_save_dir, exist_ok=True)

config_file_path = os.path.join(results_save_dir, 'configs.txt')
with open(config_file_path, 'w') as file:
    pass 

print(results_save_dir)

### 4.2.GPU Selection and Model Initialization

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

def get_free_gpu():
    if not torch.cuda.is_available():
        print("No GPU available, using CPU.")
        return torch.device("cpu")

    # Get GPU memory information using nvidia-smi
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            check=True
        )
        free_memory = [int(x) for x in result.stdout.decode('utf-8').strip().split('\n')]
        free_gpus = [(i, mem) for i, mem in enumerate(free_memory)]
        
        # Select the GPU with the most free memory
        best_gpu = max(free_gpus, key=lambda x: x[1])[0]
        print(f"Using GPU: {best_gpu} with {free_memory[best_gpu]} MB free memory.")
        return torch.device(f"cuda:{best_gpu}")
    
    except subprocess.CalledProcessError as e:
        print("Error using nvidia-smi:", e)
        return torch.device("cpu")

# Automatically select the GPU with the most available memory
device = get_free_gpu()
print(device)
# GPU ID'sini ECC hatasına göre manuel belirleyin
device = torch.device("cuda:0")  # ECC hatasız GPU


model = UNETR(
    in_channels=1,
    out_channels=2,
    img_size=(224, 224, 64),
    feature_size=32,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

loss_function = DiceCELoss(to_onehot_y=True, softmax=True, lambda_dice=1.5, lambda_ce=0.5)
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-5)

### 4.3. Training Loop with Early Stopping

This section defines the training loop for the model. It includes:

1. Early Stopping: Stops training if validation performance does not improve for a defined number of epochs (patience).
2. Loss and Metric Visualization: Plots and saves training loss and validation Dice metrics after each epoch.
3. Best Model Saving: Automatically saves the best-performing model based on validation Dice scores.
4. Time Limit: The training process is constrained to a specified time limit (e.g., 5 days).

In [None]:
# Patience for early stopping: Stops training if no improvement after 'patience' epochs
patience = 10
no_improvement_count = 0  # Counter for epochs without improvement
min_delta = 0.001  # Minimum improvement difference to be considered significant

# Function to plot and save loss and dice values
def plot_and_save_loss_dice(epoch_loss_values, metric_values, global_step, epoch_num, elapsed_time, results_save_dir):
    # Check if epoch_loss_values and metric_values are empty or constant
    if not epoch_loss_values or not metric_values:
        print("Epoch loss values or metric values are empty!")
    else:
        print(f"Epoch loss values: {epoch_loss_values}")
        print(f"Metric values: {metric_values}")
        
    days = int(elapsed_time // (24 * 3600))
    hours = int((elapsed_time % (24 * 3600)) // 3600)
    minutes = int((elapsed_time % 3600) // 60)

    plt.figure("train", (12, 6))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.title(f"Iteration Average Loss (Epoch {epoch_num} - {days}d {hours}:{minutes})")
    x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("Iteration")
    plt.plot(x, y)
    
    # Plot Dice
    plt.subplot(1, 2, 2)
    plt.title(f"Val Mean Dice (Epoch {epoch_num} - {days}days {hours}:{minutes})")
    x = [eval_num * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel("Iteration")
    plt.plot(x, y)
    
    plt.tight_layout()
    
    # Save the plot to the specified path
    loss_path = os.path.join(results_save_dir, f"loss.png")
    plt.savefig(loss_path)
    plt.close()
    print(f"Plot saved to {loss_path}")

# Training function
def train(global_step, train_loader, dice_val_best, global_step_best, no_improvement_count):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(
        train_loader,
        desc=f"Training ({global_step} Steps) (loss=X.X)",
        dynamic_ncols=True,
        position=0,
        leave=True
    )
    for step, batch in enumerate(epoch_iterator):
        x, y = (batch["image"].to(device), batch["label"].to(device))
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if step % 10 == 0:
            epoch_iterator.set_description(
                f"Training ({global_step} Steps) (loss={loss.item():.5f})"
            )
        if global_step % eval_num == 0 and global_step != 0:
            model.eval()
            with torch.no_grad():
                epoch_iterator_val = tqdm(
                    val_loader,
                    desc="Validation",
                    dynamic_ncols=True,
                    position=1,
                    leave=False
                )
                dice_metric.reset()
                for val_step, batch in enumerate(epoch_iterator_val):
                    val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
                    val_outputs = sliding_window_inference(val_inputs, (224, 224, 64), 4, model)
                    val_labels_list = decollate_batch(val_labels)
                    val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
                    val_outputs_list = decollate_batch(val_outputs)
                    val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
                    dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice_val = dice_metric.aggregate().item()
                dice_metric.reset()
            model.train()
            epoch_loss /= (step + 1)
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)

            # Early Stopping check
            if dice_val > dice_val_best + min_delta:
                dice_val_best = dice_val
                global_step_best = global_step
                no_improvement_count = 0  # Reset counter
                best_model_path = os.path.join(results_save_dir, f"best_model.pth")
                torch.save(model.state_dict(), best_model_path)
                tqdm.write(
                    f"Model Was Saved! Current Best Avg. Dice: {dice_val_best:.4f} Current Avg. Dice: {dice_val:.4f}"
                )
            else:
                no_improvement_count += 1  # Increment the counter
                tqdm.write(
                    f"Model Was Not Saved. Current Best Avg. Dice: {dice_val_best:.4f} Current Avg. Dice: {dice_val:.4f}"
                )

            last_model_path = os.path.join(results_save_dir, f"last_model.pth")
            torch.save(model.state_dict(), last_model_path)
            
            # Early stopping if no improvement for 'patience' epochs
            if no_improvement_count >= patience:
                print(f"Early stopping triggered after {no_improvement_count} epochs with no improvement.")
                break

        global_step += 1

    # Plot and save after each epoch
    elapsed_time = time.time() - start_time
    plot_and_save_loss_dice(epoch_loss_values, metric_values, global_step, epoch_num, elapsed_time, results_save_dir)
    
    return global_step, dice_val_best, global_step_best, no_improvement_count

# Define time limit: 5 days in seconds
time_limit = 6 * 60 #5 * 24 * 60 * 60
eval_num = 500
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

# Record the start time
start_time = time.time()

epoch_num = 1
# Start training loop (without max_iterations)
while True:
    # Check if the elapsed time has exceeded the time limit
    elapsed_time = time.time() - start_time
    if elapsed_time > time_limit:
        print(f"Stopping training after {elapsed_time / 3600:.2f} hours as the time limit of 5 days has been reached.")
        break

    # Continue training
    global_step, dice_val_best, global_step_best, no_improvement_count = train(
        global_step, train_loader, dice_val_best, global_step_best, no_improvement_count
    )
    epoch_num += 1

# Record the end time
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Training completed in {elapsed_time / 3600:.2f} hours.")

In [None]:
# Load the best model weights
model.load_state_dict(torch.load(os.path.join(results_save_dir, "best_metric_model.pth")))

### 4.4. Finalizing Training and Organizing Results

In [None]:
print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")

# Check if the training was completed without errors and rename the folder
completed_folder_name = results_save_dir + "_completed"
os.rename(results_save_dir, completed_folder_name)

print(f"Training folder renamed to: {completed_folder_name}")

## 5. Summary

Plot the loss

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)

# Save the plot to the specified path
loss_path = os.path.join(results_save_dir + '_completed', f"loss_final.png")
plt.savefig(loss_path)
plt.close()

print(f"Plot saved to {loss_path}")