# Module 09: Introduction to Image Segmentation

**Pixel-Level Understanding**

Classify every pixel in an image!

## What You'll Learn
- What is image segmentation?
- Semantic vs instance segmentation
- U-Net architecture
- Using pre-trained segmentation models
- Real-world applications

## Time: 45 minutes

In [None]:
import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet50
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: What is Image Segmentation?

### Progression of Vision Tasks

**1. Classification:**
- "What is in the image?"
- Output: One label

**2. Object Detection:**
- "What and where are objects?"
- Output: Boxes + labels

**3. Segmentation:**
- "Which pixels belong to which object?"
- Output: Labeled mask for every pixel

### Types of Segmentation

#### **Semantic Segmentation**
- Classify each pixel into a class
- Don't distinguish between instances
- Example: All people labeled as "person", all cars as "car"

#### **Instance Segmentation**
- Classify AND separate individual instances
- Example: Person1, Person2, Car1, Car2
- More complex but more informative

#### **Panoptic Segmentation**
- Combines semantic and instance
- Labels "stuff" (road, sky) semantically
- Labels "things" (people, cars) by instance

## Part 2: U-Net Architecture

**Most popular architecture for segmentation!**

### Structure

```
Input Image
    ↓
[Encoder Path] → Downsample, extract features
    ↓
[Bottleneck] → Deepest features
    ↓
[Decoder Path] → Upsample, restore resolution
    ↑ ↖ (skip connections from encoder)
    ↓
Output Segmentation Map
```

### Key Features

1. **Encoder (Contracting Path)**
   - Similar to normal CNN
   - Downsample with pooling
   - Increase channels
   - Extract high-level features

2. **Decoder (Expanding Path)**
   - Upsample with transposed convolutions
   - Decrease channels
   - Restore spatial resolution

3. **Skip Connections**
   - Connect encoder to decoder
   - Preserve fine details
   - Help recover spatial information

### Why U-Net?

- **Works with small datasets** (originally for medical imaging)
- **Precise localization** thanks to skip connections
- **Fast** and efficient
- **Easy to train**

## Part 3: Using Pre-Trained Segmentation Models

In [None]:
# Load pre-trained DeepLabV3
model = deeplabv3_resnet50(pretrained=True)
model = model.to(device)
model.eval()

print("DeepLabV3 loaded!")
print("Trained on COCO dataset")
print("Can segment 21 classes including:")
print("- person, car, bicycle, dog, cat")
print("- chair, table, potted plant")
print("- background, and more...")

## Part 4: Segmentation Example

In [None]:
def segment_image(image_path, model):
    """
    Perform semantic segmentation on an image

    Args:
        image_path: Path to image
        model: Segmentation model

    Returns:
        Original image and segmentation mask
    """
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    input_tensor = transform(image).unsqueeze(0).to(device)

    # Segment
    with torch.no_grad():
        output = model(input_tensor)["out"][0]

    # Get class predictions
    output_predictions = output.argmax(0).cpu().numpy()

    return image, output_predictions


def visualize_segmentation(image, mask, num_classes=21):
    """
    Visualize segmentation results
    """
    # Create colormap
    colors = plt.cm.tab20(np.linspace(0, 1, num_classes))

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

    # Original image
    ax1.imshow(image)
    ax1.set_title("Original Image", fontsize=14, fontweight="bold")
    ax1.axis("off")

    # Segmentation mask
    ax2.imshow(mask, cmap="tab20")
    ax2.set_title("Segmentation Mask", fontsize=14, fontweight="bold")
    ax2.axis("off")

    # Overlay
    ax3.imshow(image)
    ax3.imshow(mask, alpha=0.5, cmap="tab20")
    ax3.set_title("Overlay", fontsize=14, fontweight="bold")
    ax3.axis("off")

    plt.tight_layout()
    plt.show()


print("Segmentation functions ready!")

## Part 5: Real-World Applications

### Medical Imaging
- **Tumor detection**: Segment cancerous regions
- **Organ segmentation**: Identify organs in CT/MRI scans
- **Cell counting**: Count and segment individual cells

### Autonomous Vehicles
- **Road scene understanding**: Segment road, lanes, vehicles, pedestrians
- **Drivable area detection**: Where can the car go?
- **Obstacle detection**: Identify and segment obstacles

### Satellite Imagery
- **Land use classification**: Forest, urban, agriculture
- **Building detection**: Segment buildings from aerial images
- **Crop monitoring**: Identify crop types and health

### Photo/Video Editing
- **Background removal**: Segment person from background
- **Object selection**: Select objects for editing
- **Depth estimation**: Estimate depth from segmentation

### Manufacturing
- **Defect detection**: Segment defective areas
- **Quality control**: Identify and segment anomalies
- **Part inspection**: Segment and measure components

## Summary

### What You Learned:

1. **Segmentation Types**
   - Semantic: Classify pixels by class
   - Instance: Separate individual objects
   - Panoptic: Combine both

2. **U-Net Architecture**
   - Encoder-decoder structure
   - Skip connections for detail
   - Gold standard for segmentation

3. **Practical Implementation**
   - Used pre-trained DeepLabV3
   - Segmented images pixel-by-pixel
   - Visualized results

4. **Applications**
   - Medical imaging
   - Autonomous vehicles
   - Satellite imagery
   - Many more!

### Key Insight:
Segmentation provides the most detailed understanding of images - every pixel is classified!

### Next: Module 10 - Final Projects & Next Steps
Wrap up and plan your deep learning journey!