This notebook demonstrates the process of optimizing hyperparameters for training a UNet model on a 2D image segmentation task using Optuna. <br> It covers data preparation, model definition, and the use of Optuna to find the best hyperparameters by minimizing validation loss. Additionally, it includes code for retraining the model with the optimal parameters and saving the final results.

<b>Dataset Link:- </b> https://www.kaggle.com/datasets/kmader/finding-lungs-in-ct-data/data

<h4> Import Required Libraries </h4>

In [None]:
import json
import glob
import numpy as np
import pandas as pd
from PIL import Image
import optuna
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from sklearn.model_selection import train_test_split

<h4>Root dir path initialization </h4>

In [None]:
# Path to dataset
root_dir = "/home/dataset/"

<h4> Collect all the image and mask files along with it's path </h4>

In [None]:
# Load image and mask file paths
image_files = sorted(glob.glob(root_dir + '2d_images/' + '*.tif'))
mask_files = sorted(glob.glob(root_dir + '2d_masks/' + '*.tif'))

# Check if the number of images matches the number of masks
if len(image_files) != len(mask_files):
    raise ValueError("Number of image files and mask files do not match.")

<h4> Dataset split for train, val and test </h4>

In [None]:
# Split data into training, validation, and test sets
train_image_files, test_image_files, train_mask_files, test_mask_files = train_test_split(
    image_files, mask_files, test_size=0.2, random_state=42
)

train_image_files, val_image_files, train_mask_files, val_mask_files = train_test_split(
    train_image_files, train_mask_files, test_size=0.25, random_state=42
)

print(f"Number of training image files: {len(train_image_files)}")
print(f"Number of validation image files: {len(val_image_files)}")
print(f"Number of testing image files: {len(test_image_files)}")
print(f"Total number of image files: {len(train_image_files) + len(val_image_files) + len(test_image_files)}")
print(f"Total number of mask files: {len(train_mask_files) + len(val_mask_files) + len(test_mask_files)}")


<h4> Check Pixel value range </h4>

In [None]:
# Check pixel value range
temp_image = Image.open(mask_files[0])
temp_image_array = np.array(temp_image)
print("Image shape:", temp_image_array.shape)
print(f"Pixel value range: {temp_image_array.min()} to {temp_image_array.max()}")

<h4> Convert the mask values into [0,1] </h4>

In [None]:
def convert_to_0_and_1(mask_image):
    mask_image_array = np.array(mask_image)
    binary_mask_array = np.where(mask_image_array > 0, 1, 0)
    return binary_mask_array.astype(np.uint8)

<h4> Dataloader for Custom dataset </h4>

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_files, mask_files, transform=None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")
        mask = convert_to_0_and_1(mask)
        mask = Image.fromarray(mask)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        mask = mask.squeeze(0).unsqueeze(0).float()
        return image, mask

In [None]:
# Define data transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Create datasets and dataloaders
train_dataset = CustomDataset(image_files=train_image_files, mask_files=train_mask_files, transform=transform)
val_dataset = CustomDataset(image_files=val_image_files, mask_files=val_mask_files, transform=transform)
test_dataset = CustomDataset(image_files=test_image_files, mask_files=test_mask_files, transform=transform)


<h4> UNet Model architecture </h4>

In [None]:
# Define UNet model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.conv_block(1, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 1024)
        self.upconv5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        enc5 = self.enc5(F.max_pool2d(enc4, 2))
        dec4 = self.upconv5(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        dec3 = self.upconv4(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        dec2 = self.upconv3(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        dec1 = self.upconv2(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)
        return torch.sigmoid(self.final_conv(dec1))


<h4>objective function for Optuna</h4>

In [None]:
# Define the objective function for Optuna
def objective(trial):
    # Hyperparameters to tune
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)
    num_epochs = trial.suggest_int('num_epochs', 20, 500)
    batch_size = trial.suggest_categorical('batch_size', [4, 8, 16, 32])
    
    # Define the device to be used for training
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define the model, criterion, and optimizer
    model = UNet().to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Create data loaders with the suggested batch size
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    best_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        epoch_train_loss = 0.0
        
        for batch_idx, (images, masks) in enumerate(train_dataloader):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
        
        epoch_train_loss /= len(train_dataloader)
        
        # Validation loop
        model.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_dataloader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

        val_loss /= len(val_dataloader)

        if val_loss < best_loss:
            best_loss = val_loss
            # Optionally save the model here

    return best_loss

<h4>Initialize Optuna study and optimize</h4>

In [None]:
# Initialize Optuna study and optimize
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=20)

# Print and save the best hyperparameters
print("Best Hyperparameters:")
print(study.best_params)
print("Best Value (Validation Loss):")
print(study.best_value)

# Save the study results
with open('optuna_study_results.json', 'w') as f:
    json.dump(study.best_params, f)
    json.dump(study.best_value, f)

# Initialize the final model with the best hyperparameters and retrain if desired
best_params = study.best_params
best_model = UNet().to(device)
best_optimizer = optim.Adam(best_model.parameters(), lr=best_params['learning_rate'])
best_criterion = nn.BCELoss()

# Retrain with best parameters
train_dataloader = DataLoader(train_dataset, batch_size=best_params['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=best_params['batch_size'], shuffle=False)

for epoch in range(best_params['num_epochs']):
    best_model.train()
    epoch_train_loss = 0.0

    for batch_idx, (images, masks) in enumerate(train_dataloader):
        images, masks = images.to(device), masks.to(device)
        outputs = best_model(images)
        loss = best_criterion(outputs, masks)
        
        best_optimizer.zero_grad()
        loss.backward()
        best_optimizer.step()
        
        epoch_train_loss += loss.item()

    epoch_train_loss /= len(train_dataloader)

    # Validation loop
    best_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = best_model(images)
            loss = best_criterion(outputs, masks)
            val_loss += loss.item()

    val_loss /= len(val_dataloader)
    print(f'Epoch [{epoch + 1}/{best_params["num_epochs"]}], Training Loss: {epoch_train_loss:.4f}, Validation Loss: {val_loss:.4f}')

# Save the final model
torch.save(best_model.state_dict(), 'best_model_with_optuna.pth')
