##### Pipeline for Generating Time Series Burn scar datasets for GFM Bench (outside of USA) 
In this pipeline, we use the `pre and post fire images   to define a "year" for generating the time series, generate 4 images during that event year, and use the same range of months in previous years to sample for control years. We ensure that the event year has a sample after the date of the post fire image.

In [1]:
import os
os.chdir("../")
os.environ["CPL_VSIL_CURL_NUM_CONNECTIONS"] = "20"

In [2]:
import yaml
from datetime import datetime, timedelta
import calendar

import pystac
import pystac_client
import planetary_computer
import stackstac 
from pystac_client.stac_api_io import StacApiIO
from urllib3 import Retry
from requests.adapters import HTTPAdapter
import dask.distributed
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import xarray as xr
import rioxarray
from dask.distributed import Client, LocalCluster

from rasterio.features import rasterize
from shapely.geometry import mapping
from shapely.wkt import loads
from scipy.signal import convolve2d
from scipy.ndimage import uniform_filter
from shapely.geometry import shape
from src.utils import (search_s2_scenes, 
                       stack_s2_data, 
                       unique_class, 
                       missing_values, 
                       gen_chips, 
                       mask_cloudy_pixels
)
import warnings
warnings.filterwarnings("ignore")

In [3]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [4]:
# load fire events data
fire_df = pd.read_csv("data/s2_wcd_fires.csv")
fire_df.head()

Unnamed: 0,location,event_type,date,path,geometry
0,Almonaster,pre,2022-07-12,/workspace/Rufai/data/S2-WCD/Almonaster/img1_c...,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
1,Almonaster,post,2022-07-27,/workspace/Rufai/data/S2-WCD/Almonaster/img2_c...,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
2,Almonaster,mask,,/workspace/Rufai/data/S2-WCD/Almonaster/cm/cm.tif,MULTIPOLYGON (((-1.9465144286373106 41.3722766...
3,Attica,pre,2021-08-16,/workspace/Rufai/data/S2-WCD/Attica/img1_cropp...,MULTIPOLYGON (((-0.7064004420359428 38.0931080...
4,Attica,post,2021-08-28,/workspace/Rufai/data/S2-WCD/Attica/img2_cropp...,MULTIPOLYGON (((-0.7064004420359428 38.0931080...


In [5]:
def get_fire_mask_path(fire_df):
    df_filtered = fire_df[fire_df["event_type"].isin(["pre", "post"])]
    summary = (
        df_filtered
        .groupby("location")
        .agg(
            geometry=("geometry", "first"),
            pre_date=("date", lambda x: sorted(x[df_filtered.loc[x.index, "event_type"] == "pre"])[0] if any(df_filtered.loc[x.index, "event_type"] == "pre") else None),
            post_date=("date", lambda x: sorted(x[df_filtered.loc[x.index, "event_type"] == "post"])[0] if any(df_filtered.loc[x.index, "event_type"] == "post") else None),
        )
        .reset_index()
    )

    if isinstance(summary["geometry"].iloc[0], str):
        summary["geometry"] = summary["geometry"].apply(loads)
    summary_gdf = gpd.GeoDataFrame(summary, geometry=summary["geometry"], crs="EPSG:4326")
    summary_gdf["pre_date"] = pd.to_datetime(summary_gdf["pre_date"])
    summary_gdf["post_date"] = pd.to_datetime(summary_gdf["post_date"])

    mask_paths = (
        fire_df[fire_df["event_type"] == "mask"]
        .groupby("location")["path"]
        .first()  #one fire mask per location
        .reset_index()
        .rename(columns={"path": "mask_path"})
    )
    summary_gdf = summary_gdf.merge(mask_paths, on="location", how="left")
    return summary_gdf

In [6]:
summary_gdf = get_fire_mask_path(fire_df)
summary_gdf.head()

Unnamed: 0,location,geometry,pre_date,post_date,mask_path
0,Almonaster,"MULTIPOLYGON (((-1.94651 41.37228, -1.94652 41...",2022-07-12,2022-07-27,/workspace/Rufai/data/S2-WCD/Almonaster/cm/cm.tif
1,Attica,"MULTIPOLYGON (((-0.7064 38.09311, -0.7064 38.0...",2021-08-16,2021-08-28,/workspace/Rufai/data/S2-WCD/Attica/cm/cm.tif
2,Australia_1,"POLYGON ((-4.47264 58.56369, -4.47263 58.5636,...",2021-01-31,2021-02-20,/workspace/Rufai/data/S2-WCD/Australia_1/cm/cm...
3,Australia_2,"MULTIPOLYGON (((-4.49479 58.50658, -4.49513 58...",2021-01-31,2021-02-20,/workspace/Rufai/data/S2-WCD/Australia_2/cm/cm...
4,Bejis,"MULTIPOLYGON (((-0.71135 39.88491, -0.71147 39...",2022-08-08,2022-08-23,/workspace/Rufai/data/S2-WCD/Bejis/cm/cm.tif


In [7]:
print("Number of fire events:", len(summary_gdf))

Number of fire events: 41


### Load datasets from STAC API

In [8]:
cluster = LocalCluster()
client = Client(cluster)
print(client.dashboard_link)

http://127.0.0.1:8787/status


In [9]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [10]:
# TODO: 4 PER event year (calendar not seasons)
# 4 for for each control year (depending on the availability of Sentinel-2 data)
def get_date_ranges(row, n_control_years: int = 7):
    post_date   = pd.to_datetime(row["post_date"])
    event_year  = post_date.year

    # Quarter breakpoints
    quarters = [(1, 3), (4, 6), (7, 9), (10, 12)]

    def quarter_window(year, start_m, end_m):
        """Return a 'YYYY-MM-DD/YYYY-MM-DD' string for a single quarter."""
        start = f"{year}-{start_m:02d}-01"
        end_day = calendar.monthrange(year, end_m)[1]
        end   = f"{year}-{end_m:02d}-{end_day:02d}"
        return f"{start}/{end}"

    # Event-year ranges (always 4)
    event_ranges = [
        quarter_window(event_year, sm, em) for sm, em in quarters
    ]

    # Control-year ranges
    control_ranges = []
    for delta in range(1, n_control_years + 1):
        cy = event_year - delta
        control_ranges.append([
            quarter_window(cy, sm, em) for sm, em in quarters
        ])
    return event_ranges, control_ranges

def rasterize_aoi(aoi, s2_stack):
    """
    Rasterize the AOI polygon into a burn mask
    """
    aoi_gdf = gpd.GeoDataFrame(
        {"geometry": [shape(aoi['geometry'])]},
        crs="EPSG:4326"
    )
    
    aoi_proj = aoi_gdf.to_crs(s2_stack.rio.crs)
    
    burn_mask = rasterize(
        [(mapping(aoi_proj['geometry'].iloc[0]), 1)],
        out_shape=(s2_stack.sizes['y'], s2_stack.sizes['x']),
        transform=s2_stack.rio.transform(),
        fill=0,
        dtype='uint8'
    )
    
    burn_mask_da = xr.DataArray(
        burn_mask,
        coords={"y": s2_stack["y"], "x": s2_stack["x"]},
        dims=("y", "x")
    )
    return burn_mask_da


# def crop_burn_window(s2_stack, burn_mask, config):
#     window_size = config['chips']['chip_size']

#     # Apply uniform filter (mean filter), then scale to get sum
#     burn_mean = uniform_filter(burn_mask.values.astype(float), size=window_size, mode='constant', cval=0.0)
#     burn_sum = burn_mean * (window_size ** 2)
    
#     # Find maximum sum
#     max_idx = np.unravel_index(np.argmax(burn_sum), burn_sum.shape)
#     y_idx, x_idx = max_idx
    
#     # Calculate start indices
#     y_start = max(y_idx - window_size // 2, 0)
#     x_start = max(x_idx - window_size // 2, 0)
        
#     cropped_stack = s2_stack.isel(
#         y=slice(y_start, y_start + window_size),
#         x=slice(x_start, x_start + window_size)
#     )
    
#     return cropped_stack

def crop_burn_windows(s2_stack, burn_mask, chip_size=224, pct_threshold=0.30):
    """
    Return *all* windows of size `chip_size`×`chip_size` whose burn-pixel
    fraction is ≥ `pct_threshold`.

    Returns
    -------
    list[xr.DataArray]  each element is a cropped stack
    """
    step = chip_size        # non-overlap; set <chip_size for stride/overlap
    ny, nx = burn_mask.shape
    chips = []

    for y0 in range(0, ny - chip_size + 1, step):
        for x0 in range(0, nx - chip_size + 1, step):
            window = burn_mask[y0:y0+chip_size, x0:x0+chip_size]
            frac   = window.mean()        # fraction of burn pixels in the window
            if frac >= pct_threshold:     # keep only “burn-rich” windows
                chips.append(
                    s2_stack.isel(
                        y=slice(y0, y0+chip_size),
                        x=slice(x0, x0+chip_size),
                        drop=False
                    )
                )
    return chips

def harmonize_to_old(data):
    """
    Harmonize new Sentinel-2 data to the old baseline.

    Parameters
    ----------
    data: xarray.DataArray
        A DataArray with four dimensions: time, band, y, x

    Returns
    -------
    harmonized: xarray.DataArray
        A DataArray with all values harmonized to the old
        processing baseline.
    """
    if "time" not in data.dims:
        # static composite → nothing to do
        return data
    if "time" not in data.dims:
        if "time" in data:
            data = data.expand_dims("time")  # convert scalar to 1D dimension
        else:
            raise ValueError("Variable 'time' not found in dataset.")
            
    data = data.set_index(time="time")

    cutoff = datetime(2022, 1, 25)
    offset = 1000
    bands = [
        "B01",
        "B02",
        "B03",
        "B04",
        "B05",
        "B06",
        "B07",
        "B08",
        "B8A",
        "B09",
        "B10",
        "B11",
        "B12",
    ]

    old = data.sel(time=slice(cutoff))

    to_process = list(set(bands) & set(data.band.data.tolist()))
    new = data.sel(time=slice(cutoff, None)).drop_sel(band=to_process)

    new_harmonized = data.sel(time=slice(cutoff, None), band=to_process).clip(offset)
    new_harmonized -= offset

    new = xr.concat([new, new_harmonized], "band").sel(band=data.band.data.tolist())
    
    return xr.concat([old, new], dim="time")

In [17]:
def generate_fire_chips(s2_stack, aoi, config, time_series_type, epsg, index, metadata_df):
    saved_any = False
    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")
    if s2_stack.shape[2] != 224 or s2_stack.shape[3] != 224:
        print(f"Skipping chip ID {index} for mismatch dimensions")
        return False, metadata_df 
    
    for crop_idx, s2_stack_cropped in enumerate(s2_stack):
        if "time" not in s2_stack_cropped.dims:
            if "time" in s2_stack_cropped.coords:
                s2_stack_cropped = s2_stack_cropped.expand_dims("time")
            else:
                print(f"Skipping chip ID {index} — no time dimension present")
                continue
        if missing_values(s2_stack_cropped, config['chips']['chip_size'], config['chips']['chip_size']):
            print(f"Skipping chip ID {index} for missing values")
            continue      
                
        s2_stack_cropped = harmonize_to_old(s2_stack_cropped)
    
        s2_stack_cropped = s2_stack_cropped.fillna(-999)
        s2_stack_cropped = s2_stack_cropped.rio.write_nodata(-999)
        s2_stack_cropped = s2_stack_cropped.astype(np.dtype(np.int16))
        s2_stack_cropped = s2_stack_cropped.rename("s2")

        if time_series_type == "event":
            for dt in s2_stack_cropped.time.values:
                print(f"Processing chip ID {index} for event date {dt}")
                ts = pd.to_datetime(str(dt)) 
                s2_path = f"data/fire_data/s2_{index:06}_e_{ts.strftime('%Y%m%d')}.tif"
                if os.path.exists(s2_path):
                    print(f"Skipping chip ID {index}_{crop_idx} for chip {index} date: {ts.strftime('%Y%m%d')} — file already exists")
                    continue
                print(f"Saving chip ID {index}_{crop_idx} for chip {index} date: {ts.strftime('%Y%m%d')}")
                s2_stack_cropped.sel(time=dt).squeeze().rio.to_raster(s2_path)

                s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)

                metadata_df = pd.concat([pd.DataFrame([[index,
                                                        ts.strftime('%Y%m%d'),
                                                        f"{index:06}_e_{ts.strftime('%Y%m%d')}",
                                                        "event",
                                                        s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                        s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                        epsg]
                                                    ],
                                                    columns=metadata_df.columns
                                                    ),
                                        metadata_df],
                                        ignore_index=True
                                    )
                saved_any = True
        else:
            for dt in s2_stack_cropped.time.values:
                print(f"Processing chip ID {index} for control date {dt}")
            # dt = s2_stack_cropped.time.values[0]
                ts = pd.to_datetime(str(dt)) 
                s2_path = f"data/fire_data/s2_{index:06}_c_{ts.strftime('%Y%m%d')}.tif"
                if os.path.exists(s2_path):
                    print(f"Skipping chip ID {index}_{crop_idx} for chip {index} date: {ts.strftime('%Y%m%d')} — file already exists")
                    continue
                print(f"Saving chip ID {index}_{crop_idx} for chip {index} date: {ts.strftime('%Y%m%d')}")
                s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)
                metadata_df = pd.concat([pd.DataFrame([[index,
                                                        ts.strftime('%Y%m%d'),
                                                        f"{index:06}_c_{ts.strftime('%Y%m%d')}",
                                                        "control",
                                                        s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                        s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                        epsg]
                                                    ],
                                            columns=metadata_df.columns
                                            ),
                                metadata_df],
                                ignore_index=True
                            )
                saved_any = True
    return saved_any, metadata_df

In [18]:
# metadata_df = pd.DataFrame(columns=["chip_id", "date", "sample_id", "type", "x_center", "y_center", "epsg"])
metadata_df = pd.read_csv("data/metadata_df.csv")

In [19]:
def search_s2_scenes(aoi, date_range, catalog, config, best_one=True):
    search = catalog.search(
        collections=["sentinel-2-l2a"],
        bbox=aoi.geometry.bounds,
        datetime=date_range,
        query=[f"s2:nodata_pixel_percentage<{config["sentinel_2"]["nodata_pixel_percentage"]}",
                 f"eo:cloud_cover<{config["sentinel_2"]["cloud_cover"]}"
                ],
        limit=None,                     
    )
    items = list(search.get_items())

    if best_one and items:
        items = [min(items, key=lambda it: it.properties["eo:cloud_cover"])]

    return pystac.ItemCollection(items)


In [20]:
import os, planetary_computer as pc, rasterio
from rasterio.errors import RasterioIOError
import stackstac, dask.array as da
from tenacity import retry, wait_random_exponential, stop_after_attempt

@retry(wait=wait_random_exponential(max=30), stop=stop_after_attempt(4))
def _open_once(href):
    # one quick probe to fail fast if token is stale
    with rasterio.open(href):
        pass

In [None]:
for index, aoi in summary_gdf.iterrows():
    print(f"\nProcessing AOI at index {index}")
    # if index <=36:
    #     continue

    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    event_date_ranges, control_date_ranges = get_date_ranges(aoi, n_control_years=7)
    # print(event_date_ranges)
    for date_range in event_date_ranges:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<2 or "bands" not in s2_items[0].assets:
        print(f"Invalid or Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
    clipping_geom = aoi["geometry"]

    try:
        s2_stack = stackstac.stack(
            s2_items,
            assets=config["sentinel_2"]["bands"],
            epsg=epsg,
            resolution=config["sentinel_2"]["resolution"],
            fill_value=np.nan,
            bounds_latlon = clipping_geom.bounds
        )
        s2_stack = mask_cloudy_pixels(s2_stack)
        s2_stack = s2_stack.drop_sel(band="SCL")
    except Exception as e:
        print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
        continue
    burn_mask = rasterize_aoi(aoi, s2_stack)

    chips = crop_burn_windows(
        s2_stack,
        burn_mask,
        chip_size=config['chips']['chip_size'],
        pct_threshold=0.30,
        )
    
    print(f"Found {len(chips)} chips for AOI {aoi_bounds}")
    
    for chip_id, s2_chip in enumerate(chips):
        event_status, metadata_df = generate_fire_chips(s2_chip, aoi, config, "event", epsg, f"{index}_{chip_id}", metadata_df)
    if event_status:
        for control_date_range in control_date_ranges:
            for date_range in control_date_range: 
                s2_items = search_s2_scenes(aoi, date_range, catalog, config)
            
                if len(s2_items)<1:
                    print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
                    continue
                try:
                    epsg = s2_items[0].properties["proj:epsg"]
                except:
                    epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
                try:
                    for it in s2_items:
                        planetary_computer.sign_inplace(it)
                    _open_once(s2_items[0].assets["bands"][0].href)
                    s2_stack = stackstac.stack(
                        s2_items,
                        assets=config["sentinel_2"]["bands"],
                        epsg=epsg,
                        resolution=config["sentinel_2"]["resolution"],
                        fill_value=np.nan,
                        bounds_latlon = clipping_geom.bounds
                    )
                    s2_stack = mask_cloudy_pixels(s2_stack)
                    s2_stack = s2_stack.drop_sel(band="SCL")
                except Exception as e:
                    print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
                    continue
                chips = crop_burn_windows(
                    s2_stack,
                    burn_mask,
                    chip_size=config['chips']['chip_size'],
                    pct_threshold=0.30,
                    )
                for chip_id, s2_chip in enumerate(chips):
                    control_status, metadata_df = generate_fire_chips(s2_chip, aoi, config, "control", epsg, f"{index}_{chip_id}", metadata_df)
            metadata_df.to_csv('data/metadata_df.csv', index=False)


Processing AOI at index 0
Found 7 chips for AOI (-1.9472346869945791, 41.37056606450764, -1.8334312864604134, 41.43045545782507)
Skipping chip ID 0_0 for missing values
Processing chip ID 0_0 for event date 2022-05-28T10:56:19.024000000
Skipping chip ID 0_0_1 for chip 0_0 date: 20220528 — file already exists
Processing chip ID 0_0 for event date 2022-08-31T10:56:31.024000000
Saving chip ID 0_0_2 for chip 0_0 date: 20220831
Processing chip ID 0_0 for event date 2022-10-05T10:58:19.024000000
Saving chip ID 0_0_3 for chip 0_0 date: 20221005
Processing chip ID 0_1 for event date 2022-02-22T11:00:51.024000000
Skipping chip ID 0_1_0 for chip 0_1 date: 20220222 — file already exists
Processing chip ID 0_1 for event date 2022-05-28T10:56:19.024000000
Saving chip ID 0_1_1 for chip 0_1 date: 20220528
Processing chip ID 0_1 for event date 2022-08-31T10:56:31.024000000
Saving chip ID 0_1_2 for chip 0_1 date: 20220831
Processing chip ID 0_1 for event date 2022-10-05T10:58:19.024000000
Saving chip 

In [16]:
metadata_df

Unnamed: 0,chip_id,date,sample_id,type,x_center,y_center,epsg
0,0_4,20220528,0_4000_e_20220528,event,591320.0,4581610.0,32630
1,0_0,20220528,0_0000_e_20220528,event,595800.0,4586090.0,32630
2,39_9,20230316,39_900_e_20230316,event,716180.0,4214570.0,32630
3,39_8,20230316,39_800_e_20230316,event,713940.0,4214570.0,32630
4,39_7,20230316,39_700_e_20230316,event,711700.0,4214570.0,32630
...,...,...,...,...,...,...,...
511,0_6,20220222,0_6000_e_20220222,event,595800.0,4581610.0,32630
512,0_5,20220222,0_5000_e_20220222,event,593560.0,4581610.0,32630
513,0_3,20220222,0_3000_e_20220222,event,595800.0,4583850.0,32630
514,0_2,20220222,0_2000_e_20220222,event,593560.0,4583850.0,32630
