# Semantic Segmentation with U-Net

This notebook demonstrates how to perform semantic segmentation on remote sensing imagery using a U-Net model with `torch` and `segmentation-models-pytorch` in Python. U-Net is ideal for pixel-wise classification tasks, such as land cover segmentation.

## Prerequisites
- Install required libraries: `rasterio`, `torch`, `segmentation-models-pytorch`, `numpy`, `matplotlib` (listed in `requirements.txt`).
- A multi-band GeoTIFF file (e.g., `sample.tif`) and a labeled raster mask (e.g., `mask.tif`) for training. Replace file paths with your own data.
- GPU recommended for faster training.

## Learning Objectives
- Prepare raster data for U-Net training.
- Train a U-Net model for semantic segmentation.
- Predict and visualize segmentation results.

In [None]:
# Import required libraries
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from segmentation_models_pytorch import Unet
from sklearn.model_selection import train_test_split

## Step 1: Create Custom Dataset

Define a custom dataset to load image patches and corresponding mask patches.

In [None]:
class RasterDataset(Dataset):
    def __init__(self, image_path, mask_path, patch_size=256):
        self.image_path = image_path
        self.mask_path = mask_path
        self.patch_size = patch_size
        
        # Load image and mask
        with rasterio.open(image_path) as src_img, rasterio.open(mask_path) as src_mask:
            self.image = src_img.read().astype(np.float32)
            self.mask = src_mask.read(1).astype(np.int64)
            self.profile = src_img.profile
        
        # Normalize image
        self.image = self.image / np.max(self.image, axis=(1, 2), keepdims=True)
        
        # Get dimensions
        self.height, self.width = self.image.shape[1:]
        self.n_patches_x = self.width // patch_size
        self.n_patches_y = self.height // patch_size
    
    def __len__(self):
        return self.n_patches_x * self.n_patches_y
    
    def __getitem__(self, idx):
        # Calculate patch coordinates
        y = (idx // self.n_patches_x) * self.patch_size
        x = (idx % self.n_patches_x) * self.patch_size
        
        # Extract patch
        img_patch = self.image[:, y:y+self.patch_size, x:x+self.patch_size]
        mask_patch = self.mask[y:y+self.patch_size, x:x+self.patch_size]
        
        # Convert to tensors
        img_patch = torch.from_numpy(img_patch)
        mask_patch = torch.from_numpy(mask_patch)
        
        return img_patch, mask_patch

## Step 2: Load Data and Prepare Dataloaders

Load the image and mask rasters, split into training/validation sets, and create dataloaders.

In [None]:
# Define file paths
image_path = 'sample.tif'
mask_path = 'mask.tif'

# Create dataset
dataset = RasterDataset(image_path, mask_path, patch_size=256)

# Split into training and validation
train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Print dataset information
print(f'Total patches: {len(dataset)}')
print(f'Training patches: {len(train_dataset)}')
print(f'Validation patches: {len(val_dataset)}')

## Step 3: Initialize U-Net Model

Set up a U-Net model with a specified backbone and number of classes.

In [None]:
# Define model parameters
n_classes = len(np.unique(dataset.mask))  # Number of classes in mask
n_channels = dataset.image.shape[0]  # Number of input bands

# Initialize U-Net model
model = Unet(
    encoder_name='resnet18',  # Use ResNet18 backbone
    encoder_weights='imagenet',
    in_channels=n_channels,
    classes=n_classes
)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Step 4: Train the Model

Train the U-Net model for a specified number of epochs.

In [None]:
# Training loop
n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    
    train_loss /= len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{n_epochs}, Training Loss: {train_loss:.4f}')

## Step 5: Predict and Visualize Segmentation

Predict segmentation on the entire raster and visualize the result.

In [None]:
# Load full raster for prediction
with rasterio.open(image_path) as src:
    full_image = src.read().astype(np.float32)
    profile = src.profile
full_image = full_image / np.max(full_image, axis=(1, 2), keepdims=True)

# Initialize output array
height, width = full_image.shape[1], full_image.shape[2]
predictions = np.zeros((height, width), dtype=np.int64)

# Predict in patches
model.eval()
with torch.no_grad():
    for i in range(0, height, 256):
        for j in range(0, width, 256):
            patch = full_image[:, i:i+256, j:j+256]
            if patch.shape[1:] != (256, 256):
                continue  # Skip incomplete patches
            patch = torch.from_numpy(patch).unsqueeze(0).to(device)
            output = model(patch)
            pred = torch.argmax(output, dim=1).cpu().numpy()[0]
            predictions[i:i+256, j:j+256] = pred

# Visualize predictions
plt.figure(figsize=(8, 8))
plt.imshow(predictions, cmap='tab10')
plt.colorbar(label='Class')
plt.title('U-Net Segmentation Result')
plt.xlabel('Column')
plt.ylabel('Row')
plt.show()

## Step 6: Save Segmentation Result

Save the segmentation result as a single-band GeoTIFF.

In [None]:
# Update profile for single-band output
output_profile = profile.copy()
output_profile.update(count=1, dtype=rasterio.int64)

# Save predictions
with rasterio.open('unet_segmentation.tif', 'w', **output_profile) as dst:
    dst.write(predictions, 1)

print('Segmentation result saved to: unet_segmentation.tif')

## Next Steps

- Replace `sample.tif` and `mask.tif` with your own image and labeled mask.
- Adjust patch size, number of epochs, or model architecture (e.g., `encoder_name`).
- Add validation metrics (e.g., IoU, accuracy) for model evaluation.
- Proceed to the next notebook (`16_landcover_classification_cnn.ipynb`) for CNN-based classification.

## Notes
- Ensure the mask raster has integer class labels starting from 0.
- Normalize input data to improve training stability.
- See `docs/installation.md` for troubleshooting library installation.