# Satellite Image Segmentation Using Deep Learning

## Introduction

This implements state-of-the-art deep learning techniques for semantic segmentation of satellite imagery.

## Segmentation Methods

1. **Traditional Methods**: Thresholding, clustering, edge detection
2. **Deep Learning Methods**: U-Net, DeepLabV3+, SegNet

## Mathematical Foundation

In semantic segmentation, we assign a class label to each pixel in the image. For a pixel $(i,j)$, the predicted class is:

$$\hat{y}_{i,j} = \arg\max_{c} p(y_{i,j} = c | x)$$

where $p(y_{i,j} = c | x)$ is the probability that pixel $(i,j)$ belongs to class $c$ given input image $x$.

In [None]:
# Basic imports
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import time

# Image processing
import cv2
from PIL import Image
from skimage import io, transform, segmentation, color
import rasterio
from rasterio.plot import show

# Deep learning imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp

# Metrics and evaluation
from sklearn.metrics import confusion_matrix, accuracy_score, jaccard_score
import seaborn as sns
from tqdm.notebook import tqdm

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Libraries imported successfully.")

# Load and Display Satellite Images

Satellite images require special handling due to their unique characteristics:

1. **Multiple bands**: Beyond RGB (e.g., NIR, SWIR)
2. **Large dimensions**: Often much larger than standard images
3. **Different resolutions**: Spatial, spectral, and temporal
4. **Georeferencing**: Contains geographical metadata

In [None]:
class SatelliteImageHandler:
    """
    A class for loading and processing satellite images.
    """
    
    def __init__(self, data_dir=None):
        """
        Initialize the satellite image handler.
        
        Args:
            data_dir (str, optional): Directory containing satellite images.
        """
        self.data_dir = data_dir
        
    def load_image(self, image_path):
        """
        Load a satellite image from path.
        
        Args:
            image_path (str): Path to the image file.
            
        Returns:
            numpy.ndarray: Loaded image.
        """
        if image_path.lower().endswith('.tif'):
            try:
                with rasterio.open(image_path) as src:
                    # Read all bands and convert to RGB if needed
                    img = src.read()
                    
                    # Extract metadata
                    self.metadata = {
                        'crs': src.crs,
                        'transform': src.transform,
                        'bounds': src.bounds,
                        'height': src.height,
                        'width': src.width
                    }
                    
                    # If more than 3 bands, take first 3 (assuming RGB)
                    if img.shape[0] > 3:
                        img = img[:3]
                        
                    # Transpose to (H, W, C) format for display
                    img = np.transpose(img, (1, 2, 0))
                    
                    # Normalize to 0-255 range if needed
                    if img.max() > 0:
                        img = (img / img.max() * 255).astype(np.uint8)
                        
                    return img
            except Exception as e:
                print(f"Error loading GeoTIFF: {e}")
                # Fallback to regular image loading
                img = cv2.imread(image_path)
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            # For regular image formats
            img = cv2.imread(image_path)
            if img is not None:
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            else:
                raise ValueError(f"Could not load image from {image_path}")
    
    def display_image(self, image=None, image_path=None, title="Satellite Image", figsize=(12, 10)):
        """
        Display a satellite image.
        
        Args:
            image (numpy.ndarray, optional): Image to display.
            image_path (str, optional): Path to image file (if image not provided).
            title (str): Title for the plot.
            figsize (tuple): Figure size.
        """
        if image is None and image_path is not None:
            image = self.load_image(image_path)
            
        if image is None:
            print("No image provided to display")
            return
            
        plt.figure(figsize=figsize)
        plt.imshow(image)
        plt.title(title)
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        # Print image stats
        print(f"Image shape: {image.shape}")
        print(f"Image type: {image.dtype}")
        print(f"Value range: [{image.min()}, {image.max()}]")
        
    def create_false_color(self, img, bands=[3, 2, 1]):
        """
        Create a false color composite from a multi-band image.
        
        Args:
            img (numpy.ndarray): Multi-band image.
            bands (list): List of band indices to use [R, G, B].
            
        Returns:
            numpy.ndarray: False color image.
        """
        if len(img.shape) < 3 or img.shape[0] < 4:  # Not enough bands
            print("Not enough bands for false color composite")
            return img
            
        # Select bands (adjust indices to be 0-based)
        bands = [b-1 for b in bands]
        selected_bands = img[bands]
        
        # Normalize each band
        normalized = np.zeros_like(selected_bands, dtype=np.float32)
        for i in range(3):
            band = selected_bands[i]
            if band.max() > band.min():
                normalized[i] = (band - band.min()) / (band.max() - band.min())
                
        # Convert to 8-bit and transpose to (H, W, C)
        normalized = (normalized * 255).astype(np.uint8)
        normalized = np.transpose(normalized, (1, 2, 0))
        
        return normalized

# Create a satellite image handler
handler = SatelliteImageHandler()

# Define the image path - update with your actual image path
image_path = '../data/sentinelsat/sample_image.tif'

# Check if the file exists
if os.path.exists(image_path):
    # Load and display the image
    image_rgb = handler.load_image(image_path)
    handler.display_image(image_rgb, title="Loaded Satellite Image")
else:
    print(f"Image file {image_path} not found. Please update the path.")
    # Use a sample image for demonstration
    print("Using a sample image for demonstration...")
    sample_img = np.ones((512, 512, 3), dtype=np.uint8) * 128
    # Add some features to the sample image
    cv2.circle(sample_img, (256, 256), 100, (200, 100, 100), -1)
    cv2.rectangle(sample_img, (100, 100), (400, 150), (100, 200, 100), -1)
    image_rgb = sample_img
    handler.display_image(image_rgb, title="Sample Image (Placeholder)")

In [None]:
class SatelliteImagePreprocessor:
    """
    Class for preprocessing satellite images for segmentation tasks.
    """
    
    def __init__(self, patch_size=256, overlap=0.2, augment=True):
        """
        Initialize the preprocessor.
        
        Args:
            patch_size (int): Size of patches to extract (square).
            overlap (float): Overlap between patches (0-1).
            augment (bool): Whether to apply data augmentation.
        """
        self.patch_size = patch_size
        self.overlap = overlap
        self.augment = augment
        
    def normalize_image(self, image, method='minmax'):
        """
        Normalize image pixel values.
        
        Args:
            image (numpy.ndarray): Input image.
            method (str): Normalization method ('minmax', 'std', 'imagenet').
            
        Returns:
            numpy.ndarray: Normalized image.
        """
        if image is None:
            return None
            
        # Make a copy to avoid modifying the original
        img = image.copy().astype(np.float32)
        
        if method == 'minmax':
            # Normalize to [0, 1] for each channel
            for i in range(img.shape[2]):
                if img[:,:,i].max() > img[:,:,i].min():
                    img[:,:,i] = (img[:,:,i] - img[:,:,i].min()) / (img[:,:,i].max() - img[:,:,i].min())
        
        elif method == 'std':
            # Standardize to zero mean and unit variance
            for i in range(img.shape[2]):
                mean = img[:,:,i].mean()
                std = img[:,:,i].std() + 1e-8
                img[:,:,i] = (img[:,:,i] - mean) / std
                
        elif method == 'imagenet':
            # ImageNet normalization (assumes RGB input)
            if img.shape[2] >= 3:
                img = img / 255.0  # Scale to [0, 1]
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                for i in range(3):
                    img[:,:,i] = (img[:,:,i] - mean[i]) / std[i]
        
        return img
    
    def extract_patches(self, image, mask=None):
        """
        Extract overlapping patches from image and mask.
        
        Args:
            image (numpy.ndarray): Input image.
            mask (numpy.ndarray, optional): Input mask.
            
        Returns:
            list: List of image patches and (if provided) mask patches.
        """
        if image is None:
            return []
            
        h, w = image.shape[:2]
        stride = int(self.patch_size * (1 - self.overlap))
        
        image_patches = []
        mask_patches = []
        
        # Extract patches
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                image_patch = image[y:y+self.patch_size, x:x+self.patch_size]
                image_patches.append(image_patch)
                
                if mask is not None:
                    mask_patch = mask[y:y+self.patch_size, x:x+self.patch_size]
                    mask_patches.append(mask_patch)
        
        # Handle edge cases: add patches that include the right and bottom borders
        if h % self.patch_size != 0:
            for x in range(0, w - self.patch_size + 1, stride):
                image_patch = image[h-self.patch_size:h, x:x+self.patch_size]
                image_patches.append(image_patch)
                
                if mask is not None:
                    mask_patch = mask[h-self.patch_size:h, x:x+self.patch_size]
                    mask_patches.append(mask_patch)
        
        if w % self.patch_size != 0:
            for y in range(0, h - self.patch_size + 1, stride):
                image_patch = image[y:y+self.patch_size, w-self.patch_size:w]
                image_patches.append(image_patch)
                
                if mask is not None:
                    mask_patch = mask[y:y+self.patch_size, w-self.patch_size:w]
                    mask_patches.append(mask_patch)
        
        # Add bottom-right corner patch
        if h % self.patch_size != 0 and w % self.patch_size != 0:
            image_patch = image[h-self.patch_size:h, w-self.patch_size:w]
            image_patches.append(image_patch)
            
            if mask is not None:
                mask_patch = mask[h-self.patch_size:h, w-self.patch_size:w]
                mask_patches.append(mask_patch)
                
        if mask is not None:
            return image_patches, mask_patches
        else:
            return image_patches
    
    def augment_data(self, image, mask=None):
        """
        Apply data augmentation to image and mask.
        
        Args:
            image (numpy.ndarray): Input image.
            mask (numpy.ndarray, optional): Input mask.
            
        Returns:
            tuple: Augmented image and (if provided) mask.
        """
        if not self.augment or image is None:
            return image, mask
            
        # Choose a random augmentation
        aug_type = np.random.choice(['flip_h', 'flip_v', 'rotate', 'none'])
        
        if aug_type == 'flip_h':
            image = np.fliplr(image)
            if mask is not None:
                mask = np.fliplr(mask)
                
        elif aug_type == 'flip_v':
            image = np.flipud(image)
            if mask is not None:
                mask = np.flipud(mask)
                
        elif aug_type == 'rotate':
            k = np.random.choice([1, 2, 3])  # 90, 180, 270 degrees
            image = np.rot90(image, k=k)
            if mask is not None:
                mask = np.rot90(mask, k=k)
        
        return image, mask
    
    def create_ndvi_band(self, image, nir_idx=3, red_idx=0):
        """
        Create NDVI band from NIR and Red bands.
        
        Args:
            image (numpy.ndarray): Multi-band satellite image.
            nir_idx (int): Index of NIR band.
            red_idx (int): Index of Red band.
            
        Returns:
            numpy.ndarray: NDVI band.
        """
        # Check if image has enough bands
        if len(image.shape) < 3 or image.shape[2] <= max(nir_idx, red_idx):
            print("Not enough bands for NDVI calculation")
            return None
            
        # Extract bands
        nir = image[:, :, nir_idx].astype(np.float32)
        red = image[:, :, red_idx].astype(np.float32)
        
        # Avoid division by zero
        denominator = nir + red
        denominator[denominator == 0] = 1
        
        # Calculate NDVI
        ndvi = (nir - red) / denominator
        
        # Scale from [-1, 1] to [0, 1] for visualization
        ndvi = (ndvi + 1) / 2
        
        return ndvi
    
    def apply_preprocessing(self, image, mask=None, normalize_method='minmax', extract_patches=True):
        """
        Apply all preprocessing steps to image and mask.
        
        Args:
            image (numpy.ndarray): Input image.
            mask (numpy.ndarray, optional): Input mask.
            normalize_method (str): Normalization method.
            extract_patches (bool): Whether to extract patches.
            
        Returns:
            list or numpy.ndarray: Preprocessed image(s) and mask(s).
        """
        if image is None:
            return None, None
            
        # Normalize
        norm_image = self.normalize_image(image, method=normalize_method)
        
        if extract_patches:
            # Extract patches
            if mask is not None:
                image_patches, mask_patches = self.extract_patches(norm_image, mask)
                
                # Augment patches
                if self.augment:
                    for i in range(len(image_patches)):
                        image_patches[i], mask_patches[i] = self.augment_data(image_patches[i], mask_patches[i])
                        
                return image_patches, mask_patches
            else:
                image_patches = self.extract_patches(norm_image)
                
                # Augment patches
                if self.augment:
                    for i in range(len(image_patches)):
                        image_patches[i], _ = self.augment_data(image_patches[i])
                        
                return image_patches
        else:
            # Just normalize and potentially augment
            if self.augment:
                norm_image, mask = self.augment_data(norm_image, mask)
                
            return norm_image, mask

# Create a preprocessor
preprocessor = SatelliteImagePreprocessor(patch_size=256, overlap=0.2, augment=False)

# Check if we have an image to process
if 'image_rgb' in locals():
    # Normalize the image
    normalized_image = preprocessor.normalize_image(image_rgb)
    
    # Extract patches
    image_patches = preprocessor.extract_patches(normalized_image)
    
    # Display the normalized image
    plt.figure(figsize=(10, 6))
    plt.imshow(normalized_image)
    plt.title('Normalized Image')
    plt.axis('off')
    plt.show()
    
    # Display some patches
    n_patches = min(len(image_patches), 4)
    if n_patches > 0:
        plt.figure(figsize=(15, 4))
        for i in range(n_patches):
            plt.subplot(1, n_patches, i+1)
            plt.imshow(image_patches[i])
            plt.title(f"Patch {i+1}")
            plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        print(f"Total patches extracted: {len(image_patches)}")
    else:
        print("No patches were extracted. The image may be too small.")
    
    # Try to create NDVI (assuming 4+ bands; will not work with RGB)
    if len(image_rgb.shape) >= 3 and image_rgb.shape[2] >= 4:
        ndvi = preprocessor.create_ndvi_band(image_rgb)
        if ndvi is not None:
            plt.figure(figsize=(10, 6))
            plt.imshow(ndvi, cmap='RdYlGn')
            plt.title('NDVI')
            plt.colorbar(label='NDVI')
            plt.axis('off')
            plt.show()
    else:
        print("Not enough bands to calculate NDVI. Using grayscale instead.")
        # Convert to grayscale as a fallback
        gray_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
        plt.figure(figsize=(10, 6))
        plt.imshow(gray_image, cmap='gray')
        plt.title('Grayscale Image')
        plt.axis('off')
        plt.show()
else:
    print("No image available for preprocessing.")

# Implementing U-Net for Satellite Image Segmentation

U-Net is a convolutional neural network designed for biomedical image segmentation, but it works exceptionally well for satellite imagery.

## U-Net Architecture

The U-Net architecture consists of:
1. **Encoder Path (Contracting)**: Captures context through downsampling
2. **Decoder Path (Expanding)**: Precise localization through upsampling
3. **Skip Connections**: Combine high-resolution features with upsampled features

## Mathematical Background

The U-Net is trained to minimize a loss function between the predicted segmentation $\hat{y}$ and ground truth $y$:

$$\mathcal{L}(\hat{y}, y) = -\sum_{i,j} w_{i,j} [y_{i,j} \log(\hat{y}_{i,j}) + (1 - y_{i,j}) \log(1 - \hat{y}_{i,j})]$$

where $w_{i,j}$ are weights that can be used to emphasize certain pixels or classes.

In [None]:
class DoubleConv(nn.Module):
    """
    Double Convolution block for U-Net: (Conv2d -> BatchNorm -> ReLU) × 2
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """
    Downscaling with maxpool then double conv
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
        
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """
    Upscaling then double conv
    """
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        # Use ConvTranspose2d or bilinear upsampling
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        # Pad x1 if necessary
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        # Concatenate along the channel dimension
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    """
    Output convolution block
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    """
    U-Net architecture for image segmentation
    """
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder (downsampling path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder (upsampling path)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Output layer
        logits = self.outc(x)
        return logits

# Define Dice Loss for segmentation
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # Flatten predictions and targets
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        # Calculate intersection and union
        intersection = (predictions * targets).sum()
        union = predictions.sum() + targets.sum()
        
        # Calculate dice coefficient and loss
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        loss = 1 - dice
        
        return loss

# Create a simple dataset class for demonstration
class SatelliteSegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        
        if self.transform:
            image = self.transform(image)
            
        # Convert to tensor if not already
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image.transpose((2, 0, 1))).float()
            
        if not isinstance(mask, torch.Tensor):
            mask = torch.from_numpy(mask).float()
            
        # Ensure image has 3 channels
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)
        elif image.shape[0] > 3:
            image = image[:3]
            
        # Normalize image if needed
        if image.max() > 1.0:
            image = image / 255.0
            
        # Ensure mask is binary
        if mask.max() > 1.0:
            mask = (mask > 0).float()
            
        return image, mask.unsqueeze(0)  # Add channel dimension to mask

# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=5):
    """
    Train the U-Net model
    
    Args:
        model: Model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Device to use
        num_epochs: Number of epochs to train
        
    Returns:
        Trained model and training history
    """
    model = model.to(device)
    history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        # Training loop
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(torch.sigmoid(outputs), masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        # Calculate average training loss
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(torch.sigmoid(outputs), masks)
                
                val_loss += loss.item()
                
                # Calculate Dice coefficient for validation
                dice = 1 - loss.item()  # Since we're using Dice Loss
                val_dice += dice
                
        val_loss /= len(val_loader)
        val_dice /= len(val_loader)
        
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_dice)
        
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Dice: {val_dice:.4f}")
    
    return model, history

# Generate some synthetic data for demonstration
def generate_synthetic_data(num_samples=100, img_size=128):
    """Generate synthetic satellite images and segmentation masks"""
    images = []
    masks = []
    
    for _ in range(num_samples):
        # Create a synthetic image with some shapes
        image = np.ones((img_size, img_size, 3), dtype=np.float32) * 0.1
        mask = np.zeros((img_size, img_size), dtype=np.float32)
        
        # Add random shapes
        num_shapes = np.random.randint(1, 5)
        for _ in range(num_shapes):
            # Random shape type
            shape_type = np.random.choice(['circle', 'rectangle'])
            
            # Random position and size
            x = np.random.randint(10, img_size-10)
            y = np.random.randint(10, img_size-10)
            size = np.random.randint(10, 40)
            
            # Random color for image
            color = np.random.rand(3) * 0.8 + 0.2
            
            if shape_type == 'circle':
                # Draw circle on image and mask
                cv2.circle(image, (x, y), size, color, -1)
                cv2.circle(mask, (x, y), size, 1, -1)
            else:
                # Draw rectangle on image and mask
                x2, y2 = x + size, y + size
                cv2.rectangle(image, (x, y), (x2, y2), color, -1)
                cv2.rectangle(mask, (x, y), (x2, y2), 1, -1)
        
        images.append(image)
        masks.append(mask)
    
    return images, masks

# Create synthetic dataset for demonstration
print("Generating synthetic data for demonstration...")
synthetic_images, synthetic_masks = generate_synthetic_data(num_samples=100, img_size=128)

# Display a few examples
plt.figure(figsize=(15, 5))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(synthetic_images[i])
    plt.contour(synthetic_masks[i], colors='r', levels=[0.5])
    plt.title(f"Sample {i+1}")
    plt.axis('off')
plt.tight_layout()
plt.show()

# Split into training and validation sets
train_size = int(0.8 * len(synthetic_images))
train_images = synthetic_images[:train_size]
train_masks = synthetic_masks[:train_size]
val_images = synthetic_images[train_size:]
val_masks = synthetic_masks[train_size:]

# Create datasets
train_dataset = SatelliteSegmentationDataset(train_images, train_masks)
val_dataset = SatelliteSegmentationDataset(val_images, val_masks)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Initialize the model
model = UNet(n_channels=3, n_classes=1, bilinear=True)

# Define loss function and optimizer
criterion = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Train the model (using a small number of epochs for demonstration)
print("\nTraining U-Net model...")
model, history = train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=2)

# Plot training history
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['val_dice'], label='Validation Dice')
plt.title('Dice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()

plt.tight_layout()
plt.show()

# Test the model on a sample
model.eval()
with torch.no_grad():
    sample_idx = np.random.randint(len(val_images))
    sample_image = val_images[sample_idx]
    sample_mask = val_masks[sample_idx]
    
    # Convert to tensor
    sample_tensor = torch.from_numpy(sample_image.transpose((2, 0, 1))).float().unsqueeze(0).to(device)
    
    # Predict
    output = model(sample_tensor)
    pred_mask = torch.sigmoid(output).cpu().squeeze().numpy() > 0.5
    
    # Display results
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(sample_image)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(sample_mask, cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask, cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate Dice coefficient
    dice = 2 * np.sum(pred_mask * sample_mask) / (np.sum(pred_mask) + np.sum(sample_mask) + 1e-8)
    print(f"Dice coefficient on test sample: {dice:.4f}")

# Perform Edge Detection
Use Canny edge detection to find edges in the binary image.

In [None]:
edges = cv2.Canny(binary_image, 100, 200)

# Display the edges 
plt.figure(figsize=(10, 6))
plt.imshow(edges, cmap='gray')
plt.title('Edge Detection using Canny')
plt.axis('off')  # Hide the axis
plt.show()

# Segment the Image Using Contours
Find and draw contours on the original image to segment it.

In [None]:
contours, _ = cv2.findContours(binary_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

image_with_contours = image_rgb.copy()
cv2.drawContours(image_with_contours, contours, -1, (0, 255, 0), 2)

# image with contours 
plt.figure(figsize=(10, 6))
plt.imshow(image_with_contours)
plt.title('Image with Contours')
plt.axis('off')  
plt.show()