# U-Net Training for Simplified Bike Detection with On-the-Fly Data Generation

## 1. Introduction

This notebook demonstrates how to train a U-Net model to segment simplified drawings of bicycles. The key features are:

*   **Synthetic Data:** We generate training data programmatically. Each sample consists of:
    *   An input image: A simple urban background sketch (street, buildings, lampposts), a distracting sun sketch, and a simplified bike sketch (2 circles, lines for frame).
    *   A mask image: A binary mask showing *only* the pixels belonging to the bike.
*   **On-the-Fly Generation:** Data is generated in batches as needed during training, avoiding the need to store a large dataset. This also provides virtually infinite unique training examples.
*   **Variability:** The bike's position, size, and color (grayscale intensity) vary. The sun's position and brightness also vary. The background details change slightly.
*   **PyTorch Implementation:** We use PyTorch to define the U-Net architecture and the training loop.
*   **Goal:** Train the U-Net to accurately segment the bike, ignoring the background clutter and the sun.
*   **Inference & Visualization:** After training, we demonstrate inference by predicting the mask for a new image and drawing a red bounding box around the detected bike.

## 2. Imports

Import necessary libraries.

In [None]:
import os

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm # Progress bar

# Import local modules
import unet_model as unet
import drawing_bike as pic

## 3. Configuration

Define key parameters for data generation and training.

In [None]:
# Data Generation Config
IMG_SIZE = 128 # prev 128 Keep it smaller for faster training initially (e.g., 128x128)
ADD_NOISE_PROB = 0.5 # Probability to add salt & pepper noise to input image
NOISE_AMOUNT = 0.04 # Amount of noise to add (0.0 to 1.0)

# Training Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16 # prev 16 Adjust based on your GPU memory
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20 # prev 20 Start with a moderate number, increase if needed
STEPS_PER_EPOCH = 200 # prev 200 Number of batches per epoch (since data is generated on-the-fly)
VALIDATION_STEPS = 50 # Number of batches for validation check per epoch

MODEL_FILE_PATH = "../../models/unet/bike_" # Path to save the trained model
FILE_VAL_LOSS = f"{MODEL_FILE_PATH}val_loss.pth" # File to save validation loss
FILE_MODEL = f"{MODEL_FILE_PATH}model.pth" # File to save validation loss

print(f"Using device: {DEVICE}")
print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Steps per Epoch: {STEPS_PER_EPOCH}")

## 4. Data Generation Functions

These functions create the synthetic images and masks.

In [None]:
pic.main(IMG_SIZE) # Initialize the drawing module

## 5. PyTorch Dataset and DataLoader

We create a custom `Dataset` that uses our `generate_bike_sample` function. The `DataLoader` will then handle batching.

In [None]:
class BikeDataset(Dataset):
    """PyTorch Dataset for generating bike images and masks on-the-fly."""
    def __init__(self, img_size, num_samples):
        self.img_size = img_size
        self.num_samples = num_samples # Effectively, steps per epoch

    def __len__(self):
        # Length is the number of samples we want to generate per epoch
        return self.num_samples

    def __getitem__(self, idx):
        # Generate a new sample each time __getitem__ is called
        input_np, mask_np = pic.generate_bike_sample(self.img_size, self.img_size, ADD_NOISE_PROB, NOISE_AMOUNT)

        # Convert NumPy arrays to PyTorch tensors
        # Input: Add channel dimension (C, H, W) and normalize to [0, 1]
        input_tensor = torch.from_numpy(input_np).float().unsqueeze(0) / 255.0

        # Mask: Add channel dimension and normalize to [0, 1] (for BCE loss)
        mask_tensor = torch.from_numpy(mask_np).float().unsqueeze(0) / 255.0

        return input_tensor, mask_tensor

# Create DataLoaders for training and validation
# For validation, we generate a separate set of samples on-the-fly
num_workers = 0 if os.name == 'nt' else os.cpu_count()//2  # Zero for Windows compatibility, else use half of available cores
train_dataset = BikeDataset(IMG_SIZE, STEPS_PER_EPOCH * BATCH_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers) # Shuffle batches each epoch


val_dataset = BikeDataset(IMG_SIZE, VALIDATION_STEPS * BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)

# Let's check one batch
try:
    first_batch_img, first_batch_mask = next(iter(train_loader))
    print("Successfully loaded one batch.")
    print(f"Image batch shape: {first_batch_img.shape}") # Should be [BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE]
    print(f"Mask batch shape: {first_batch_mask.shape}")
    print(f"Image batch dtype: {first_batch_img.dtype}") # Should be torch.float32
    print(f"Image batch min/max: {first_batch_img.min():.2f}/{first_batch_img.max():.2f}") # Should be ~0.0/1.0
    print(f"Mask batch min/max: {first_batch_mask.min():.2f}/{first_batch_mask.max():.2f}") # Should be 0.0/1.0
except Exception as e:
    print(f"Error loading batch: {e}")
    print("Check num_workers or data generation logic if issues persist.")

## 6. U-Net Model Definition

Define the U-Net architecture using standard PyTorch modules.

In [None]:
# Instantiate the model
# n_channels=1 (grayscale), n_classes=1 (bike or not bike)
model = unet.main(device=DEVICE, img_size=IMG_SIZE, n_channels=1, n_classes=1, batch_size=BATCH_SIZE)

## 7. Training Setup

Define the loss function and optimizer. We use `BCEWithLogitsLoss` which is suitable for binary segmentation and expects raw logits from the model.

In [None]:
# Loss function
# BCEWithLogitsLoss combines Sigmoid layer and BCELoss in one single class.
# It's more numerically stable than using a plain Sigmoid followed by BCELoss.
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler (optional, but can help)
# Reduce LR if validation loss plateaus
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)

## 8. Training Loop

Train the model using the generated data.

In [None]:
best_val_loss = torch.load(FILE_VAL_LOSS) if os.path.exists(FILE_VAL_LOSS) else float('inf')
if(os.path.exists(FILE_MODEL)):
    model.load_state_dict(torch.load(FILE_MODEL))
    print(f"Model loaded successfully. val_loss: {best_val_loss:.5f}")
else:
    print("No pre-trained model found. Starting from scratch.")

In [None]:
train_losses = []
val_losses = []

print(f"Starting training for {NUM_EPOCHS} epochs on {DEVICE}...")

for epoch in range(NUM_EPOCHS):
    model.train() # Set model to training mode
    epoch_train_loss = 0.0
    
    # Use tqdm for progress bar on the training loader
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", unit="batch")

    for images, masks in train_pbar:
        images = images.to(DEVICE)
        masks = masks.to(DEVICE) # Target masks should be float [0, 1]

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks) # Compare logits with target mask

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()
        
        # Update progress bar description with current loss
        train_pbar.set_postfix({'loss': loss.item()})

    avg_train_loss = epoch_train_loss / len(train_loader) # len(train_loader) is STEPS_PER_EPOCH
    train_losses.append(avg_train_loss)

    # --- Validation Phase ---
    model.eval() # Set model to evaluation mode
    epoch_val_loss = 0.0
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]", unit="batch")
    
    with torch.no_grad(): # No need to track gradients during validation
        for images, masks in val_pbar:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, masks)
            epoch_val_loss += loss.item()
            val_pbar.set_postfix({'loss': loss.item()})

    avg_val_loss = epoch_val_loss / len(val_loader) # len(val_loader) is VALIDATION_STEPS
    val_losses.append(avg_val_loss)
    
    # Optional: Adjust learning rate with scheduler based on validation loss
    # scheduler.step(avg_val_loss)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}")
    
    # Optional: Save model checkpoint periodically or based on best validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), FILE_MODEL) # Save model state dict
        torch.save(avg_val_loss, FILE_VAL_LOSS)

print("Training finished.")

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, label='Training Loss')
plt.plot(range(1, NUM_EPOCHS + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss (BCEWithLogits)')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

## 9. Inference and Visualization

Now, let's use the trained model to predict the mask for a new generated image and draw a bounding box around the detected bike.

In [None]:
# Generate a few new samples for inference
num_inference_samples = 30
inference_results = []

model.eval() # Set model to evaluation mode
with torch.no_grad():
    for _ in range(num_inference_samples):
        # 1. Generate new data
        input_np, true_mask_np = pic.generate_bike_sample(IMG_SIZE, IMG_SIZE)
        
        # 2. Prepare for model (convert to tensor, normalize, add batch dim, move to device)
        input_tensor = torch.from_numpy(input_np).float().unsqueeze(0).unsqueeze(0) / 255.0
        input_tensor = input_tensor.to(DEVICE)

        # 3. Get model prediction (logits)
        pred_logits = model(input_tensor)

        # 4. Convert prediction to probability map (sigmoid) and then to NumPy mask
        pred_prob = torch.sigmoid(pred_logits)
        pred_mask_np = pred_prob.squeeze().cpu().numpy() # Remove batch and channel dims

        # 5. Calculate bounding box from the predicted mask
        bbox = pic.mask_to_bbox(pred_mask_np, threshold=0.5) # Use 0.5 threshold

        # 6. Prepare images for display (convert input back to PIL)
        # Input needs denormalizing and converting back to uint8 if normalized earlier
        # Since we only converted to float and divided by 255, multiply back
        display_img_np = (input_np).astype(np.uint8)
        img_pil = Image.fromarray(display_img_np).convert("RGB") # Convert to RGB for red box

        # 7. Draw the bounding box on the PIL image
        img_with_bbox = pic.draw_bbox(img_pil, bbox, color="red", thickness=1)

        inference_results.append({
            "input_pil": Image.fromarray(display_img_np), # Original grayscale input
            "true_mask_np": true_mask_np,
            "pred_mask_np": pred_mask_np,
            "img_with_bbox": img_with_bbox,
            "bbox": bbox
        })

# Display the results
fig, axes = plt.subplots(num_inference_samples, 4, figsize=(16, num_inference_samples * 4))
fig.suptitle("Inference Results", fontsize=16)

for i, result in enumerate(inference_results):
    ax_input = axes[i, 0]
    ax_true_mask = axes[i, 1]
    ax_pred_mask = axes[i, 2]
    ax_bbox = axes[i, 3]

    ax_input.imshow(result["input_pil"], cmap='gray')
    ax_input.set_title(f"Input Image {i+1}")
    ax_input.axis('off')

    ax_true_mask.imshow(result["true_mask_np"], cmap='gray')
    ax_true_mask.set_title("True Mask")
    ax_true_mask.axis('off')

    im = ax_pred_mask.imshow(result["pred_mask_np"], cmap='viridis', vmin=0, vmax=1) # Show probability map
    ax_pred_mask.set_title("Predicted Mask (Prob)")
    ax_pred_mask.axis('off')
    # fig.colorbar(im, ax=ax_pred_mask, fraction=0.046, pad=0.04) # Optional colorbar

    ax_bbox.imshow(result["img_with_bbox"])
    ax_bbox.set_title(f"Prediction w/ BBox\n{result['bbox']}") # Display coords
    ax_bbox.axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to prevent title overlap
plt.show()

# Optional: Save the trained model
# torch.save(model.state_dict(), 'final_unet_bike_detector.pth')
# print("Model state dictionary saved to final_unet_bike_detector.pth")

## 10. Conclusion

This notebook demonstrated the process of training a U-Net for object segmentation using entirely synthetic, on-the-fly generated data.

**Key takeaways:**

*   **Synthetic data** can be effective for training, especially for bootstrapping or understanding model behavior on specific features (like shape).
*   **On-the-fly generation** is memory-efficient and provides a vast amount of training data, reducing overfitting.
*   **Controlling variability** (object color, distractors) is crucial to ensure the model learns the desired features (shape) rather than simple shortcuts (color intensity).
*   The U-Net architecture is well-suited for **pixel-wise segmentation tasks**.
*   The output mask from the U-Net can be easily post-processed (e.g., thresholding, finding contours) to extract higher-level information like **bounding boxes**.

**Potential improvements:**

*   More complex backgrounds and bike variations.
*   Data augmentation (rotation, scaling, elastic deformations) applied to generated samples.
*   Hyperparameter tuning (learning rate, batch size, network depth/width).
*   Using more advanced metrics beyond loss (e.g., Dice coefficient, IoU) for evaluation.
*   Training for longer or using learning rate scheduling.