# Satellite Image Change Detection Demo

This notebook demonstrates the usage of our satellite image change detection system. We'll go through the following steps:

1. Setup and environment preparation
2. Data loading and preprocessing
3. Model training (small demo)
4. Running inference on test images
5. Visualizing and evaluating results
6. Post-processing for output refinement

Let's begin!

## 1. Setup and Environment Preparation

First, let's set up our environment and import necessary libraries.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import rasterio
from rasterio.plot import show
import geopandas as gpd
from tqdm.notebook import tqdm

# Add the project root to path to import our modules
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import our modules
from changedetect.src.models.siamese_unet import get_change_detection_model
from changedetect.src.data.dataset import create_dataloaders, create_test_dataloader
from changedetect.src.data.preprocess import preprocess_pair
from changedetect.src.utils.metrics import calculate_all_metrics
from changedetect.src.utils.postprocessing import process_change_detection_mask

# Set paths
DATA_DIR = Path("../data")
OUTPUT_DIR = Path("../outputs")

# Create directories if they don't exist
DATA_DIR.mkdir(exist_ok=True, parents=True)
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Data Loading and Preprocessing

Now, let's load some sample data and preprocess it. If you don't have any sample data, you can download it using our data download module.

In [None]:
# Example to download data (commented out)
# from changedetect.src.data.download import download_sentinel2_copernicus
# download_sentinel2_copernicus(
#     output_dir=str(DATA_DIR / "raw"),
#     lat=28.1740, lon=77.6126,  # Urban area example
#     date_range=("2023-01-01", "2023-02-01"),
#     cloud_cover_max=10
# )

In [None]:
# For this demo, let's use a pre-downloaded image pair
# Replace with your actual image paths or use sample data
image_t1_path = DATA_DIR / "samples" / "urban_t1.tif"
image_t2_path = DATA_DIR / "samples" / "urban_t2.tif"
mask_path = DATA_DIR / "samples" / "urban_mask.tif"

# Check if files exist
if not image_t1_path.exists() or not image_t2_path.exists():
    print("Sample files not found. Please add sample data or modify the paths.")
else:
    # Display the images
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    with rasterio.open(image_t1_path) as src:
        img_t1 = src.read()
        show(img_t1, ax=axs[0])
        axs[0].set_title("Time 1 Image")
    
    with rasterio.open(image_t2_path) as src:
        img_t2 = src.read()
        show(img_t2, ax=axs[1])
        axs[1].set_title("Time 2 Image")
    
    if mask_path.exists():
        with rasterio.open(mask_path) as src:
            mask = src.read(1)
            axs[2].imshow(mask, cmap='viridis')
            axs[2].set_title("Ground Truth Change Mask")
    else:
        axs[2].set_title("No Ground Truth Available")
        axs[2].axis('off')
    
    plt.tight_layout()
    plt.show()

### Preprocessing the Images

Now, let's preprocess the images before feeding them to the model. This includes co-registration, normalization, and other preprocessing steps.

In [None]:
# Read images
if image_t1_path.exists() and image_t2_path.exists():
    with rasterio.open(image_t1_path) as src:
        img_t1 = src.read()
        profile = src.profile
    
    with rasterio.open(image_t2_path) as src:
        img_t2 = src.read()
    
    # Preprocess image pair
    preprocessed_t1, preprocessed_t2 = preprocess_pair(img_t1, img_t2)
    
    # Visualize preprocessed images
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    
    # Original images (first 3 bands for RGB)
    rgb_t1 = np.transpose(img_t1[:3], (1, 2, 0))
    rgb_t2 = np.transpose(img_t2[:3], (1, 2, 0))
    
    # Normalize for display
    def normalize_for_display(img):
        img = img.copy()
        for i in range(img.shape[2]):
            img[:,:,i] = (img[:,:,i] - img[:,:,i].min()) / (img[:,:,i].max() - img[:,:,i].min() + 1e-8)
        return np.clip(img, 0, 1)
    
    rgb_t1 = normalize_for_display(rgb_t1)
    rgb_t2 = normalize_for_display(rgb_t2)
    
    axs[0, 0].imshow(rgb_t1)
    axs[0, 0].set_title("Original Time 1")
    
    axs[0, 1].imshow(rgb_t2)
    axs[0, 1].set_title("Original Time 2")
    
    # Preprocessed images
    pre_rgb_t1 = np.transpose(preprocessed_t1[:3], (1, 2, 0))
    pre_rgb_t2 = np.transpose(preprocessed_t2[:3], (1, 2, 0))
    
    pre_rgb_t1 = normalize_for_display(pre_rgb_t1)
    pre_rgb_t2 = normalize_for_display(pre_rgb_t2)
    
    axs[1, 0].imshow(pre_rgb_t1)
    axs[1, 0].set_title("Preprocessed Time 1")
    
    axs[1, 1].imshow(pre_rgb_t2)
    axs[1, 1].set_title("Preprocessed Time 2")
    
    plt.tight_layout()
    plt.show()

## 3. Model Training (Small Demo)

For demonstration purposes, let's create a small training loop to show how the model is trained. In practice, you would use our training script with more data and epochs.

In [None]:
# Create a model
model = get_change_detection_model(
    model_type='siamese_unet',
    in_channels=3,  # Using RGB channels for demo
    out_channels=1,  # Binary change detection
    features=32,    # Reduced features for demo
    bilinear=True,
    dropout=0.2
)

# Move model to device
model.to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
# Create data loaders (you need a proper dataset for this)
# For demo, we'll just show the code without executing it

'''
# Define paths
image_dir = DATA_DIR / "train" / "images"
mask_dir = DATA_DIR / "train" / "masks"

# Create dataloaders
train_loader, val_loader = create_dataloaders(
    image_pairs_dir=str(image_dir),
    mask_dir=str(mask_dir),
    tile_size=256,
    batch_size=4,  # Small batch size for demo
    val_split=0.2,
    num_workers=2,
    overlap=32
)

# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5  # Few epochs for demo

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Get data
        t1_images = batch['t1'].to(device).float()
        t2_images = batch['t2'].to(device).float()
        masks = batch['mask'].to(device).float()
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(t1_images, t2_images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Print epoch loss
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
    
    # Validation
    if epoch % 2 == 0:
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                # Get data
                t1_images = batch['t1'].to(device).float()
                t2_images = batch['t2'].to(device).float()
                masks = batch['mask'].to(device).float()
                
                # Forward pass
                outputs = model(t1_images, t2_images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
        
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}")

# Save model
model_save_path = OUTPUT_DIR / "models" / "demo_model.pth"
model_save_path.parent.mkdir(exist_ok=True, parents=True)

torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, model_save_path)

print(f"Model saved to {model_save_path}")
'''

## 4. Running Inference on Test Images

Now, let's run inference on a test image pair using a pre-trained model. If you don't have a pre-trained model, you can use the model we just trained.

In [None]:
# Let's assume we have a pre-trained model available
# In practice, you would load this from a file

# For demonstration, we'll use the model we created earlier
# In a real scenario, you would load a trained model like this:

'''
# Load pre-trained model
model_path = OUTPUT_DIR / "models" / "best_model.pth"
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
'''

# Set model to evaluation mode
model.eval()

In [None]:
# Run inference on our sample images
if image_t1_path.exists() and image_t2_path.exists():
    # Read images again if needed
    with rasterio.open(image_t1_path) as src:
        img_t1 = src.read()
        profile = src.profile
    
    with rasterio.open(image_t2_path) as src:
        img_t2 = src.read()
    
    # Preprocess images
    preprocessed_t1, preprocessed_t2 = preprocess_pair(img_t1, img_t2)
    
    # Convert to PyTorch tensors
    t1_tensor = torch.from_numpy(preprocessed_t1).unsqueeze(0).to(device).float()
    t2_tensor = torch.from_numpy(preprocessed_t2).unsqueeze(0).to(device).float()
    
    # Run inference
    with torch.no_grad():
        output = model(t1_tensor, t2_tensor)
        prediction = torch.sigmoid(output) > 0.5
    
    # Convert to numpy array
    prediction = prediction.squeeze().cpu().numpy().astype(np.uint8) * 255
    
    # Save prediction
    output_dir = OUTPUT_DIR / "predictions"
    output_dir.mkdir(exist_ok=True, parents=True)
    
    output_path = output_dir / "demo_prediction.tif"
    
    # Update profile for output
    out_profile = profile.copy()
    out_profile.update({
        'count': 1,
        'dtype': 'uint8',
        'compress': 'lzw',
        'nodata': 0
    })
    
    # Write prediction to file
    with rasterio.open(output_path, 'w', **out_profile) as dst:
        dst.write(prediction, 1)
    
    print(f"Prediction saved to {output_path}")
    
    # Display results
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    # Display original images (RGB bands)
    rgb_t1 = np.transpose(preprocessed_t1[:3], (1, 2, 0))
    rgb_t2 = np.transpose(preprocessed_t2[:3], (1, 2, 0))
    
    # Normalize for display
    rgb_t1 = normalize_for_display(rgb_t1)
    rgb_t2 = normalize_for_display(rgb_t2)
    
    axs[0].imshow(rgb_t1)
    axs[0].set_title("Time 1 Image")
    
    axs[1].imshow(rgb_t2)
    axs[1].set_title("Time 2 Image")
    
    # Display prediction
    axs[2].imshow(prediction, cmap='viridis')
    axs[2].set_title("Predicted Change Mask")
    
    plt.tight_layout()
    plt.show()

## 5. Visualizing and Evaluating Results

Let's compare the prediction with the ground truth and calculate evaluation metrics.

In [None]:
# Compare with ground truth if available
if mask_path.exists() and 'prediction' in locals():
    with rasterio.open(mask_path) as src:
        gt_mask = src.read(1)
        gt_mask_binary = (gt_mask > 0).astype(bool)
    
    # Convert prediction to binary mask
    pred_binary = (prediction > 0).astype(bool)
    
    # Calculate metrics
    metrics = calculate_all_metrics(pred_binary, gt_mask_binary)
    
    # Display metrics
    print(f"Evaluation Metrics:")
    print(f"IoU: {metrics['iou']:.4f}")
    print(f"Dice (F1): {metrics['dice']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    
    # Visualize comparison
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    axs[0].imshow(gt_mask_binary, cmap='viridis')
    axs[0].set_title("Ground Truth")
    
    axs[1].imshow(pred_binary, cmap='viridis')
    axs[1].set_title("Prediction")
    
    # Confusion visualization
    # - True positive (green)
    # - False positive (red)
    # - False negative (blue)
    # - True negative (black)
    confusion = np.zeros((*pred_binary.shape, 3), dtype=np.uint8)
    
    # True positive: Prediction = 1, Ground Truth = 1
    confusion[pred_binary & gt_mask_binary] = [0, 255, 0]  # Green
    
    # False positive: Prediction = 1, Ground Truth = 0
    confusion[pred_binary & ~gt_mask_binary] = [255, 0, 0]  # Red
    
    # False negative: Prediction = 0, Ground Truth = 1
    confusion[~pred_binary & gt_mask_binary] = [0, 0, 255]  # Blue
    
    axs[2].imshow(confusion)
    axs[2].set_title("Confusion (Green=TP, Red=FP, Blue=FN)")
    
    plt.tight_layout()
    plt.show()

## 6. Post-processing for Output Refinement

Finally, let's apply post-processing to refine the predicted mask.

In [None]:
# Apply post-processing
if 'prediction' in locals():
    # Save the raw prediction to a temporary file
    raw_pred_path = output_dir / "raw_prediction.tif"
    
    with rasterio.open(raw_pred_path, 'w', **out_profile) as dst:
        dst.write(prediction, 1)
    
    # Apply post-processing
    processed_path = output_dir / "processed_prediction.tif"
    
    process_change_detection_mask(
        str(raw_pred_path),
        str(processed_path),
        min_size=10,
        apply_opening=True,
        apply_closing=True,
        fill_holes=True
    )
    
    # Read the processed mask
    with rasterio.open(processed_path) as src:
        processed_mask = src.read(1)
    
    # Compare raw and processed masks
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    
    axs[0].imshow(prediction, cmap='viridis')
    axs[0].set_title("Raw Prediction")
    
    axs[1].imshow(processed_mask, cmap='viridis')
    axs[1].set_title("Processed Prediction")
    
    plt.tight_layout()
    plt.show()
    
    # Calculate metrics for the processed mask
    if mask_path.exists():
        processed_binary = (processed_mask > 0).astype(bool)
        processed_metrics = calculate_all_metrics(processed_binary, gt_mask_binary)
        
        print(f"\nProcessed Mask Metrics:")
        print(f"IoU: {processed_metrics['iou']:.4f}")
        print(f"Dice (F1): {processed_metrics['dice']:.4f}")
        print(f"Precision: {processed_metrics['precision']:.4f}")
        print(f"Recall: {processed_metrics['recall']:.4f}")
        print(f"Accuracy: {processed_metrics['accuracy']:.4f}")

## Conclusion

In this notebook, we demonstrated how to use our satellite image change detection system. We covered:

1. Setting up the environment
2. Loading and preprocessing data
3. Training a model (demo code)
4. Running inference on test images
5. Visualizing and evaluating results
6. Applying post-processing for output refinement

This system can be used for detecting various types of changes in satellite imagery, such as urban expansion, deforestation, construction activities, and more. For production use, you would typically train on larger datasets and for more epochs to achieve better performance.