Prithvi EO 2.0 Burn Scar Dataset Generator

This script processes multi-temporal satellite imagery to create training data for
the Prithvi EO 2.0 model for burn scar detection and severity classification.

Input: GeoTIFF with 13 bands (6 pre-fire, 6 post-fire, 1 label)
Output: Temporal image chips and corresponding masks in NumPy format

Author: Tushar Thokdar

In [1]:
import os
import numpy as np
import rasterio
from tqdm import tqdm
from glob import glob
from collections import Counter

In [2]:
# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration parameters for dataset generation"""

    # File paths
    INPUT_TIF = "/content/drive/MyDrive/data_in_TIFF/Prithvi_PrePost_Training_Data.tif"
    OUTPUT_DIR = "prithvi_dataset"

    # Processing parameters
    TILE_SIZE = 224
    IGNORE_VALUE = 255
    EXPECTED_BANDS = 13  # 6 pre + 6 post + 1 label

    # Quality thresholds
    MIN_VALID_PIXELS = 0.95  # Minimum 95% valid reflectance pixels
    MIN_LABELED_PIXELS = 0.01  # Minimum 1% labeled pixels

    # Data scaling
    REFLECTANCE_SCALE = 10000.0
    REFLECTANCE_THRESHOLD = 1.5  # If max > this, apply scaling

    # Optional: Include delta (difference) channel
    INCLUDE_DELTA = True

In [3]:
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def create_output_directories(base_dir: str) -> None:
    """Create necessary output directories"""
    os.makedirs(f"{base_dir}/temporal_images", exist_ok=True)
    os.makedirs(f"{base_dir}/masks", exist_ok=True)
    print(f"üìÅ Output directories created: {base_dir}/")


def scale_reflectance(data: np.ndarray, scale: float = Config.REFLECTANCE_SCALE) -> np.ndarray:
    """
    Scale reflectance values if needed

    Args:
        data: Input reflectance array
        scale: Scaling factor

    Returns:
        Scaled reflectance array
    """
    if np.nanmax(data) > Config.REFLECTANCE_THRESHOLD:
        return data / scale
    return data


def clean_array(data: np.ndarray) -> np.ndarray:
    """
    Remove NaN and Inf values from array

    Args:
        data: Input array

    Returns:
        Cleaned array with NaN/Inf replaced by 0
    """
    return np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)


def process_label(raw_label: np.ndarray, bad_pixels_mask: np.ndarray) -> np.ndarray:
    """
    Process raw labels to class indices (0-4) with ignore regions

    Args:
        raw_label: Raw label array (values 1-5)
        bad_pixels_mask: Boolean mask of invalid pixels

    Returns:
        Processed label array (0-4 for valid, 255 for ignore)
    """
    label = np.full(raw_label.shape, Config.IGNORE_VALUE, dtype=np.uint8)

    # Map valid labels from 1-5 to 0-4
    valid = (raw_label >= 1) & (raw_label <= 5)
    label[valid] = raw_label[valid] - 1

    # Mark bad pixels as ignore
    label[bad_pixels_mask] = Config.IGNORE_VALUE

    return label


def passes_quality_checks(bad_pixels_mask: np.ndarray, label: np.ndarray) -> bool:
    """
    Check if chip meets quality thresholds

    Args:
        bad_pixels_mask: Boolean mask of invalid pixels
        label: Processed label array

    Returns:
        True if chip passes quality checks
    """
    # Check valid pixel percentage
    valid_pixel_ratio = (~bad_pixels_mask).mean()
    if valid_pixel_ratio < Config.MIN_VALID_PIXELS:
        return False

    # Check labeled pixel percentage
    labeled_pixel_ratio = (label != Config.IGNORE_VALUE).mean()
    if labeled_pixel_ratio < Config.MIN_LABELED_PIXELS:
        return False

    return True



In [4]:
# ============================================================================
# MAIN PROCESSING FUNCTION
# ============================================================================

def generate_dataset(config: Config = Config()) -> int:
    """
    Generate training dataset from multi-temporal satellite imagery

    Args:
        config: Configuration object

    Returns:
        Number of chips generated
    """
    print("üöÄ Starting Prithvi EO 2.0 Dataset Generation")
    print(f"   Input: {config.INPUT_TIF}")
    print(f"   Output: {config.OUTPUT_DIR}")
    print(f"   Tile size: {config.TILE_SIZE}x{config.TILE_SIZE}")
    print(f"   Delta channel: {'Enabled' if config.INCLUDE_DELTA else 'Disabled'}")
    print("-" * 70)

    # Create output directories
    create_output_directories(config.OUTPUT_DIR)

    chip_count = 0

    with rasterio.open(config.INPUT_TIF) as src:
        height, width = src.height, src.width

        # Validate input bands
        if src.count != config.EXPECTED_BANDS:
            raise ValueError(
                f"Expected {config.EXPECTED_BANDS} bands, got {src.count}"
            )

        print(f"üìä Image dimensions: {height}x{width}")
        print(f"üî¢ Processing {src.count} bands")
        print()

        # Process tiles
        total_tiles = ((height - config.TILE_SIZE) // config.TILE_SIZE + 1) * \
                      ((width - config.TILE_SIZE) // config.TILE_SIZE + 1)

        with tqdm(total=total_tiles, desc="Generating chips") as pbar:
            for row in range(0, height - config.TILE_SIZE + 1, config.TILE_SIZE):
                for col in range(0, width - config.TILE_SIZE + 1, config.TILE_SIZE):

                    # Read tile
                    window = rasterio.windows.Window(
                        col, row, config.TILE_SIZE, config.TILE_SIZE
                    )
                    data = src.read(window=window)  # (13, 224, 224)

                    # Split bands
                    pre_fire = data[0:6].astype(np.float32)
                    post_fire = data[6:12].astype(np.float32)
                    raw_label = data[12]

                    # Scale reflectance if needed
                    pre_fire = scale_reflectance(pre_fire)
                    post_fire = scale_reflectance(post_fire)

                    # Identify bad pixels
                    bad_pixels = ~np.isfinite(pre_fire).all(axis=0)

                    # Clean arrays
                    pre_fire = clean_array(pre_fire)
                    post_fire = clean_array(post_fire)

                    # Process labels
                    label = process_label(raw_label, bad_pixels)

                    # Quality checks
                    if not passes_quality_checks(bad_pixels, label):
                        pbar.update(1)
                        continue

                    # Create temporal stack
                    if config.INCLUDE_DELTA:
                        # Calculate change
                        delta = np.clip(post_fire - pre_fire, -1.0, 1.0)
                        temporal = np.stack([pre_fire, post_fire, delta], axis=0)
                        # Clip reflectance channels only
                        temporal[0:2] = np.clip(temporal[0:2], 0, 1)
                    else:
                        temporal = np.stack([pre_fire, post_fire], axis=0)
                        temporal = np.clip(temporal, 0, 1)

                    temporal = temporal.astype(np.float32)

                    # Save chip
                    np.save(
                        f"{config.OUTPUT_DIR}/temporal_images/chip_{chip_count:06d}.npy",
                        temporal
                    )
                    np.save(
                        f"{config.OUTPUT_DIR}/masks/chip_{chip_count:06d}.npy",
                        label
                    )

                    chip_count += 1
                    pbar.update(1)

    print()
    print("‚úÖ Dataset generation complete!")
    print(f"üì¶ Total chips generated: {chip_count}")
    print(f"üìÇ Location: {config.OUTPUT_DIR}/")

    return chip_count

In [5]:
# ============================================================================
# DATASET ANALYSIS
# ============================================================================

def analyze_dataset(dataset_dir: str) -> dict:
    """
    Analyze generated dataset for class distribution and statistics

    Args:
        dataset_dir: Path to dataset directory

    Returns:
        Dictionary containing dataset statistics
    """
    print("\n" + "=" * 70)
    print("üìä DATASET ANALYSIS")
    print("=" * 70)

    mask_files = glob(f"{dataset_dir}/masks/*.npy")

    if not mask_files:
        print("‚ö†Ô∏è  No mask files found!")
        return {}

    # Collect all labels
    all_labels = []
    for mask_file in tqdm(mask_files, desc="Analyzing masks"):
        mask = np.load(mask_file)
        valid_pixels = mask != Config.IGNORE_VALUE
        all_labels.extend(mask[valid_pixels].tolist())

    # Calculate statistics
    label_counts = Counter(all_labels)
    total_pixels = len(all_labels)

    print(f"\nüìà Class Distribution:")
    print(f"{'Class':<15} {'Count':<12} {'Percentage':<12}")
    print("-" * 40)

    class_names = {
        0: "Unburned",
        1: "Low Severity",
        2: "Moderate-Low",
        3: "Moderate-High",
        4: "High Severity"
    }

    stats = {}
    for class_id in sorted(label_counts.keys()):
        count = label_counts[class_id]
        percentage = (count / total_pixels) * 100
        class_name = class_names.get(class_id, f"Class {class_id}")
        print(f"{class_name:<15} {count:<12,} {percentage:<12.2f}%")
        stats[class_name] = {"count": count, "percentage": percentage}

    print("-" * 40)
    print(f"{'Total':<15} {total_pixels:<12,} {100.0:<12.2f}%")

    # Load and check first sample
    print(f"\nüîç Sample Data Check:")
    sample_img = np.load(f"{dataset_dir}/temporal_images/chip_000000.npy")
    sample_mask = np.load(f"{dataset_dir}/masks/chip_000000.npy")

    print(f"   Image shape: {sample_img.shape}")
    print(f"   Image range: [{sample_img.min():.4f}, {sample_img.max():.4f}]")
    print(f"   Mask unique values: {np.unique(sample_mask)}")

    if Config.INCLUDE_DELTA:
        print(f"   Delta channel range: [{sample_img[2].min():.4f}, {sample_img[2].max():.4f}]")

    stats['total_samples'] = len(mask_files)
    stats['total_pixels'] = total_pixels

    return stats

In [6]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    # Generate dataset
    num_chips = generate_dataset()

    # Analyze results
    if num_chips > 0:
        analyze_dataset(Config.OUTPUT_DIR)

        print("\n" + "=" * 70)
        print("üíæ Next Steps:")
        print("=" * 70)
        print("1. Review the class distribution above")
        print("2. Zip the dataset for backup/sharing:")
        print(f"   !zip -r prithvi_dataset.zip {Config.OUTPUT_DIR}/")
        print("3. Copy to Google Drive (if applicable):")
        print(f"   !cp prithvi_dataset.zip /content/drive/MyDrive/")
        print("=" * 70)
    else:
        print("\n‚ö†Ô∏è  No chips were generated. Please check your input data and thresholds.")


üöÄ Starting Prithvi EO 2.0 Dataset Generation
   Input: /content/drive/MyDrive/data_in_TIFF/Prithvi_PrePost_Training_Data.tif
   Output: prithvi_dataset
   Tile size: 224x224
   Delta channel: Enabled
----------------------------------------------------------------------
üìÅ Output directories created: prithvi_dataset/
üìä Image dimensions: 2785x2228
üî¢ Processing 13 bands



Generating chips: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [00:05<00:00, 20.56it/s]



‚úÖ Dataset generation complete!
üì¶ Total chips generated: 108
üìÇ Location: prithvi_dataset/

üìä DATASET ANALYSIS


Analyzing masks: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [00:00<00:00, 1169.24it/s]



üìà Class Distribution:
Class           Count        Percentage  
----------------------------------------
Unburned        1,369,197    27.00       %
Low Severity    771,342      15.21       %
Moderate-Low    924,961      18.24       %
Moderate-High   822,241      16.22       %
High Severity   1,182,633    23.32       %
----------------------------------------
Total           5,070,374    100.00      %

üîç Sample Data Check:
   Image shape: (3, 6, 224, 224)
   Image range: [-0.3542, 0.4540]
   Mask unique values: [  0   1   2   3   4 255]
   Delta channel range: [-0.3542, 0.1316]

üíæ Next Steps:
1. Review the class distribution above
2. Zip the dataset for backup/sharing:
   !zip -r prithvi_dataset.zip prithvi_dataset/
3. Copy to Google Drive (if applicable):
   !cp prithvi_dataset.zip /content/drive/MyDrive/
