# 🌲 Change Detection CNN - Interactive Tutorial

Learn to build and train a Convolutional Neural Network for detecting deforestation from paired satellite/drone images.

## What You'll Learn:
1. Model architecture basics
2. Data preparation
3. Training process
4. Evaluation metrics
5. Making predictions
6. Real-world application

## 1. Setup 📦

In [None]:
# Install dependencies (run once)
# !pip install tensorflow numpy matplotlib opencv-python

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## 2. Understanding Change Detection 🔍

**Goal:** Detect WHERE forest has been removed between two time periods

**Input:** Two images
- T₀ (Before) - Baseline image
- T₁ (After) - Recent image

**Output:** Probability map
- Values 0-1 for each pixel
- High values = deforestation likely
- Low values = no change

## 3. Build the Model 🏗️

In [None]:
def build_change_detection_model(img_size=256, n_channels=3):
    """
    Build a CNN for change detection
    
    Architecture:
    1. Two inputs (T0 and T1)
    2. Concatenate
    3. Encoder (feature extraction)
    4. Decoder (mask reconstruction)
    5. Output (probability map)
    """
    
    # === INPUTS ===
    input_t0 = layers.Input(shape=(img_size, img_size, n_channels), name='image_t0')
    input_t1 = layers.Input(shape=(img_size, img_size, n_channels), name='image_t1')
    
    # Concatenate both images
    merged = layers.Concatenate(name='merge')([input_t0, input_t1])
    
    # === ENCODER (Downsampling) ===
    # Block 1
    x = layers.Conv2D(64, 3, padding='same', activation='relu', name='enc_conv1')(merged)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2, name='enc_pool1')(x)
    
    # Block 2
    x = layers.Conv2D(128, 3, padding='same', activation='relu', name='enc_conv2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2, name='enc_pool2')(x)
    
    # Block 3
    x = layers.Conv2D(256, 3, padding='same', activation='relu', name='enc_conv3')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2, name='enc_pool3')(x)
    
    # Bottleneck
    x = layers.Conv2D(512, 3, padding='same', activation='relu', name='bottleneck')(x)
    x = layers.BatchNormalization()(x)
    
    # === DECODER (Upsampling) ===
    # Upsample 1
    x = layers.UpSampling2D(2, name='dec_upsample1')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu', name='dec_conv1')(x)
    x = layers.BatchNormalization()(x)
    
    # Upsample 2
    x = layers.UpSampling2D(2, name='dec_upsample2')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu', name='dec_conv2')(x)
    x = layers.BatchNormalization()(x)
    
    # Upsample 3
    x = layers.UpSampling2D(2, name='dec_upsample3')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu', name='dec_conv3')(x)
    x = layers.BatchNormalization()(x)
    
    # === OUTPUT ===
    output = layers.Conv2D(1, 1, activation='sigmoid', name='output')(x)
    
    # Create model
    model = keras.Model(inputs=[input_t0, input_t1], outputs=output, name='ChangeDetectionCNN')
    
    return model

# Build model
model = build_change_detection_model(img_size=256)
print(f"\n✓ Model created with {model.count_params():,} parameters")

In [None]:
# View architecture
model.summary()

## 4. Compile Model ⚙️

Define loss function and metrics

In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Dice coefficient for segmentation tasks"""
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

def combined_loss(y_true, y_pred):
    """BCE + Dice Loss"""
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    return bce + dice_loss(y_true, y_pred)

# Compile
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss=combined_loss,
    metrics=[
        'accuracy',
        dice_coefficient,
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall')
    ]
)

print("✓ Model compiled successfully!")

## 5. Generate Training Data 🎲

For this tutorial, we'll generate synthetic data. In production, use real satellite/drone images.

In [None]:
def generate_synthetic_data(n_samples=200, img_size=256):
    """
    Generate synthetic change detection data
    """
    print(f"Generating {n_samples} synthetic samples...")
    
    images_t0 = []
    images_t1 = []
    masks = []
    
    for i in range(n_samples):
        # Create base forest image (T0)
        img_t0 = np.random.rand(img_size, img_size, 3).astype(np.float32)
        img_t0[:, :, 1] *= 0.8  # More green
        
        # Create T1 with some changes
        img_t1 = img_t0.copy()
        
        # Create change mask
        mask = np.zeros((img_size, img_size, 1), dtype=np.float32)
        
        # Add random deforestation patches
        n_patches = np.random.randint(1, 5)
        for _ in range(n_patches):
            x = np.random.randint(0, img_size - 50)
            y = np.random.randint(0, img_size - 50)
            w = np.random.randint(20, 50)
            h = np.random.randint(20, 50)
            
            # Change image T1
            img_t1[y:y+h, x:x+w, :] *= 0.5
            img_t1[y:y+h, x:x+w, 2] += 0.2  # More brown
            
            # Mark in mask
            mask[y:y+h, x:x+w, 0] = 1.0
        
        images_t0.append(img_t0)
        images_t1.append(img_t1)
        masks.append(mask)
    
    print(f"✓ Generated {n_samples} samples")
    return np.array(images_t0), np.array(images_t1), np.array(masks)

# Generate data
X_t0, X_t1, y = generate_synthetic_data(n_samples=200, img_size=256)

print(f"\nDataset shapes:")
print(f"  X_t0: {X_t0.shape}")
print(f"  X_t1: {X_t1.shape}")
print(f"  y: {y.shape}")
print(f"\nChange percentage: {y.mean()*100:.1f}%")

In [None]:
# Visualize samples
fig, axes = plt.subplots(3, 3, figsize=(12, 12))

for i in range(3):
    idx = np.random.randint(0, len(X_t0))
    
    axes[i, 0].imshow(X_t0[idx])
    axes[i, 0].set_title(f'Sample {idx} - T₀ (Before)')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(X_t1[idx])
    axes[i, 1].set_title(f'Sample {idx} - T₁ (After)')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(y[idx, :, :, 0], cmap='Reds', vmin=0, vmax=1)
    axes[i, 2].set_title(f'Sample {idx} - Change Mask')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

## 6. Train/Val Split 📊

In [None]:
# Split data (80/20)
split_idx = int(0.8 * len(X_t0))

X_t0_train, X_t0_val = X_t0[:split_idx], X_t0[split_idx:]
X_t1_train, X_t1_val = X_t1[:split_idx], X_t1[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]

print(f"Training samples: {len(X_t0_train)}")
print(f"Validation samples: {len(X_t0_val)}")

## 7. Training 🚀

In [None]:
# Configure callbacks
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

callbacks = [
    keras.callbacks.ModelCheckpoint(
        f'best_model_{timestamp}.keras',
        monitor='val_dice_coefficient',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]

print("Callbacks configured:")
print("  ✓ ModelCheckpoint")
print("  ✓ EarlyStopping (patience=10)")
print("  ✓ ReduceLROnPlateau")

In [None]:
# Train
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60 + "\n")

history = model.fit(
    [X_t0_train, X_t1_train],
    y_train,
    validation_data=([X_t0_val, X_t1_val], y_val),
    epochs=30,
    batch_size=8,
    callbacks=callbacks,
    verbose=1
)

print("\n" + "="*60)
print("✓ TRAINING COMPLETED")
print("="*60)

## 8. Visualize Training Results 📈

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0, 0].plot(history.history['val_loss'], label='Validation', linewidth=2)
axes[0, 0].set_title('Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Accuracy
axes[0, 1].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0, 1].plot(history.history['val_accuracy'], label='Validation', linewidth=2)
axes[0, 1].set_title('Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Dice
axes[1, 0].plot(history.history['dice_coefficient'], label='Train', linewidth=2)
axes[1, 0].plot(history.history['val_dice_coefficient'], label='Validation', linewidth=2)
axes[1, 0].set_title('Dice Coefficient', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Dice')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Precision & Recall
axes[1, 1].plot(history.history['precision'], label='Precision (Train)', linewidth=2)
axes[1, 1].plot(history.history['val_precision'], label='Precision (Val)', linewidth=2)
axes[1, 1].plot(history.history['recall'], label='Recall (Train)', linewidth=2, linestyle='--')
axes[1, 1].plot(history.history['val_recall'], label='Recall (Val)', linewidth=2, linestyle='--')
axes[1, 1].set_title('Precision & Recall', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f'training_results_{timestamp}.png', dpi=300)
plt.show()

print(f"\n✓ Plot saved: training_results_{timestamp}.png")

## 9. Evaluate Model 🎯

In [None]:
# Evaluate on validation set
print("\n📊 FINAL EVALUATION\n")
print("="*60)

results = model.evaluate([X_t0_val, X_t1_val], y_val, verbose=0)

metrics = ['Loss', 'Accuracy', 'Dice Coefficient', 'Precision', 'Recall']
for name, value in zip(metrics, results):
    print(f"{name:20s}: {value:.4f}")

print("="*60)

# Calculate F1-Score
precision = results[3]
recall = results[4]
f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
print(f"\nF1-Score: {f1:.4f}")

## 10. Example Predictions 🔮

In [None]:
# Predict on validation samples
n_examples = 6
random_indices = np.random.choice(len(X_t0_val), n_examples, replace=False)

predictions = model.predict([X_t0_val[random_indices], X_t1_val[random_indices]], verbose=0)

# Visualize
fig, axes = plt.subplots(n_examples, 4, figsize=(16, 4*n_examples))

for i, idx in enumerate(random_indices):
    # T0
    axes[i, 0].imshow(X_t0_val[idx])
    axes[i, 0].set_title('T₀ (Before)', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # T1
    axes[i, 1].imshow(X_t1_val[idx])
    axes[i, 1].set_title('T₁ (After)', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Ground Truth
    axes[i, 2].imshow(y_val[idx, :, :, 0], cmap='Reds', vmin=0, vmax=1)
    axes[i, 2].set_title('Ground Truth', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')
    
    # Prediction
    pred = predictions[i, :, :, 0]
    defor_pct = (pred > 0.5).sum() / pred.size * 100
    axes[i, 3].imshow(pred, cmap='Reds', vmin=0, vmax=1)
    axes[i, 3].set_title(f'Prediction ({defor_pct:.1f}% defor.)', fontsize=12, fontweight='bold')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig(f'predictions_{timestamp}.png', dpi=300)
plt.show()

print(f"\n✓ Predictions saved: predictions_{timestamp}.png")

## 11. Detailed Analysis 🔍

In [None]:
# Analyze single prediction with different thresholds
idx = 0
test_t0 = X_t0_val[idx:idx+1]
test_t1 = X_t1_val[idx:idx+1]
true_mask = y_val[idx:idx+1]

pred_mask = model.predict([test_t0, test_t1], verbose=0)

# Different thresholds
thresholds = [0.3, 0.5, 0.7]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Top row
axes[0, 0].imshow(test_t0[0])
axes[0, 0].set_title('T₀ (Before)', fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(test_t1[0])
axes[0, 1].set_title('T₁ (After)', fontweight='bold')
axes[0, 1].axis('off')

axes[0, 2].imshow(true_mask[0, :, :, 0], cmap='Reds')
axes[0, 2].set_title('Ground Truth', fontweight='bold')
axes[0, 2].axis('off')

im = axes[0, 3].imshow(pred_mask[0, :, :, 0], cmap='jet', vmin=0, vmax=1)
axes[0, 3].set_title('Probability Map', fontweight='bold')
axes[0, 3].axis('off')
plt.colorbar(im, ax=axes[0, 3], fraction=0.046)

# Bottom row - different thresholds
for i, thresh in enumerate(thresholds):
    binary = (pred_mask[0, :, :, 0] > thresh).astype(float)
    defor_pct = binary.sum() / binary.size * 100
    axes[1, i].imshow(binary, cmap='Reds')
    axes[1, i].set_title(f'Threshold={thresh}\n({defor_pct:.1f}% defor.)', fontweight='bold')
    axes[1, i].axis('off')

# Error map
diff = np.abs(pred_mask[0, :, :, 0] - true_mask[0, :, :, 0])
im = axes[1, 3].imshow(diff, cmap='RdYlGn_r', vmin=0, vmax=1)
axes[1, 3].set_title('Absolute Error', fontweight='bold')
axes[1, 3].axis('off')
plt.colorbar(im, ax=axes[1, 3], fraction=0.046)

plt.tight_layout()
plt.show()

# Statistics
print("\n📊 DETAILED STATISTICS\n")
print(f"Mean Absolute Error: {diff.mean():.4f}")
print(f"Max Error: {diff.max():.4f}")
print(f"Pixels correctly classified (threshold=0.5): {((pred_mask[0,:,:,0] > 0.5) == (true_mask[0,:,:,0] > 0.5)).sum() / true_mask.size * 100:.2f}%")

## 12. Save Model 💾

In [None]:
# Save final model
final_model_path = f'change_detection_final_{timestamp}.keras'
model.save(final_model_path)
print(f"✓ Model saved: {final_model_path}")

print("\n📦 MODEL READY FOR DEPLOYMENT!")
print("\nTo load:")
print(f">>> model = keras.models.load_model('{final_model_path}')")
print("\nTo predict:")
print(">>> prediction = model.predict([img_t0, img_t1])")

## 13. Production Helper Function 🛠️

In [None]:
def predict_deforestation(model_path, img_t0_path, img_t1_path, threshold=0.5):
    """
    Complete function for predicting deforestation from image files
    
    Args:
        model_path: path to .keras model
        img_t0_path: path to T0 image
        img_t1_path: path to T1 image
        threshold: binarization threshold
    
    Returns:
        dict with results and statistics
    """
    import cv2
    
    # Load model
    model = keras.models.load_model(model_path, compile=False)
    
    # Load and preprocess images
    img_t0 = cv2.imread(img_t0_path)
    img_t1 = cv2.imread(img_t1_path)
    
    img_t0 = cv2.cvtColor(img_t0, cv2.COLOR_BGR2RGB)
    img_t1 = cv2.cvtColor(img_t1, cv2.COLOR_BGR2RGB)
    
    img_t0 = cv2.resize(img_t0, (256, 256)) / 255.0
    img_t1 = cv2.resize(img_t1, (256, 256)) / 255.0
    
    img_t0 = np.expand_dims(img_t0, axis=0)
    img_t1 = np.expand_dims(img_t1, axis=0)
    
    # Predict
    prob_mask = model.predict([img_t0, img_t1], verbose=0)
    binary_mask = (prob_mask > threshold).astype(np.uint8)
    
    # Statistics
    deforestation_pixels = binary_mask.sum()
    total_pixels = binary_mask.size
    deforestation_percentage = (deforestation_pixels / total_pixels) * 100
    
    results = {
        'probability_map': prob_mask[0, :, :, 0],
        'binary_mask': binary_mask[0, :, :, 0],
        'deforestation_percentage': deforestation_percentage,
        'deforestation_pixels': deforestation_pixels,
        'total_pixels': total_pixels,
        'threshold': threshold
    }
    
    return results

print("✓ Helper function defined")
print("\nUsage:")
print(">>> results = predict_deforestation('model.keras', 'before.jpg', 'after.jpg')")
print(">>> print(f'Deforestation: {results['deforestation_percentage']:.2f}%')")

## 🎉 Congratulations!

You've completed the Change Detection CNN tutorial!

### What You've Learned:
- ✅ CNN architecture for change detection
- ✅ Combined loss functions (BCE + Dice)
- ✅ Training with callbacks
- ✅ Multiple evaluation metrics
- ✅ Threshold analysis
- ✅ Model deployment

### Next Steps:
1. **Use Real Data**: Replace synthetic data with actual drone/satellite images
2. **Try Advanced Architecture**: Implement U-Net or Siamese CNN
3. **Improve Performance**: Add more augmentation, try transfer learning
4. **Deploy**: Create a web service or integrate with GIS tools
5. **Scale Up**: Process large areas with sliding windows

### Resources:
- [TensorFlow Documentation](https://www.tensorflow.org/)
- [Keras Examples](https://keras.io/examples/)
- [Change Detection Papers](https://paperswithcode.com/task/change-detection)

**Happy monitoring! 🌲🛰️**