# Instant segmentation with U-Net

## 0. Utils

In [7]:
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from albumentations import (
    Resize,
    Compose,
    Normalize,
    HorizontalFlip,
    RandomBrightnessContrast,
)
import torch
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import os 
import numpy as np
from torchvision import transforms


## 1. Dataset creation

### 1.1 Dataset class

In [2]:
class SegmentationDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.image_paths = sorted(list(Path(images_dir).glob("*.jpg")))
        self.mask_paths = sorted(list(Path(masks_dir).glob("*.png")))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        # Load image and mask
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        # Normalize mask (binary: 0 or 1)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) / 255.0

        return image, mask

### 1.2 Data augmentation during training

In [3]:
# Training augmentations
train_transform = Compose(
    [
        HorizontalFlip(p=0.5),
        RandomBrightnessContrast(p=0.2),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        Resize(512, 512),
        ToTensorV2(),
    ]
)

# Validation augmentations
val_transform = Compose(
    [
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        Resize(512, 512),
        ToTensorV2(),
    ]
)

## 2. Training

Run the following cells if you want train the model, otherwise go to Section 3

Data Loader

In [4]:
train_images_dir = "../data/model_training/split/train/images"
train_masks_dir = "../data/model_training/split/train/masks"
val_images_dir = "../data/model_training/split/val/images"
val_masks_dir = "../data/model_training/split/val/masks"
test_images_dir = "../data/model_training/split/test/images"
test_masks_dir = "../data/split/model_training/test/masks"

In [5]:
train_dataset = SegmentationDataset(train_images_dir, train_masks_dir, transform=train_transform)
val_dataset = SegmentationDataset(val_images_dir, val_masks_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

Model definition

In [None]:
%pip install segmentation_models_pytorch

In [6]:
from segmentation_models_pytorch import Unet

# Define the U-Net model
model = Unet(
    encoder_name="resnet34",  # Encoder backbone
    encoder_weights="imagenet",  # Pretrained on ImageNet
    in_channels=3,  # Input channels (RGB)
    classes=1,  # Output channels (binary segmentation)
)

### 2.1 Training function

In [None]:
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss

torch.cuda.empty_cache()
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Loss and optimizer
criterion = BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=0.001)

def save_checkpoint(model, optimizer, epoch, path="../checkpoints/test_best_unet_model.pth"):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved at epoch {epoch} to {path}")

best_val_loss = float("inf")
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}")
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "../checkpoints/test_best_unet_model.pth")
        print("Best model saved!")

    # Save periodic checkpoints
    if epoch % 5 == 0:
        save_checkpoint(model, optimizer, epoch, path=f"../checkpoints/unet_checkpoint_epoch_{epoch}.pth")


### 2.1 Visualize results

In [9]:
import random

def visualize_test_results(model, dataloader, num_samples=10):
    """
    Visualizes random test results by displaying original images, true masks, and predicted masks.

    Parameters:
        model (torch.nn.Module): Trained model for inference.
        dataloader (DataLoader): DataLoader containing test dataset.
        num_samples (int): Number of random samples to visualize.
    """
    model.eval()

    # Flatten the dataset into a list of indices
    dataset_indices = list(range(len(dataloader.dataset)))

    # Randomly sample indices
    selected_indices = random.sample(
        dataset_indices, min(num_samples, len(dataset_indices))
    )

    with torch.no_grad():
        for idx in selected_indices:
            # Retrieve the image and mask at the selected index
            image, mask = dataloader.dataset[idx]

            # Convert to device and add batch dimension
            image_tensor = image.unsqueeze(0).to(device)
            mask_tensor = mask.unsqueeze(0).to(device)

            # Model prediction
            output = model(image_tensor)
            pred = (torch.sigmoid(output) > 0.5).float()

            # Visualize
            plt.figure(figsize=(12, 6))

            # Original image
            plt.subplot(1, 3, 1)
            plt.imshow(image.permute(1, 2, 0).cpu().numpy())
            plt.title("Original Image")
            plt.axis("off")

            # True mask
            plt.subplot(1, 3, 2)
            plt.imshow(mask.squeeze().cpu().numpy(), cmap="gray")
            plt.title("True Mask")
            plt.axis("off")

            # Predicted mask
            plt.subplot(1, 3, 3)
            plt.imshow(pred.squeeze().cpu().numpy(), cmap="gray")
            plt.title("Predicted Mask")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

## 3. Load checkpoint

Load the pre-saved checkpoint

In [None]:
from segmentation_models_pytorch import Unet

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(encoder_name="resnet34", in_channels=3, classes=1).to(device)
model.load_state_dict(torch.load("../checkpoints/test_best_unet_model.pth", map_location=device))
model.eval()

test_transform = Compose(
    [
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        Resize(512, 512),
        ToTensorV2(),
    ]
)

## 4. Predict on Single Image

In [46]:
def predict_and_save_masks(input_path, model, transform, output_dir, device=device):
    """
    Predict masks for a single image or all images in a directory and save them.

    Args:
        input_path (str): Path to a single image or a directory containing images.
        model (torch.nn.Module): Trained PyTorch model.
        transform (albumentations.Compose): Transformations for preprocessing.
        output_dir (str): Path to save the predicted masks.
        device (str): Device to run the prediction on ("cuda" or "cpu").

    Returns:
        None (saves all predicted masks to the specified output directory).
    """
    model.eval()
    model.to(device)

    os.makedirs(output_dir, exist_ok=True)

    input_path = Path(input_path)
    if input_path.is_file():
        image_paths = [input_path]
    elif input_path.is_dir():
        image_paths = list(input_path.glob("*.*")) 
    else:
        raise ValueError(f"Input path {input_path} is neither a file nor a directory.")

    for image_path in image_paths:
        image = cv2.imread(str(image_path))
        if image is None:
            print(f"Could not read image {image_path}. Skipping.")
            continue

        # Apply transformations
        transformed = transform(image=image)
        input_image = transformed["image"].unsqueeze(0).to(device)  # Add batch dimension

        # Predict the mask
        with torch.no_grad():
            output = model(input_image)
            predicted_mask = torch.sigmoid(output).squeeze(0).squeeze(0).cpu().numpy()

        predicted_mask = (predicted_mask > 0.5).astype(np.uint8) * 255 
    
        mask_name = f"{image_path.stem}_mask.png"
        mask_path = os.path.join(output_dir, mask_name)
        cv2.imwrite(mask_path, predicted_mask)

        print(f"Mask saved for {image_path} at {mask_path}")


### Single image

Replace `image_path` with the path of the single image you want to predict.\
Replace `output_dir` with the output path.

In [None]:
image_path = "CONFIGURE_PATH_TO_IMAGE" 
output_dir = "CONFIGURE_OUTPUT_DIR"            
predict_and_save_masks(image_path, model, test_transform, output_dir, device=device)


### Multiple images

Replace `images_dir` with the path of the single image you want to predict.\
Replace `output_dir` with the output path.

In [None]:
images_dir = "CONFIGURE_PATH_TO_IMAGES" 
output_dir = "CONFIGURE_OUTPUT_DIR" 

predict_and_save_masks(images_dir, model, test_transform, output_dir, device=device)
