# Wildfire Detection and Segmentation using CNNs

This notebook implements a Convolutional Neural Network (CNN) approach for wildfire detection and segmentation using satellite imagery. We'll utilize pre-fire and post-fire images from the MTBS dataset to train a model that can detect and segment burned areas.

In [None]:
# Import required libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
import rasterio
from rasterio.plot import show
import glob
import geopandas as gpd
from shapely.geometry import box
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, precision_score, recall_score, f1_score

# Check TensorFlow version and GPU availability
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Add src directory to path
sys.path.append(os.path.join(os.getcwd(), 'src'))

## Load and Prepare MTBS Fire Data

We'll load the pre-fire and post-fire satellite imagery along with burn perimeter information from the MTBS dataset.

In [None]:
# Function to find available fire data
def find_fire_data():
    data_dir = os.path.join(os.getcwd(), 'data')
    fire_folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))]
    return data_dir, fire_folders

# Get available fire data
data_dir, fire_folders = find_fire_data()
print(f"Found {len(fire_folders)} fire datasets:")
for folder in fire_folders:
    print(f"- {folder}")

# Function to load MTBS fire data for a specific fire
def load_fire_data(fire_folder):
    fire_path = os.path.join(data_dir, fire_folder)
    
    # Find pre-fire and post-fire reflectance data
    refl_files = glob.glob(os.path.join(fire_path, f"*_refl.tif"))
    
    if len(refl_files) < 2:
        print(f"Insufficient reflectance files for {fire_folder}")
        return None, None, None
    
    # Sort by date (assuming filename contains date)
    refl_files.sort()
    pre_fire_raster = refl_files[0]
    post_fire_raster = refl_files[1]
    
    # Find burn perimeter shapefile
    burn_bndy_files = glob.glob(os.path.join(fire_path, f"*_burn_bndy.shp"))
    if not burn_bndy_files:
        print(f"No burn boundary shapefile found for {fire_folder}")
        return pre_fire_raster, post_fire_raster, None
    
    burn_bndy = burn_bndy_files[0]
    
    print(f"Pre-fire: {os.path.basename(pre_fire_raster)}")
    print(f"Post-fire: {os.path.basename(post_fire_raster)}")
    print(f"Burn boundary: {os.path.basename(burn_bndy)}")
    
    return pre_fire_raster, post_fire_raster, burn_bndy

# Select the first fire dataset as an example
if fire_folders:
    selected_fire = fire_folders[0]
    print(f"\nLoading data for fire: {selected_fire}")
    pre_fire_raster, post_fire_raster, burn_bndy = load_fire_data(selected_fire)
else:
    print("No fire datasets found.")

## Visualize the Fire Data

Let's visualize the pre-fire and post-fire imagery along with the burn perimeter to understand the data.

In [None]:
# Function to visualize reflectance data with burn perimeter overlay
def visualize_fire_data(pre_fire_path, post_fire_path, burn_bndy_path):
    if not (pre_fire_path and post_fire_path):
        print("Missing pre-fire or post-fire data")
        return
    
    # Create figure with subplots
    fig, axs = plt.subplots(1, 3, figsize=(20, 7))
    
    # Load and display pre-fire image
    with rasterio.open(pre_fire_path) as src:
        pre_fire_data = src.read([4, 3, 2])  # NIR, Red, Green as RGB
        pre_fire_data = np.transpose(pre_fire_data, (1, 2, 0))
        # Normalize for visualization
        pre_fire_data = pre_fire_data.astype(np.float32)
        for i in range(3):
            band = pre_fire_data[:,:,i]
            min_val = np.percentile(band, 2)
            max_val = np.percentile(band, 98)
            pre_fire_data[:,:,i] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
        
        axs[0].imshow(pre_fire_data)
        axs[0].set_title("Pre-Fire Image (False Color)")
        axs[0].axis('off')
        transform = src.transform
        crs = src.crs
    
    # Load and display post-fire image
    with rasterio.open(post_fire_path) as src:
        post_fire_data = src.read([4, 3, 2])  # NIR, Red, Green as RGB
        post_fire_data = np.transpose(post_fire_data, (1, 2, 0))
        # Normalize for visualization
        post_fire_data = post_fire_data.astype(np.float32)
        for i in range(3):
            band = post_fire_data[:,:,i]
            min_val = np.percentile(band, 2)
            max_val = np.percentile(band, 98)
            post_fire_data[:,:,i] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
        
        axs[1].imshow(post_fire_data)
        axs[1].set_title("Post-Fire Image (False Color)")
        axs[1].axis('off')
    
    # Calculate NDVI change
    with rasterio.open(pre_fire_path) as pre_src, rasterio.open(post_fire_path) as post_src:
        pre_nir = pre_src.read(4).astype(np.float32)
        pre_red = pre_src.read(3).astype(np.float32)
        pre_ndvi = (pre_nir - pre_red) / (pre_nir + pre_red + 1e-6)  # Add small value to avoid division by zero
        
        post_nir = post_src.read(4).astype(np.float32)
        post_red = post_src.read(3).astype(np.float32)
        post_ndvi = (post_nir - post_red) / (post_nir + post_red + 1e-6)
        
        ndvi_diff = pre_ndvi - post_ndvi
        
        # Plot NDVI difference
        im = axs[2].imshow(ndvi_diff, cmap='RdYlGn_r', vmin=-0.5, vmax=0.5)
        axs[2].set_title("NDVI Change (Pre - Post)")
        axs[2].axis('off')
        plt.colorbar(im, ax=axs[2], label='NDVI Difference')
    
    # If burn perimeter is available, overlay it on all images
    if burn_bndy_path:
        try:
            burn_gdf = gpd.read_file(burn_bndy_path)
            
            # Transform to raster CRS if needed
            if burn_gdf.crs != crs:
                burn_gdf = burn_gdf.to_crs(crs)
            
            # For each subplot, overlay the burn perimeter
            for ax in axs:
                burn_gdf.boundary.plot(ax=ax, color='red', linewidth=2)
                
            print(f"Burn perimeter overlaid successfully")
        except Exception as e:
            print(f"Error overlaying burn perimeter: {e}")
    
    plt.tight_layout()
    plt.show()
    
    return pre_fire_data, post_fire_data, ndvi_diff

# Visualize the fire data
if 'pre_fire_raster' in locals() and pre_fire_raster:
    pre_fire_img, post_fire_img, ndvi_diff = visualize_fire_data(pre_fire_raster, post_fire_raster, burn_bndy)
else:
    print("No fire data available to visualize.")

## Prepare Training Data for CNN

Now we'll prepare the data for training our CNN model. We'll create a dataset of image patches with corresponding burned/unburned labels.

In [None]:
# Function to create burned area mask from burn perimeter
def create_burn_mask(burn_bndy_path, reference_raster_path):
    # Open the reference raster to get dimensions and transform
    with rasterio.open(reference_raster_path) as src:
        height = src.height
        width = src.width
        transform = src.transform
        crs = src.crs
    
    # Read the burn perimeter shapefile
    burn_gdf = gpd.read_file(burn_bndy_path)
    
    # Transform to the same CRS as the raster if needed
    if burn_gdf.crs != crs:
        burn_gdf = burn_gdf.to_crs(crs)
    
    # Create a mask where 1=burned, 0=unburned
    from rasterio.features import rasterize
    shapes = [(geom, 1) for geom in burn_gdf.geometry]
    burn_mask = rasterize(shapes, out_shape=(height, width), transform=transform, fill=0, dtype=np.uint8)
    
    return burn_mask

# Function to prepare training data
def prepare_training_data(pre_fire_path, post_fire_path, burn_bndy_path, patch_size=64, stride=32):
    print("Preparing training data...")
    
    # Create burn mask
    burn_mask = create_burn_mask(burn_bndy_path, post_fire_path)
    print(f"Created burn mask with shape {burn_mask.shape}")
    
    # Load pre-fire and post-fire images
    with rasterio.open(pre_fire_path) as pre_src, rasterio.open(post_fire_path) as post_src:
        # Read visible + NIR bands
        pre_bands = pre_src.read([1, 2, 3, 4])  # Blue, Green, Red, NIR
        post_bands = post_src.read([1, 2, 3, 4])
        
        # Transpose to height, width, channels
        pre_img = np.transpose(pre_bands, (1, 2, 0))
        post_img = np.transpose(post_bands, (1, 2, 0))
        
        # Calculate NDVI for pre and post
        pre_ndvi = (pre_bands[3] - pre_bands[2]) / (pre_bands[3] + pre_bands[2] + 1e-6)
        post_ndvi = (post_bands[3] - post_bands[2]) / (post_bands[3] + post_bands[2] + 1e-6)
        ndvi_diff = pre_ndvi - post_ndvi
        
        # Stack all features: pre-image (4 bands), post-image (4 bands), NDVI difference (1 band)
        X = np.zeros((pre_img.shape[0], pre_img.shape[1], 9), dtype=np.float32)
        X[:,:,0:4] = pre_img / 10000.0  # Normalize reflectance values
        X[:,:,4:8] = post_img / 10000.0
        X[:,:,8] = ndvi_diff
        
        # Target is burn mask
        y = burn_mask
        
        # Create patches
        patches_X = []
        patches_y = []
        
        h, w = X.shape[:2]
        for i in range(0, h - patch_size + 1, stride):
            for j in range(0, w - patch_size + 1, stride):
                patch_X = X[i:i+patch_size, j:j+patch_size]
                patch_y = y[i:i+patch_size, j:j+patch_size]
                
                # Only include patches that have some burned or unburned pixels
                if np.any(patch_y) and np.mean(patch_y) < 0.9:
                    patches_X.append(patch_X)
                    patches_y.append(patch_y)
        
        # Convert to numpy arrays
        patches_X = np.array(patches_X)
        patches_y = np.array(patches_y)
        
        print(f"Created {len(patches_X)} patches with shape {patches_X.shape}")
        
        return patches_X, patches_y

# Prepare training data if fire data is available
if 'pre_fire_raster' in locals() and pre_fire_raster and burn_bndy:
    try:
        X_patches, y_patches = prepare_training_data(pre_fire_raster, post_fire_raster, burn_bndy, patch_size=64, stride=32)
        
        # Split into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_patches, y_patches, test_size=0.2, random_state=42)
        
        print(f"Training data: {X_train.shape}, {y_train.shape}")
        print(f"Validation data: {X_val.shape}, {y_val.shape}")
    except Exception as e:
        print(f"Error preparing training data: {e}")
else:
    print("Cannot prepare training data: Missing fire data or burn boundary.")

## Create and Train the CNN Model

We'll create a U-Net style CNN for semantic segmentation of burned areas.

In [None]:
# Define U-Net model for semantic segmentation
def create_unet_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)
    
    # Encoder (downsampling)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Bridge
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
    
    # Decoder (upsampling)
    up5 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv4)
    up5 = layers.concatenate([up5, conv3])
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(up5)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)
    
    up6 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv5)
    up6 = layers.concatenate([up6, conv2])
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(up6)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)
    
    up7 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv6)
    up7 = layers.concatenate([up7, conv1])
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(up7)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)
    
    # Output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv7)
    
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Create and compile model if training data is available
if 'X_train' in locals() and len(X_train) > 0:
    try:
        # Create model
        input_shape = X_train.shape[1:]
        model = create_unet_model(input_shape)
        
        # Compile model
        model.compile(
            optimizer='adam',
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
        )
        
        # Model summary
        model.summary()
        
        # Define callbacks
        callbacks = [
            EarlyStopping(patience=5, restore_best_weights=True),
            ModelCheckpoint('best_fire_cnn_model.h5', save_best_only=True)
        ]
        
        # Train model
        print("Training CNN model...")
        history = model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            epochs=30,
            batch_size=16,
            callbacks=callbacks
        )
        
        # Plot training history
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Train')
        plt.plot(history.history['val_loss'], label='Validation')
        plt.title('Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Binary Crossentropy')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], label='Train')
        plt.plot(history.history['val_accuracy'], label='Validation')
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"Error creating or training model: {e}")
else:
    print("Cannot train model: No training data available.")

## Evaluate Model Performance

Let's evaluate the model's performance on the validation set and visualize some predictions.

In [None]:
# Evaluate model if available
if 'model' in locals() and 'X_val' in locals():
    # Evaluate on validation set
    val_loss, val_acc, val_recall, val_precision = model.evaluate(X_val, y_val)
    val_f1 = 2 * (val_precision * val_recall) / (val_precision + val_recall)
    
    print("\nValidation Performance:")
    print(f"Loss: {val_loss:.4f}")
    print(f"Accuracy: {val_acc:.4f}")
    print(f"Precision: {val_precision:.4f}")
    print(f"Recall: {val_recall:.4f}")
    print(f"F1 Score: {val_f1:.4f}")
    
    # Make predictions on validation set
    y_pred = model.predict(X_val)
    y_pred_binary = (y_pred > 0.5).astype(np.uint8)
    
    # Calculate confusion matrix
    y_true_flat = y_val.flatten()
    y_pred_flat = y_pred_binary.flatten()
    cm = confusion_matrix(y_true_flat, y_pred_flat)
    
    # Visualize confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=['Unburned', 'Burned'],
                yticklabels=['Unburned', 'Burned'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    
    # Visualize some predictions
    n_samples = min(5, len(X_val))
    plt.figure(figsize=(15, n_samples * 3))
    
    for i in range(n_samples):
        # Display pre-fire image (false color: NIR, Red, Green)
        plt.subplot(n_samples, 4, i*4 + 1)
        rgb = np.zeros((X_val.shape[1], X_val.shape[2], 3))
        rgb[:,:,0] = X_val[i,:,:,3]  # NIR
        rgb[:,:,1] = X_val[i,:,:,2]  # Red
        rgb[:,:,2] = X_val[i,:,:,1]  # Green
        # Normalize for visualization
        for j in range(3):
            min_val = np.percentile(rgb[:,:,j], 2)
            max_val = np.percentile(rgb[:,:,j], 98)
            rgb[:,:,j] = np.clip((rgb[:,:,j] - min_val) / (max_val - min_val), 0, 1)
        plt.imshow(rgb)
        plt.title(f"Pre-Fire (Sample {i+1})")
        plt.axis('off')
        
        # Display post-fire image
        plt.subplot(n_samples, 4, i*4 + 2)
        rgb = np.zeros((X_val.shape[1], X_val.shape[2], 3))
        rgb[:,:,0] = X_val[i,:,:,7]  # NIR
        rgb[:,:,1] = X_val[i,:,:,6]  # Red
        rgb[:,:,2] = X_val[i,:,:,5]  # Green
        # Normalize for visualization
        for j in range(3):
            min_val = np.percentile(rgb[:,:,j], 2)
            max_val = np.percentile(rgb[:,:,j], 98)
            rgb[:,:,j] = np.clip((rgb[:,:,j] - min_val) / (max_val - min_val), 0, 1)
        plt.imshow(rgb)
        plt.title(f"Post-Fire (Sample {i+1})")
        plt.axis('off')
        
        # Display true burn mask
        plt.subplot(n_samples, 4, i*4 + 3)
        plt.imshow(y_val[i,:,:,0], cmap='gray')
        plt.title(f"True Burn Mask")
        plt.axis('off')
        
        # Display predicted burn mask
        plt.subplot(n_samples, 4, i*4 + 4)
        plt.imshow(y_pred[i,:,:,0], cmap='gray')
        plt.title(f"Predicted Burn Mask")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No model or validation data available for evaluation.")

## Apply the Model to a New Fire

Let's test our model on a different fire dataset to assess its generalization capability.

In [None]:
# Function to apply trained model to a new fire dataset
def apply_model_to_new_fire(model, fire_folder):
    print(f"Applying model to new fire dataset: {fire_folder}")
    
    # Load fire data
    pre_fire_raster, post_fire_raster, burn_bndy = load_fire_data(fire_folder)
    
    if not (pre_fire_raster and post_fire_raster):
        print("Cannot apply model: Missing fire data")
        return
    
    # Load the raster data
    with rasterio.open(pre_fire_raster) as pre_src, rasterio.open(post_fire_raster) as post_src:
        # Get metadata for output
        out_meta = post_src.meta.copy()
        
        # Read visible + NIR bands
        pre_bands = pre_src.read([1, 2, 3, 4])  # Blue, Green, Red, NIR
        post_bands = post_src.read([1, 2, 3, 4])
        
        # Transpose to height, width, channels
        pre_img = np.transpose(pre_bands, (1, 2, 0))
        post_img = np.transpose(post_bands, (1, 2, 0))
        
        # Calculate NDVI for pre and post
        pre_ndvi = (pre_bands[3] - pre_bands[2]) / (pre_bands[3] + pre_bands[2] + 1e-6)
        post_ndvi = (post_bands[3] - post_bands[2]) / (post_bands[3] + post_bands[2] + 1e-6)
        ndvi_diff = pre_ndvi - post_ndvi
        
        # Load actual burn mask if available
        if burn_bndy:
            actual_mask = create_burn_mask(burn_bndy, post_fire_raster)
        else:
            actual_mask = None
        
        # Prepare input data
        height, width = pre_img.shape[:2]
        X = np.zeros((height, width, 9), dtype=np.float32)
        X[:,:,0:4] = pre_img / 10000.0  # Normalize reflectance values
        X[:,:,4:8] = post_img / 10000.0
        X[:,:,8] = ndvi_diff
        
        # Process in patches to avoid memory issues
        patch_size = 64
        stride = 32  # For overlapping predictions
        
        # Create an empty prediction mask
        pred_mask = np.zeros((height, width), dtype=np.float32)
        count_mask = np.zeros((height, width), dtype=np.float32)  # For overlapping normalization
        
        # Generate predictions for each patch
        for i in range(0, height - patch_size + 1, stride):
            for j in range(0, width - patch_size + 1, stride):
                patch = X[i:i+patch_size, j:j+patch_size]
                
                # Add batch dimension
                patch = np.expand_dims(patch, axis=0)
                
                # Predict
                pred = model.predict(patch)[0,:,:,0]
                
                # Accumulate predictions (for overlapping areas)
                pred_mask[i:i+patch_size, j:j+patch_size] += pred
                count_mask[i:i+patch_size, j:j+patch_size] += 1
        
        # Average overlapping predictions
        pred_mask = pred_mask / np.maximum(count_mask, 1)  # Avoid division by zero
        
        # Create binary mask
        binary_mask = (pred_mask > 0.5).astype(np.uint8)
        
        # Display results
        fig, axs = plt.subplots(1, 3 if actual_mask is not None else 2, figsize=(15, 5))
        
        # NDVI difference
        im = axs[0].imshow(ndvi_diff, cmap='RdYlGn_r', vmin=-0.5, vmax=0.5)
        axs[0].set_title("NDVI Change (Pre - Post)")
        axs[0].axis('off')
        plt.colorbar(im, ax=axs[0], label='NDVI Difference')
        
        # Model prediction
        axs[1].imshow(binary_mask, cmap='gray')
        axs[1].set_title("Model Prediction (Burned Areas)")
        axs[1].axis('off')
        
        # Actual burn mask (if available)
        if actual_mask is not None:
            axs[2].imshow(actual_mask, cmap='gray')
            axs[2].set_title("Actual Burned Areas")
            axs[2].axis('off')
            
            # Calculate accuracy metrics
            acc = accuracy_score(actual_mask.flatten(), binary_mask.flatten())
            prec = precision_score(actual_mask.flatten(), binary_mask.flatten())
            rec = recall_score(actual_mask.flatten(), binary_mask.flatten())
            f1 = f1_score(actual_mask.flatten(), binary_mask.flatten())
            
            print(f"\nModel Performance on New Fire:")
            print(f"Accuracy: {acc:.4f}")
            print(f"Precision: {prec:.4f}")
            print(f"Recall: {rec:.4f}")
            print(f"F1 Score: {f1:.4f}")
        
        plt.tight_layout()
        plt.show()
        
        # Save prediction as GeoTIFF
        out_meta.update({
            'dtype': 'uint8',
            'count': 1
        })
        
        output_file = os.path.join(os.getcwd(), f"{fire_folder}_cnn_prediction.tif")
        with rasterio.open(output_file, 'w', **out_meta) as dest:
            dest.write(binary_mask.astype(np.uint8), 1)
        
        print(f"Prediction saved to: {output_file}")

# If we have a trained model and multiple fire datasets, test on another one
if 'model' in locals() and len(fire_folders) > 1:
    try:
        # Use the second fire dataset for testing
        test_fire = fire_folders[1]
        apply_model_to_new_fire(model, test_fire)
    except Exception as e:
        print(f"Error applying model to new fire: {e}")
else:
    print("Cannot test model on new fire: Missing model or additional fire datasets.")

## Conclusion and Future Work

In this notebook, we've built a CNN-based approach for wildfire burn area detection and segmentation. The model leverages multi-spectral satellite imagery and can effectively identify burned areas from pre-fire and post-fire imagery.

Future improvements could include:

1. **More training data**: Incorporating more diverse fire events for better generalization.
2. **Time series analysis**: Including multiple time points to capture the progression of fire and recovery.
3. **Additional features**: Incorporating topographic data (slope, aspect, elevation) and weather conditions.
4. **Model refinement**: Experimenting with different architectures like DeepLabV3+ or attention mechanisms.
5. **Early detection**: Adapting the model for early fire detection rather than post-fire mapping.
6. **Severity classification**: Extending the model to classify burn severity levels rather than binary classification.

The trained model could be integrated with the cellular automata simulation system to provide better initialization or validation data.

In [None]:
# Save the trained model for later use
if 'model' in locals():
    model_path = 'fire_detection_cnn_model.h5'
    model.save(model_path)
    print(f"Model saved to {model_path}")
    
    # Also save as TensorFlow SavedModel format for deployment
    tf_model_path = 'fire_detection_model'
    tf.saved_model.save(model, tf_model_path)
    print(f"Model saved in TensorFlow SavedModel format to {tf_model_path}")