<a href="https://colab.research.google.com/github/your-repo/your-project/blob/main/v2/nb/process_raw_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Process Raw Wildfire Data

This notebook processes raw parquet files containing wildfire data. It:
1. Loads raw data from parquet files
2. Processes each location's time series (interpolating missing values)
3. Saves the processed data for later use

In [1]:
# Install required packages
!pip install dask[dataframe] pyarrow

Collecting dask-expr<1.2,>=1.1 (from dask[dataframe])
  Downloading dask_expr-1.1.21-py3-none-any.whl.metadata (2.6 kB)
INFO: pip is looking at multiple versions of dask-expr to determine which version is compatible with other requirements. This could take a while.
  Downloading dask_expr-1.1.20-py3-none-any.whl.metadata (2.6 kB)
  Downloading dask_expr-1.1.19-py3-none-any.whl.metadata (2.6 kB)
  Downloading dask_expr-1.1.18-py3-none-any.whl.metadata (2.6 kB)
  Downloading dask_expr-1.1.16-py3-none-any.whl.metadata (2.5 kB)
Downloading dask_expr-1.1.16-py3-none-any.whl (243 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.2/243.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: dask-expr
Successfully installed dask-expr-1.1.16


In [2]:
# Mount Google Drive to access data
from google.colab import drive, runtime
# Import required modules
import os
import glob
import time
from datetime import datetime
import numpy as np
import pandas as pd

# Try to use Dask for efficient reading
try:
    import dask.dataframe as dd
    USE_DASK = True
except ImportError:
    USE_DASK = False

print(f"using dask: {USE_DASK}")

using dask: True


In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
year=2024
LOC_BATCH_SIZE=50000

# Set paths
# DATA_DIR = '/content/drive/My Drive/Colab Notebooks/wildfire/data/raw/2023'  # Update this
# OUTPUT_PATH = '/content/drive/My Drive/Colab Notebooks/wildfire/data/processed/2023'  # Update this
# DATA_DIR = f'/content/drive/MyDrive/GEE_Exports/Alberta/{year}'  # Update this
DATA_DIR = f'/content/drive/MyDrive/GEE_Exports/Alberta/{year}/2024_8-10'  # Update this
# OUTPUT_PATH = f'/content/drive/My Drive/Colab Notebooks/wildfire/new_data/processed/{year}'  # Update this
OUTPUT_PATH = f'/content/drive/My Drive/Colab Notebooks/wildfire/new_data/processed/2025'  # Update this
OUTPUT_TOP_PATH = f'/content/drive/My Drive/Colab Notebooks/wildfire/new_data/processed/'

PARQUET_PATTERN="*_n5.parquet"

# DATA_DIR = f"/content/drive/MyDrive/GEE_Exports/Alberta/{year}/final_files/final_file-{year}-{month}_n3.parquet"
# OUTPUT_PATH = f'/content/drive/MyDrive/Colab Notebooks/wildfire/data/processed/{year}'

NUMERIC_FEATURES = ['LST_Day_1km', 'LST_Night_1km',
       'Emis_31', 'Emis_32', 'dewpoint_temperature_2m',
       'temperature_2m', 'soil_temperature_level_1',
       'surface_net_thermal_radiation', 'u_component_of_wind_10m',
       'v_component_of_wind_10m', 'surface_pressure', 'total_precipitation', 'elevation', 'NDVI']
LABEL = 'targetY'
GROUP_COLS = ['longitude', 'latitude']

# NEEDED_COLUMNS=['date', 'longitude', 'latitude', 'LST_Day_1km', 'LST_Night_1km',
#        'Emis_31', 'Emis_32', 'dewpoint_temperature_2m',
#        'temperature_2m', 'soil_temperature_level_1',
#        'surface_net_thermal_radiation', 'u_component_of_wind_10m',
#        'v_component_of_wind_10m', 'surface_pressure', 'total_precipitation','elevation', 'NDVI', 'targetY']
NEEDED_COLUMNS=['date', 'longitude', 'latitude', 'LST_Day_1km', 'LST_Night_1km',
       'Emis_31', 'Emis_32', 'dewpoint_temperature_2m',
       'temperature_2m', 'soil_temperature_level_1',
       'surface_net_thermal_radiation', 'u_component_of_wind_10m',
       'v_component_of_wind_10m', 'surface_pressure', 'total_precipitation','elevation', 'NDVI', 'targetY', 'targetY_o1', 'targetY_o2', 'targetY_o3', 'targetY_prob', 'targetY_o1_prob', 'targetY_o2_prob', 'targetY_o3_prob']

In [5]:
def validate_processed_batch(batch_df):
    """Validate that a processed batch has no duplicates."""
    total_records = len(batch_df)
    unique_records = len(batch_df.drop_duplicates(subset=['date', 'longitude', 'latitude']))

    if total_records != unique_records:
        print(f"WARNING: Found {total_records - unique_records} duplicates in processed batch!")
        return False

    print("Validation passed: No duplicates found in processed batch")
    return True

def process_single_group(df):
    """Process a single location time series."""
    # start_time = time.time()

    # Sort by date
    df = df.sort_values('date').reset_index(drop=True)

    # drop duplicates based on (data, longitude, latitude)
    # NOTE: not required, due to drop duplicates at previous feature combination step
    # df = df.drop_duplicates(subset=['date', 'longitude', 'latitude'])

    # Interpolate missing values
    for col in NUMERIC_FEATURES:
        if col in df.columns and df[col].isnull().sum() > 0:
            # First try linear interpolation
            df[col] = df[col].interpolate(method='linear', limit_direction='both')
            # Then forward fill and backward fill any remaining nulls
            df[col] = df[col].ffill().bfill()

    return df

def extract_unique_locations(data_dir, file_pattern=PARQUET_PATTERN, output_dir=OUTPUT_TOP_PATH):
    """Extract unique (longitude, latitude) pairs from one monthly file or load from cache."""
    locations_cache_path = os.path.join(output_dir, 'all_locations_groups.csv')

    # Check if cached locations file exists
    if os.path.exists(locations_cache_path):
        print(f"\nLoading cached locations from {locations_cache_path}")
        unique_locations = pd.read_csv(locations_cache_path)
        print(f"Loaded {len(unique_locations)} unique locations from cache")
        return unique_locations

    print(f"\nExtracting unique locations at {datetime.now().strftime('%H:%M:%S')}")

    # Get first parquet file only
    parquet_files = glob.glob(os.path.join(data_dir, file_pattern))
    first_file = parquet_files[0]

    # Read only location columns
    location_df = pd.read_parquet(first_file, columns=['longitude', 'latitude'])
    unique_locations = location_df.drop_duplicates(subset=['longitude', 'latitude'])

    # Cache the results
    print(f"Caching {len(unique_locations)} unique locations to {locations_cache_path}")
    unique_locations.to_csv(locations_cache_path, index=False)

    return unique_locations

def process_location_batch(location_batch, data_dir, file_pattern=PARQUET_PATTERN):
    """Process a batch of locations across all monthly files."""
    print(f"Processing batch of {len(location_batch)} locations")

    # Get all monthly files
    parquet_files = sorted(glob.glob(os.path.join(data_dir, file_pattern)))

    # Initialize empty list to store monthly data for this batch
    all_data = []

    # Process each monthly file
    for file_path in parquet_files:
        print(f"Reading file: {os.path.basename(file_path)}")

        # Read only the rows for our location batch
        df = pd.read_parquet(file_path, columns=NEEDED_COLUMNS)
        mask = df.set_index(['longitude', 'latitude']).index.isin(
            location_batch.set_index(['longitude', 'latitude']).index
        )
        batch_monthly = df[mask]
        """ NOTE: this is done in above preprocessing step
        # Drop any duplicates within this file
        initial_count = len(batch_monthly)
        batch_monthly = batch_monthly.drop_duplicates(
            subset=['date', 'longitude', 'latitude'],
            keep='last'  # Keep the last occurrence in case of duplicates
        )
        if initial_count != len(batch_monthly):
            print(f"  Removed {initial_count - len(batch_monthly)} duplicates from {os.path.basename(file_path)}")
        """
        all_data.append(batch_monthly)
        del df, batch_monthly

    # Combine all monthly data for this batch
    print("Combining data from all files...")
    combined_batch = pd.concat(all_data, ignore_index=True)
    initial_count = len(combined_batch)

    # Remove duplicates across files, keeping the latest version of each record
    combined_batch = combined_batch.sort_values('date').drop_duplicates(
        subset=['date', 'longitude', 'latitude'],
        keep='last'
    )

    if initial_count != len(combined_batch):
        print(f"Removed {initial_count - len(combined_batch)} duplicate records across files")

    del all_data
    return combined_batch

def batch_process_locations(data_dir, output_dir, batch_size=1000):
    """Process locations in batches to manage memory usage."""
    batch_start = time.time()
    print(f"\n{'='*80}")
    print(f"Starting batch processing at {datetime.now().strftime('%H:%M:%S')}")
    print(f"Data directory: {data_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Batch size: {batch_size}")
    print(f"{'='*80}\n")

    os.makedirs(output_dir, exist_ok=True)

    # First, get all unique locations from one file
    unique_locations = extract_unique_locations(data_dir)
    total_locations = len(unique_locations)

    # Process locations in batches
    total_batches = (total_locations + batch_size - 1) // batch_size
    location_mapping = {}

    print(f"\nProcessing {total_locations} locations in {total_batches} batches")
    print(f"{'='*80}")

    for batch_idx in range(total_batches):
    # for batch_idx in range(1):
        batch_iteration_start = time.time()
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, total_locations)

        print(f"\nBatch {batch_idx + 1}/{total_batches}")
        print(f"Processing locations {start_idx} to {end_idx} ({end_idx - start_idx} locations)")

        # Get location batch
        location_batch = unique_locations.iloc[start_idx:end_idx]

        # Process this batch across all monthly files
        print("Reading and combining monthly files for batch...")
        batch_data = process_location_batch(location_batch, data_dir)

        # Process each location group in the batch
        print("Processing individual location groups...")
        group_start = time.time()
        processed_groups = []
        for idx, ((lon, lat), group) in enumerate(batch_data.groupby(['longitude', 'latitude'])):
            if idx % 10000 == 0:
                print(f"  Processing location {idx + 1}/{len(location_batch)} in current batch")
            processed_group = process_single_group(group)
            processed_groups.append(processed_group)

            # Update mapping with minimal necessary information
            location_mapping[(lon, lat)] = {
                'batch_file': f'batch_{batch_idx + 1:04d}.parquet',
                'mapping_idx': len(location_mapping) # NOTE: this idx is used by seq to reversely trace back to batch file(as seq loses the loc info already)
            }

        group_duration = time.time() - group_start
        print(f"Group processing completed in {group_duration:.2f} seconds")

        # Combine and save processed groups
        print("Saving batch results...")
        save_start = time.time()
        batch_df = pd.concat(processed_groups, ignore_index=True)

        # Validate before saving
        if validate_processed_batch(batch_df):
            batch_df.to_parquet(os.path.join(output_dir, f'batch_{batch_idx + 1:04d}.parquet'))
        else:
            print("ERROR: Batch validation failed!")
            # Handle the error as needed

        save_duration = time.time() - save_start

        # Clean up
        del batch_data, processed_groups, batch_df

        # Log batch completion
        batch_duration = time.time() - batch_iteration_start
        print(f"\nBatch {batch_idx + 1} completed:")
        print(f"  Total batch processing time: {batch_duration:.2f} seconds")
        print(f"  Average time per location: {batch_duration / (end_idx - start_idx):.2f} seconds")
        print(f"  Save time: {save_duration:.2f} seconds")
        print(f"{'='*80}")

    # Save location mapping
    print("\nSaving location mapping...")
    mapping_df = pd.DataFrame.from_dict(location_mapping, orient='index')
    mapping_df.index = pd.MultiIndex.from_tuples(mapping_df.index, names=['longitude', 'latitude'])
    mapping_df.to_parquet(os.path.join(output_dir, 'location_mapping.parquet'))

    total_duration = time.time() - batch_start
    print(f"\nAll batches completed in {total_duration:.2f} seconds")
    print(f"Average time per batch: {total_duration / total_batches:.2f} seconds")
    print(f"Average time per location: {total_duration / total_locations:.2f} seconds")
    print(f"{'='*80}\n")

def validate_location_groups(data_dir, file_pattern=PARQUET_PATTERN, output_dir=OUTPUT_TOP_PATH):
    """
    Validate that all parquet files contain the same location groups as the cached version.
    Returns True if validation passes, False otherwise.
    """
    print(f"\n{'='*80}")
    print(f"Starting location group validation at {datetime.now().strftime('%H:%M:%S')}")

    # Load cached locations
    cache_path = os.path.join(output_dir, 'all_locations_groups.csv')
    if not os.path.exists(cache_path):
        print(f"Error: Cache file not found at {cache_path}")
        return False

    cached_locations = pd.read_csv(cache_path)
    cached_locations_set = set(
        zip(cached_locations['longitude'], cached_locations['latitude'])
    )
    print(f"Loaded {len(cached_locations_set)} cached unique locations")

    # Get all parquet files
    parquet_files = glob.glob(os.path.join(data_dir, file_pattern))
    print(f"Found {len(parquet_files)} parquet files to validate")

    validation_passed = True
    for file_path in parquet_files:
        file_start = time.time()
        print(f"\nValidating: {os.path.basename(file_path)}")

        # Read only location columns
        try:
            location_df = pd.read_parquet(file_path, columns=['longitude', 'latitude'])
            file_locations_set = set(
                zip(location_df['longitude'], location_df['latitude'])
            )

            # Compare sets
            if file_locations_set != cached_locations_set:
                validation_passed = False
                missing_in_file = cached_locations_set - file_locations_set
                extra_in_file = file_locations_set - cached_locations_set

                print(f"ERROR: Location mismatch in {os.path.basename(file_path)}")
                print(f"  - File has {len(file_locations_set)} unique locations")
                print(f"  - Cache has {len(cached_locations_set)} unique locations")
                if missing_in_file:
                    print(f"  - {len(missing_in_file)} locations missing in file")
                if extra_in_file:
                    print(f"  - {len(extra_in_file)} extra locations in file")
            else:
                print(f"✓ Validation passed: {len(file_locations_set)} locations match")

            file_duration = time.time() - file_start
            print(f"  Time taken: {file_duration:.2f} seconds")

        except Exception as e:
            validation_passed = False
            print(f"ERROR: Failed to process {os.path.basename(file_path)}")
            print(f"Error details: {str(e)}")

    print(f"\n{'='*80}")
    if validation_passed:
        print("✓ All files passed location group validation")
    else:
        print("✗ Validation failed - see errors above")
    print(f"{'='*80}\n")

    return validation_passed

In [6]:
# Process the data
total_start = time.time()
print(f"Starting data processing pipeline at {datetime.now()}")

# Process the data in batches
batch_process_locations(DATA_DIR, OUTPUT_PATH, batch_size=LOC_BATCH_SIZE)

total_duration = time.time() - total_start
print(f"\nEntire processing pipeline completed in {total_duration:.2f} seconds")
print(f"Finished at {datetime.now()}")

Starting data processing pipeline at 2025-02-19 16:09:43.067945

Starting batch processing at 16:09:43
Data directory: /content/drive/MyDrive/GEE_Exports/Alberta/2024/2024_8-10
Output directory: /content/drive/My Drive/Colab Notebooks/wildfire/new_data/processed/2025
Batch size: 50000


Loading cached locations from /content/drive/My Drive/Colab Notebooks/wildfire/new_data/processed/all_locations_groups.csv
Loaded 767323 unique locations from cache

Processing 767323 locations in 16 batches

Batch 1/16
Processing locations 0 to 50000 (50000 locations)
Reading and combining monthly files for batch...
Processing batch of 50000 locations
Reading file: final_file-08_n5.parquet
Reading file: final_file-09_n5.parquet
Reading file: final_file-10_n5.parquet
Combining data from all files...
Processing individual location groups...
  Processing location 1/50000 in current batch
  Processing location 10001/50000 in current batch
  Processing location 20001/50000 in current batch
  Processing loca

In [7]:
runtime.unassign()