# U-Net Training for Simplified Bike Detection with On-the-Fly Data Generation

## 1. Introduction

This notebook demonstrates how to train a U-Net model to segment simplified drawings of bicycles. The key features are:

*   **Synthetic Data:** We generate training data programmatically. Each sample consists of:
    *   An input image: A simple urban background sketch (street, buildings, lampposts), a distracting sun sketch, and a simplified bike sketch (2 circles, lines for frame).
    *   A mask image: A binary mask showing *only* the pixels belonging to the bike.
*   **On-the-Fly Generation:** Data is generated in batches as needed during training, avoiding the need to store a large dataset. This also provides virtually infinite unique training examples.
*   **Variability:** The bike's position, size, and color (grayscale intensity) vary. The sun's position and brightness also vary. The background details change slightly.
*   **PyTorch Implementation:** We use PyTorch to define the U-Net architecture and the training loop.
*   **Goal:** Train the U-Net to accurately segment the bike, ignoring the background clutter and the sun.
*   **Inference & Visualization:** After training, we demonstrate inference by predicting the mask for a new image and drawing a red bounding box around the detected bike.

## 2. Imports

Import necessary libraries.

In [None]:
import os
import random
import math
import time

import numpy as np
from PIL import Image, ImageDraw, ImageFilter
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm # Progress bar

## 3. Configuration

Define key parameters for data generation and training.

In [None]:
# Data Generation Config
IMG_SIZE = 128 # prev 128 Keep it smaller for faster training initially (e.g., 128x128)
MIN_BIKE_INTENSITY = 150 # Min grayscale value for the bike (0-255)
MAX_BIKE_INTENSITY = 255 # Max grayscale value for the bike
MIN_SUN_INTENSITY = 180
MAX_SUN_INTENSITY = 255
ADD_NOISE_PROB = 0.3 # Probability to add salt & pepper noise to input image
NOISE_AMOUNT = 0.01

# Training Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16 # prev 16 Adjust based on your GPU memory
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20 # prev 20 Start with a moderate number, increase if needed
STEPS_PER_EPOCH = 200 # prev 200 Number of batches per epoch (since data is generated on-the-fly)
VALIDATION_STEPS = 50 # Number of batches for validation check per epoch

print(f"Using device: {DEVICE}")
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Steps per Epoch: {STEPS_PER_EPOCH}")

## 4. Data Generation Functions

These functions create the synthetic images and masks.

In [None]:
def add_noise(image_array, amount=0.02):
    """Adds salt and pepper noise to a numpy image array."""
    noisy_image = np.copy(image_array)
    num_pixels = image_array.size
    # Salt noise
    num_salt = np.ceil(amount * num_pixels * 0.5)
    coords = tuple(np.random.randint(0, i - 1, int(num_salt)) for i in image_array.shape)
    noisy_image[coords] = 255
    # Pepper noise
    num_pepper = np.ceil(amount * num_pixels * 0.5)
    coords = tuple(np.random.randint(0, i - 1, int(num_pepper)) for i in image_array.shape)
    noisy_image[coords] = 0
    return noisy_image

def draw_urban_background(draw, width, height, street_y):
    """Draws a simple urban background sketch."""
    line_color = random.randint(20, 75) # Varying shades of grey for background
    line_thickness = 1

    # Street level
    draw.line([(0, street_y), (width, street_y)], fill=line_color, width=line_thickness + 1)
    trottoir_y = street_y + random.randint(3, 6)
    if trottoir_y < height:
       draw.line([(0, trottoir_y), (width, trottoir_y)], fill=line_color, width=line_thickness)

    # Buildings
    current_x = 0
    min_building_width = max(15, int(width * 0.1))
    max_building_width = max(30, int(width * 0.3))

    while current_x < width - min_building_width:
        line_color = random.randint(50, 120) # Varying shades of grey for background
        fill_color = random.randint(45, line_color) # Same for buildings
        building_width = random.randint(min_building_width, max_building_width)
        building_height = random.randint(int(height*0.2), street_y - 10) # Ensure below top & above street
        building_top = street_y - building_height
        draw.rectangle(
            [(current_x, building_top), (current_x + building_width, street_y)],
            outline=line_color, width=line_thickness, fill=fill_color
        )
        # Simple windows (optional)
        if random.random() < 0.7:
            win_size = max(2, int(building_width * 0.1))
            for wx in range(current_x + 5, current_x + building_width - win_size - 2, win_size*2 + 3):
                 for wy in range(building_top + 5, street_y - win_size - 2, win_size*2 + 3):
                      if wy + win_size < street_y and wx + win_size < current_x + building_width:
                           draw.rectangle([(wx, wy), (wx + win_size, wy + win_size)], outline=line_color, width=1)

        current_x += building_width + random.randint(0, 5)

    # Lampposts
    num_lampposts = random.randint(1, 4)
    lamp_height = random.randint(int(height*0.15), int(height*0.3))
    lamp_bottom = street_y + 2
    for _ in range(num_lampposts):
        lamp_x = random.randint(10, width - 10)
        lamp_top = max(5, lamp_bottom - lamp_height)
        draw.line([(lamp_x, lamp_bottom), (lamp_x, lamp_top)], fill=20, width=line_thickness)
        draw.ellipse([(lamp_x - 2, lamp_top - 3), (lamp_x + 2, lamp_top)], fill=line_color)
    

def draw_sun(draw, width, height, street_y):
    """Draws a simple sun sketch in the sky."""
    sun_color = random.randint(MIN_SUN_INTENSITY, MAX_SUN_INTENSITY)
    sun_radius = random.randint(max(5, int(width*0.03)), max(15, int(width*0.08)))
    # Position in the sky (above buildings/street)
    max_y_for_sun = street_y - int(height * 0.1) # Ensure some space above street/buildings
    if max_y_for_sun <= sun_radius * 2 : max_y_for_sun = sun_radius*2 + 5 # Avoid edge case

    sun_cx = random.randint(sun_radius, width - sun_radius)
    sun_cy = random.randint(sun_radius, max(sun_radius + 1, max_y_for_sun - sun_radius)) # Place in upper part

    # Draw filled sun circle ONLY on the input image drawing context
    draw.ellipse(
        (sun_cx - sun_radius, sun_cy - sun_radius, sun_cx + sun_radius, sun_cy + sun_radius),
        outline=sun_color
    )

def draw_bike_simplified(draw_sketch, draw_mask, center_x, center_y, wheel_radius, bike_color, line_thickness=1):
    """Draws the simplified bike on the sketch (input) and mask (target)."""
    mask_color = 255 # Mask is always white for the bike
    mask_line_thickness = max(2, line_thickness * 2, wheel_radius // 4) # Thicker for mask coverage

    # Simplified fixed geometry relative to wheel radius
    wheel_distance = int(wheel_radius * 3.2)
    rear_wheel_cx = center_x - wheel_distance // 2
    rear_wheel_cy = center_y
    front_wheel_cx = center_x + wheel_distance // 2
    front_wheel_cy = center_y

    A = (rear_wheel_cx, rear_wheel_cy) # Rear wheel center
    E = (front_wheel_cx, front_wheel_cy) # Front wheel center
    B = (center_x, center_y + wheel_radius * 0.3) # Bottom bracket approx
    C = (rear_wheel_cx + wheel_radius * 0.4, center_y - wheel_radius * 1.8) # Seat approx
    D = (front_wheel_cx - wheel_radius * 0.2, center_y - wheel_radius * 1.5) # Handlebar approx

    # --- Draw on Sketch (Input Image) ---
    # Wheels (thin outline)
    draw_sketch.ellipse((A[0]-wheel_radius, A[1]-wheel_radius, A[0]+wheel_radius, A[1]+wheel_radius), outline=bike_color, width=line_thickness)
    draw_sketch.ellipse((E[0]-wheel_radius, E[1]-wheel_radius, E[0]+wheel_radius, E[1]+wheel_radius), outline=bike_color, width=line_thickness)
    # Frame (thin lines)
    draw_sketch.line([A, B], fill=bike_color, width=line_thickness) # Chainstay
    draw_sketch.line([B, C], fill=bike_color, width=line_thickness) # Seat tube
    draw_sketch.line([C, A], fill=bike_color, width=line_thickness) # Seat stay
    draw_sketch.line([C, D], fill=bike_color, width=line_thickness) # Top tube
    draw_sketch.line([B, D], fill=bike_color, width=line_thickness) # Down tube
    draw_sketch.line([D, E], fill=bike_color, width=line_thickness) # Fork

    # --- Draw on Mask (Target Image) ---
    # Wheels (filled circles)
    draw_mask.ellipse((A[0]-wheel_radius, A[1]-wheel_radius, A[0]+wheel_radius, A[1]+wheel_radius), fill=mask_color)
    draw_mask.ellipse((E[0]-wheel_radius, E[1]-wheel_radius, E[0]+wheel_radius, E[1]+wheel_radius), fill=mask_color)
    # Frame (thick lines for coverage)
    draw_mask.line([A, B], fill=mask_color, width=mask_line_thickness)
    draw_mask.line([B, C], fill=mask_color, width=mask_line_thickness)
    draw_mask.line([C, A], fill=mask_color, width=mask_line_thickness)
    draw_mask.line([C, D], fill=mask_color, width=mask_line_thickness)
    draw_mask.line([B, D], fill=mask_color, width=mask_line_thickness)
    draw_mask.line([D, E], fill=mask_color, width=mask_line_thickness)


def generate_bike_sample(img_size):
    """Generates one pair of (input_image, mask_image) as NumPy arrays."""
    width, height = img_size, img_size
    input_img = Image.new('L', (width, height), 0) # Black background
    mask_img = Image.new('L', (width, height), 0)  # Black background for mask

    draw_input = ImageDraw.Draw(input_img)
    draw_mask = ImageDraw.Draw(mask_img)

    street_y = int(height * random.uniform(0.7, 0.85))

    # 2. Draw Sun (only on input image)
    draw_sun(draw_input, width, height, street_y)
    
    # 1. Draw Background elements (street, buildings, etc.)
    draw_urban_background(draw_input, width, height, street_y)

    

    # 3. Draw Bike (on input image and mask)
    # Bike parameters
    min_radius = max(5, int(width * 0.04))
    max_radius = max(10, int(width * 0.1))
    wheel_radius = random.randint(min_radius, max_radius)
    bike_color = random.randint(MIN_BIKE_INTENSITY, MAX_BIKE_INTENSITY)

    # Bike position (mostly horizontal variation, vertically near street)
    bike_width_approx = wheel_radius * 3.2 + 2 * wheel_radius
    margin_x = int(bike_width_approx / 2) + 5
    bike_center_x = random.randint(margin_x, width - margin_x)
    # Adjust y so wheels are near the generated street level
    bike_center_y = street_y - random.randint(0, wheel_radius // 2) # Slightly above or on the line

    draw_bike_simplified(draw_input, draw_mask, bike_center_x, bike_center_y, wheel_radius, bike_color)

    # Convert to NumPy arrays
    input_array = np.array(input_img)
    mask_array = np.array(mask_img)

    # 4. Optional Noise (only on input image)
    if random.random() < ADD_NOISE_PROB:
        input_array = add_noise(input_array, amount=NOISE_AMOUNT)

    return input_array, mask_array

# Let's test the generator and visualize one sample
test_img_np, test_mask_np = generate_bike_sample(IMG_SIZE)

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(test_img_np, cmap='gray')
ax[0].set_title('Generated Input Image')
ax[0].axis('off')
ax[1].imshow(test_mask_np, cmap='gray')
ax[1].set_title('Generated Mask')
ax[1].axis('off')
plt.tight_layout()
plt.show()

print(f"Input shape: {test_img_np.shape}, Mask shape: {test_mask_np.shape}")
print(f"Input dtype: {test_img_np.dtype}, Mask dtype: {test_mask_np.dtype}")
print(f"Input min/max: {np.min(test_img_np)}/{np.max(test_img_np)}")
print(f"Mask min/max: {np.min(test_mask_np)}/{np.max(test_mask_np)}")

## 5. PyTorch Dataset and DataLoader

We create a custom `Dataset` that uses our `generate_bike_sample` function. The `DataLoader` will then handle batching.

In [None]:
class BikeDataset(Dataset):
    """PyTorch Dataset for generating bike images and masks on-the-fly."""
    def __init__(self, img_size, num_samples):
        self.img_size = img_size
        self.num_samples = num_samples # Effectively, steps per epoch

    def __len__(self):
        # Length is the number of samples we want to generate per epoch
        return self.num_samples

    def __getitem__(self, idx):
        # Generate a new sample each time __getitem__ is called
        input_np, mask_np = generate_bike_sample(self.img_size)

        # Convert NumPy arrays to PyTorch tensors
        # Input: Add channel dimension (C, H, W) and normalize to [0, 1]
        input_tensor = torch.from_numpy(input_np).float().unsqueeze(0) / 255.0

        # Mask: Add channel dimension and normalize to [0, 1] (for BCE loss)
        mask_tensor = torch.from_numpy(mask_np).float().unsqueeze(0) / 255.0

        return input_tensor, mask_tensor

# Create DataLoaders for training and validation
# For validation, we generate a separate set of samples on-the-fly
num_workers= 0 # os.cpu_count()//2
train_dataset = BikeDataset(IMG_SIZE, STEPS_PER_EPOCH * BATCH_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers) # Shuffle batches each epoch


val_dataset = BikeDataset(IMG_SIZE, VALIDATION_STEPS * BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)

# Let's check one batch
try:
    first_batch_img, first_batch_mask = next(iter(train_loader))
    print("Successfully loaded one batch.")
    print(f"Image batch shape: {first_batch_img.shape}") # Should be [BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE]
    print(f"Mask batch shape: {first_batch_mask.shape}")
    print(f"Image batch dtype: {first_batch_img.dtype}") # Should be torch.float32
    print(f"Image batch min/max: {first_batch_img.min():.2f}/{first_batch_img.max():.2f}") # Should be ~0.0/1.0
    print(f"Mask batch min/max: {first_batch_mask.min():.2f}/{first_batch_mask.max():.2f}") # Should be 0.0/1.0
except Exception as e:
    print(f"Error loading batch: {e}")
    print("Check num_workers or data generation logic if issues persist.")

## 6. U-Net Model Definition

Define the U-Net architecture using standard PyTorch modules.

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels) # in_channels already halved by ConvTranspose

    def forward(self, x1, x2):
        # x1 is from transpose conv, x2 is skip connection
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        # Pad x1 to match x2's dimensions if needed
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1) # Concatenate along channel dimension
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor) # Adjusted for factor
        self.up1 = Up(1024, 512 // factor, bilinear) # Adjusted for factor
        self.up2 = Up(512, 256 // factor, bilinear) # Adjusted for factor
        self.up3 = Up(256, 128 // factor, bilinear) # Adjusted for factor
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits # Return logits for BCEWithLogitsLoss

# Instantiate the model
# n_channels=1 (grayscale), n_classes=1 (bike or not bike)
model = UNet(n_channels=1, n_classes=1).to(DEVICE)

# Test with a dummy input
dummy_input = torch.randn(BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE).to(DEVICE)
try:
    output = model(dummy_input)
    print(f"Model output shape: {output.shape}") # Should be [BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE]
    print("Model definition seems okay.")
except Exception as e:
    print(f"Error during model forward pass test: {e}")

# Count model parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")

## 7. Training Setup

Define the loss function and optimizer. We use `BCEWithLogitsLoss` which is suitable for binary segmentation and expects raw logits from the model.

In [None]:
# Loss function
# BCEWithLogitsLoss combines Sigmoid layer and BCELoss in one single class.
# It's more numerically stable than using a plain Sigmoid followed by BCELoss.
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler (optional, but can help)
# Reduce LR if validation loss plateaus
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)

## 8. Training Loop

Train the model using the generated data.

In [None]:
train_losses = []
val_losses = []

print(f"Starting training for {NUM_EPOCHS} epochs on {DEVICE}...")

for epoch in range(NUM_EPOCHS):
    model.train() # Set model to training mode
    epoch_train_loss = 0.0
    
    # Use tqdm for progress bar on the training loader
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", unit="batch")

    for images, masks in train_pbar:
        images = images.to(DEVICE)
        masks = masks.to(DEVICE) # Target masks should be float [0, 1]

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks) # Compare logits with target mask

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        
        # Update progress bar description with current loss
        train_pbar.set_postfix({'loss': loss.item()})

    avg_train_loss = epoch_train_loss / len(train_loader) # len(train_loader) is STEPS_PER_EPOCH
    train_losses.append(avg_train_loss)

    # --- Validation Phase ---
    model.eval() # Set model to evaluation mode
    epoch_val_loss = 0.0
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]", unit="batch")
    
    with torch.no_grad(): # No need to track gradients during validation
        for images, masks in val_pbar:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, masks)
            epoch_val_loss += loss.item()
            val_pbar.set_postfix({'loss': loss.item()})

    avg_val_loss = epoch_val_loss / len(val_loader) # len(val_loader) is VALIDATION_STEPS
    val_losses.append(avg_val_loss)
    
    # Optional: Adjust learning rate with scheduler based on validation loss
    # scheduler.step(avg_val_loss)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}")
    
    # Optional: Save model checkpoint periodically or based on best validation loss
    # if avg_val_loss < best_val_loss:
    #     best_val_loss = avg_val_loss
    #     torch.save(model.state_dict(), 'best_unet_model.pth')

print("Training finished.")

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, label='Training Loss')
plt.plot(range(1, NUM_EPOCHS + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss (BCEWithLogits)')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

## 9. Inference and Visualization

Now, let's use the trained model to predict the mask for a new generated image and draw a bounding box around the detected bike.

In [None]:
def mask_to_bbox(mask_np, threshold=0.5):
    """Calculates the bounding box from a binary mask (NumPy array)."""
    # Ensure mask is binary
    binary_mask = (mask_np > threshold).astype(np.uint8)

    # Find contours - OpenCV is good for this, but let's use NumPy for simplicity
    rows = np.any(binary_mask, axis=1)
    cols = np.any(binary_mask, axis=0)
    
    if not np.any(rows) or not np.any(cols):
        return None # No object found

    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    # Return format: (xmin, ymin, xmax, ymax) consistent with PIL draw
    return (cmin, rmin, cmax, rmax)

def draw_bbox(image_pil, bbox, color="red", thickness=2):
    """Draws a bounding box on a PIL image."""
    if bbox is None:
        return image_pil # Return original if no bbox
        
    draw = ImageDraw.Draw(image_pil)
    # The bbox tuple is (xmin, ymin, xmax, ymax)
    draw.rectangle(bbox, outline=color, width=thickness)
    return image_pil

# Generate a few new samples for inference
num_inference_samples = 30
inference_results = []

model.eval() # Set model to evaluation mode
with torch.no_grad():
    for _ in range(num_inference_samples):
        # 1. Generate new data
        input_np, true_mask_np = generate_bike_sample(IMG_SIZE)
        
        # 2. Prepare for model (convert to tensor, normalize, add batch dim, move to device)
        input_tensor = torch.from_numpy(input_np).float().unsqueeze(0).unsqueeze(0) / 255.0
        input_tensor = input_tensor.to(DEVICE)

        # 3. Get model prediction (logits)
        pred_logits = model(input_tensor)

        # 4. Convert prediction to probability map (sigmoid) and then to NumPy mask
        pred_prob = torch.sigmoid(pred_logits)
        pred_mask_np = pred_prob.squeeze().cpu().numpy() # Remove batch and channel dims

        # 5. Calculate bounding box from the predicted mask
        bbox = mask_to_bbox(pred_mask_np, threshold=0.5) # Use 0.5 threshold

        # 6. Prepare images for display (convert input back to PIL)
        # Input needs denormalizing and converting back to uint8 if normalized earlier
        # Since we only converted to float and divided by 255, multiply back
        display_img_np = (input_np).astype(np.uint8)
        img_pil = Image.fromarray(display_img_np).convert("RGB") # Convert to RGB for red box

        # 7. Draw the bounding box on the PIL image
        img_with_bbox = draw_bbox(img_pil, bbox, color="red", thickness=1)

        inference_results.append({
            "input_pil": Image.fromarray(display_img_np), # Original grayscale input
            "true_mask_np": true_mask_np,
            "pred_mask_np": pred_mask_np,
            "img_with_bbox": img_with_bbox,
            "bbox": bbox
        })

# Display the results
fig, axes = plt.subplots(num_inference_samples, 4, figsize=(16, num_inference_samples * 4))
fig.suptitle("Inference Results", fontsize=16)

for i, result in enumerate(inference_results):
    ax_input = axes[i, 0]
    ax_true_mask = axes[i, 1]
    ax_pred_mask = axes[i, 2]
    ax_bbox = axes[i, 3]

    ax_input.imshow(result["input_pil"], cmap='gray')
    ax_input.set_title(f"Input Image {i+1}")
    ax_input.axis('off')

    ax_true_mask.imshow(result["true_mask_np"], cmap='gray')
    ax_true_mask.set_title("True Mask")
    ax_true_mask.axis('off')

    im = ax_pred_mask.imshow(result["pred_mask_np"], cmap='viridis', vmin=0, vmax=1) # Show probability map
    ax_pred_mask.set_title("Predicted Mask (Prob)")
    ax_pred_mask.axis('off')
    # fig.colorbar(im, ax=ax_pred_mask, fraction=0.046, pad=0.04) # Optional colorbar

    ax_bbox.imshow(result["img_with_bbox"])
    ax_bbox.set_title(f"Prediction w/ BBox\n{result['bbox']}") # Display coords
    ax_bbox.axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to prevent title overlap
plt.show()

# Optional: Save the trained model
# torch.save(model.state_dict(), 'final_unet_bike_detector.pth')
# print("Model state dictionary saved to final_unet_bike_detector.pth")

## 10. Conclusion

This notebook demonstrated the process of training a U-Net for object segmentation using entirely synthetic, on-the-fly generated data.

**Key takeaways:**

*   **Synthetic data** can be effective for training, especially for bootstrapping or understanding model behavior on specific features (like shape).
*   **On-the-fly generation** is memory-efficient and provides a vast amount of training data, reducing overfitting.
*   **Controlling variability** (object color, distractors) is crucial to ensure the model learns the desired features (shape) rather than simple shortcuts (color intensity).
*   The U-Net architecture is well-suited for **pixel-wise segmentation tasks**.
*   The output mask from the U-Net can be easily post-processed (e.g., thresholding, finding contours) to extract higher-level information like **bounding boxes**.

**Potential improvements:**

*   More complex backgrounds and bike variations.
*   Data augmentation (rotation, scaling, elastic deformations) applied to generated samples.
*   Hyperparameter tuning (learning rate, batch size, network depth/width).
*   Using more advanced metrics beyond loss (e.g., Dice coefficient, IoU) for evaluation.
*   Training for longer or using learning rate scheduling.