##### Fire Generation for other areas outside of USA
Use the // pre and post fire images // to define a "year" for generating time series.

Then generate 4 images during that event year, and select the same range of months in previous years to sample for control years. 

Note: the event year should have a sample after the date of the post fire image.


In [1]:
import os
os.chdir("../")

import pystac_client
import pystac
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from pystac_client.stac_api_io import StacApiIO
import planetary_computer

import dask.distributed
import numpy as np
import rioxarray
import pandas as pd
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import yaml
from datetime import datetime, timedelta
import rasterio
import stackstac 
from src.utils import mask_cloudy_pixels
import matplotlib.pyplot as plt
import xarray as xr
from rasterio.features import rasterize
from shapely.geometry import mapping
from scipy.signal import convolve2d
from scipy.ndimage import uniform_filter
from shapely.geometry import shape
import warnings
warnings.filterwarnings("ignore")

In [2]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)
# config

In [3]:
fire_df = pd.read_csv("data/s2_wcd_fires.csv")
fire_df.head()

Unnamed: 0,location,event_type,date,path,geometry
0,Almonaster,pre,2022-07-12,/workspace/Rufai/data/S2-WCD/Almonaster/img1_c...,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
1,Almonaster,post,2022-07-27,/workspace/Rufai/data/S2-WCD/Almonaster/img2_c...,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
2,Almonaster,mask,,/workspace/Rufai/data/S2-WCD/Almonaster/cm/cm.tif,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
3,Attica,pre,2021-08-16,/workspace/Rufai/data/S2-WCD/Attica/img1_cropp...,MULTIPOLYGON (((-0.7064004420359428 38.0931080...
4,Attica,post,2021-08-28,/workspace/Rufai/data/S2-WCD/Attica/img2_cropp...,MULTIPOLYGON (((-0.7064004420359428 38.0931080...


In [4]:
# pre-date, post-date, geometry
from shapely.wkt import loads
df_filtered = fire_df[fire_df["event_type"].isin(["pre", "post"])]

summary = (
    df_filtered
    .groupby("location")
    .agg(
        geometry=("geometry", "first"),
        pre_date=("date", lambda x: sorted(x[df_filtered.loc[x.index, "event_type"] == "pre"])[0] if any(df_filtered.loc[x.index, "event_type"] == "pre") else None),
        post_date=("date", lambda x: sorted(x[df_filtered.loc[x.index, "event_type"] == "post"])[0] if any(df_filtered.loc[x.index, "event_type"] == "post") else None),
    )
    .reset_index()
)

if isinstance(summary["geometry"].iloc[0], str):
    summary["geometry"] = summary["geometry"].apply(loads)
summary_gdf = gpd.GeoDataFrame(summary, geometry=summary["geometry"], crs="EPSG:4326")
summary_gdf["pre_date"] = pd.to_datetime(summary_gdf["pre_date"])
summary_gdf["post_date"] = pd.to_datetime(summary_gdf["post_date"])

mask_paths = (
    fire_df[fire_df["event_type"] == "mask"]
    .groupby("location")["path"]
    .first()  #one fire mask per location
    .reset_index()
    .rename(columns={"path": "mask_path"})
)
summary_gdf = summary_gdf.merge(mask_paths, on="location", how="left")

summary_gdf.head()

Unnamed: 0,location,geometry,pre_date,post_date,mask_path
0,Almonaster,"MULTIPOLYGON (((-1.94651 41.37228, -1.94652 41...",2022-07-12,2022-07-27,/workspace/Rufai/data/S2-WCD/Almonaster/cm/cm.tif
1,Attica,"MULTIPOLYGON (((-0.7064 38.09311, -0.7064 38.0...",2021-08-16,2021-08-28,/workspace/Rufai/data/S2-WCD/Attica/cm/cm.tif
2,Australia_1,"POLYGON ((-4.47264 58.56369, -4.47263 58.5636,...",2021-01-31,2021-02-20,/workspace/Rufai/data/S2-WCD/Australia_1/cm/cm...
3,Australia_2,"MULTIPOLYGON (((-4.49479 58.50658, -4.49513 58...",2021-01-31,2021-02-20,/workspace/Rufai/data/S2-WCD/Australia_2/cm/cm...
4,Bejis,"MULTIPOLYGON (((-0.71135 39.88491, -0.71147 39...",2022-08-08,2022-08-23,/workspace/Rufai/data/S2-WCD/Bejis/cm/cm.tif


In [5]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()#(n_workers=8, threads_per_worker=2)
client = Client(cluster)
print(client.dashboard_link)

http://127.0.0.1:8787/status


In [6]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [7]:
# fire_mask = rioxarray.open_rasterio("/workspace/Rufai/data/S2-WCD/Almonaster/cm/cm.tif", masked=True).squeeze()
# fire_mask.plot()

In [8]:
# fire_mask.values

In [9]:
from datetime import timedelta
from datetime import timedelta
import xarray as xr
import rioxarray
import calendar

def get_date_ranges(row):
    post_date = row["post_date"]
    # Define the event year
    event_year = post_date.year

    # Define quarterly ranges but ensure at least one includes the post_date
    month_windows = [(12, 2), (3, 5), (6, 8), (9, 11)]
    event_ranges = []

    for start_month, end_month in month_windows:
        if start_month == 12 and end_month == 2:
            # December of previous year to February of event year
            start_year = event_year - 1
            end_year = event_year
        else:
            start_year = end_year = event_year
        
        start_day = f"{start_year}-{str(start_month).zfill(2)}-01"
        end_day = f"{end_year}-{str(end_month).zfill(2)}-{calendar.monthrange(end_year, end_month)[1]}"
        range_str = f"{start_day}/{end_day}"

        # Always include windows that are after post_date
        range_start_date = pd.to_datetime(start_day)
        if range_start_date >= post_date or (start_month <= post_date.month <= end_month):
            event_ranges.append(range_str)

    # For control years, replicate those same month windows in each previous year
    control_date_ranges = []
    for delta_year in range(1, 8):
        control_year = event_year - delta_year
        control_year_ranges = []
        for r in event_ranges:
            start, end = r.split('/')
            control_start = pd.to_datetime(start).replace(year=control_year).strftime("%Y-%m-%d")
            control_end = pd.to_datetime(end).replace(year=control_year).strftime("%Y-%m-%d")
            control_year_ranges.append(f"{control_start}/{control_end}")
        control_date_ranges.append(control_year_ranges)

    return event_ranges, control_date_ranges


def rasterize_aoi(aoi, s2_stack):
    """
    Rasterize the AOI polygon into a burn mask
    """

    aoi_gdf = gpd.GeoDataFrame(
        {"geometry": [shape(aoi['geometry'])]},
        crs="EPSG:4326"
    )
    
    aoi_proj = aoi_gdf.to_crs(s2_stack.rio.crs)
    
    burn_mask = rasterize(
        [(mapping(aoi_proj['geometry'].iloc[0]), 1)],
        out_shape=(s2_stack.sizes['y'], s2_stack.sizes['x']),
        transform=s2_stack.rio.transform(),
        fill=0,
        dtype='uint8'
    )
    
    burn_mask_da = xr.DataArray(
        burn_mask,
        coords={"y": s2_stack["y"], "x": s2_stack["x"]},
        dims=("y", "x")
    )

    return burn_mask_da



def read_fire_mask_from_file(mask_path, s2_stack):
    # Read fire mask as DataArray with CRS support
    fire_mask = rioxarray.open_rasterio(mask_path, masked=True).squeeze()
    # Ensure mask has a CRS
    if fire_mask.rio.crs is None:
        raise ValueError(f"Fire mask at {mask_path} has no CRS defined.")
    # Reproject and align to match S2 stack
    fire_mask_matched = fire_mask.rio.reproject_match(s2_stack)

    return fire_mask_matched


def crop_burn_window(s2_stack, burn_mask, config):
    
    window_size = config['chips']['chip_size']

    # Apply uniform filter (mean filter), then scale to get sum
    burn_mean = uniform_filter(burn_mask.values.astype(float), size=window_size, mode='constant', cval=0.0)
    burn_sum = burn_mean * (window_size ** 2)
    
    # Find maximum sum
    max_idx = np.unravel_index(np.argmax(burn_sum), burn_sum.shape)
    y_idx, x_idx = max_idx
    
    # Calculate start indices
    y_start = max(y_idx - window_size // 2, 0)
    x_start = max(x_idx - window_size // 2, 0)
        
    cropped_stack = s2_stack.isel(
        y=slice(y_start, y_start + window_size),
        x=slice(x_start, x_start + window_size)
    )
    
    return cropped_stack

def harmonize_to_old(data):
    """
    Harmonize new Sentinel-2 data to the old baseline.

    Parameters
    ----------
    data: xarray.DataArray
        A DataArray with four dimensions: time, band, y, x

    Returns
    -------
    harmonized: xarray.DataArray
        A DataArray with all values harmonized to the old
        processing baseline.
    """
    cutoff = datetime(2022, 1, 25)
    offset = 1000
    bands = [
        "B01",
        "B02",
        "B03",
        "B04",
        "B05",
        "B06",
        "B07",
        "B08",
        "B8A",
        "B09",
        "B10",
        "B11",
        "B12",
    ]

    old = data.sel(time=slice(cutoff))

    to_process = list(set(bands) & set(data.band.data.tolist()))
    new = data.sel(time=slice(cutoff, None)).drop_sel(band=to_process)

    new_harmonized = data.sel(time=slice(cutoff, None), band=to_process).clip(offset)
    new_harmonized -= offset

    new = xr.concat([new, new_harmonized], "band").sel(band=data.band.data.tolist())
    
    return xr.concat([old, new], dim="time")

In [10]:
def generate_fire_chips(s2_stack, aoi, config, time_series_type, epsg, index, metadata_df):
    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")

    burn_mask = rasterize_aoi(aoi, s2_stack) # why not read the burn mask directly? 
    # burn_mask = read_fire_mask_from_file(aoi["mask_path"], s2_stack)

    try:
        s2_stack_cropped = crop_burn_window(s2_stack, burn_mask, config)
    except:
        print("Cropping S2 stack failed; skipping AOI")
    
    if s2_stack_cropped.shape[2] != 224 or s2_stack_cropped.shape[3] != 224:
        print(f"Skipping chip ID {index} for mismatch dimensions")
        
        return False, metadata_df 
    
    if missing_values(s2_stack_cropped, config['chips']['chip_size'], config['chips']['chip_size']):
        print(f"Skipping chip ID {index} for missing values")
        return False, metadata_df      
    
    s2_stack_cropped = harmonize_to_old(s2_stack_cropped)
    
    s2_stack_cropped = s2_stack_cropped.fillna(-999)
    s2_stack_cropped = s2_stack_cropped.rio.write_nodata(-999)
    s2_stack_cropped = s2_stack_cropped.astype(np.dtype(np.int16))
    s2_stack_cropped = s2_stack_cropped.rename("s2")

    if time_series_type == "event":
        for dt in s2_stack_cropped.time.values:
            ts = pd.to_datetime(str(dt)) 
            s2_path = f"data/fire_data/s2_{index:06}_e_{ts.strftime('%Y%m%d')}.tif"
            s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)

            metadata_df = pd.concat([pd.DataFrame([[index,
                                                    ts.strftime('%Y%m%d'),
                                                    f"{index:06}_e_{ts.strftime('%Y%m%d')}",
                                                    "event",
                                                    s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                    s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                    epsg]
                                                  ],
                                                  columns=metadata_df.columns
                                                 ),
                                     metadata_df],
                                    ignore_index=True
                                   )
            
    else:
        dt = s2_stack_cropped.time.values[0]
        ts = pd.to_datetime(str(dt)) 
        s2_path = f"data/fire_data/s2_{index:06}_c_{ts.strftime('%Y%m%d')}.tif"
        s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)
        metadata_df = pd.concat([pd.DataFrame([[index,
                                                ts.strftime('%Y%m%d'),
                                                f"{index:06}_c_{ts.strftime('%Y%m%d')}",
                                                "control",
                                                s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                epsg]
                                              ],
                                      columns=metadata_df.columns
                                     ),
                         metadata_df],
                        ignore_index=True
                       )

    return True, metadata_df

In [11]:
metadata_df = pd.DataFrame(columns=["chip_id", "date", "sample_id", "type", "x_center", "y_center", "epsg"])
# metadata_df = pd.read_csv("data/metadata_df.csv")

In [12]:
for index, aoi in summary_gdf.iterrows():
    print(f"\nProcessing AOI at index {index}")

    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    event_date_ranges, control_date_ranges = get_date_ranges(aoi)
    for date_range in event_date_ranges:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<2:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_s2_data(s2_items, config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue
        
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])

    clipping_geom = aoi["geometry"]

    try:
        s2_stack = stackstac.stack(
            s2_items,
            assets=config["sentinel_2"]["bands"],
            epsg=epsg,
            resolution=config["sentinel_2"]["resolution"],
            fill_value=np.nan,
            bounds_latlon = clipping_geom.bounds
        )
        s2_stack = mask_cloudy_pixels(s2_stack)
        s2_stack = s2_stack.drop_sel(band="SCL")
    except Exception as e:
        print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
        continue

    event_status, metadata_df = generate_fire_chips(s2_stack, aoi, config, "event", epsg, index, metadata_df)
    if event_status:
        for control_date_range in control_date_ranges:
            s2_items = search_s2_scenes(aoi, control_date_range[0], catalog, config)
            
            if len(s2_items)<1:
                print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
                continue
        
            s2_stack = stack_s2_data(s2_items, config)
            if s2_stack is None:
                print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds} in control year")
                continue
            try:
                epsg = s2_items[0].properties["proj:epsg"]
            except:
                epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
            
            try:
                s2_stack = stackstac.stack(
                    s2_items,
                    assets=config["sentinel_2"]["bands"],
                    epsg=epsg,
                    resolution=config["sentinel_2"]["resolution"],
                    fill_value=np.nan,
                    bounds_latlon = clipping_geom.bounds
                )
                s2_stack = mask_cloudy_pixels(s2_stack)
                s2_stack = s2_stack.drop_sel(band="SCL")
            except Exception as e:
                print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
                continue
        
            control_status, metadata_df = generate_fire_chips(s2_stack, aoi, config, "control", epsg, index, metadata_df)
    
    metadata_df.to_csv('data/metadata_df.csv', index=False)


Processing AOI at index 0

Processing AOI at index 1
Skipping chip ID 1 for missing values

Processing AOI at index 2
Skipping chip ID 2 for missing values

Processing AOI at index 3
Skipping chip ID 3 for missing values

Processing AOI at index 4
Skipping chip ID 4 for missing values

Processing AOI at index 5
Skipping chip ID 5 for missing values

Processing AOI at index 6
Skipping chip ID 6 for missing values

Processing AOI at index 7

Processing AOI at index 8
Skipping chip ID 8 for missing values

Processing AOI at index 9
Missing Sentinel-2 scenes for AOI (-0.3022279953603219, 38.787349583081784, -0.1977485520888272, 38.82807764488892)

Processing AOI at index 10
Missing Sentinel-2 scenes for AOI (-4.057573790038135, 41.03112188063114, -3.9300354295976905, 41.12805034396455)

Processing AOI at index 11
Missing Sentinel-2 scenes for AOI (-3.7992271949845042, 41.14634551157562, -3.704587757944652, 41.24256398914331)

Processing AOI at index 12
Missing Sentinel-2 scenes for AOI (-

In [13]:
event_date_ranges

['2022-06-01/2022-08-31', '2022-09-01/2022-11-30']

In [14]:
control_date_ranges

[['2021-06-01/2021-08-31', '2021-09-01/2021-11-30'],
 ['2020-06-01/2020-08-31', '2020-09-01/2020-11-30'],
 ['2019-06-01/2019-08-31', '2019-09-01/2019-11-30'],
 ['2018-06-01/2018-08-31', '2018-09-01/2018-11-30'],
 ['2017-06-01/2017-08-31', '2017-09-01/2017-11-30'],
 ['2016-06-01/2016-08-31', '2016-09-01/2016-11-30'],
 ['2015-06-01/2015-08-31', '2015-09-01/2015-11-30']]