In [None]:
import os
os.chdir("../")

In [None]:
import pystac_client
import pystac
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from pystac_client.stac_api_io import StacApiIO
import planetary_computer

import dask.distributed
import numpy as np
import rioxarray
import pandas as pd
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import yaml

In [None]:
from datetime import datetime, timedelta
import rasterio
import stackstac 
from src.utils import mask_cloudy_pixels

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from rasterio.features import rasterize
from shapely.geometry import mapping
from scipy.signal import convolve2d
from scipy.ndimage import uniform_filter
from shapely.geometry import shape

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [None]:
aoi_gdf = gpd.read_file("data/mtbs/mtbs_perims_DD.shp")

In [None]:
def add_start_date(event_id):
    start_date = pd.to_datetime(event_id[-8:], format="%Y%m%d")
    return start_date

def add_pre_date(pre_id):
    if len(pre_id)<=15:
        pre_date = pd.to_datetime(pre_id[-8:], format="%Y%m%d")
    else:
        pre_date = pd.to_datetime(pre_id.split("_")[0][-8:], format="%Y%m%d")

    return pre_date

def add_post_date(post_id):
    if len(post_id)<=15:
        post_date = pd.to_datetime(post_id[-8:], format="%Y%m%d")
    else:
        post_date = pd.to_datetime(post_id.split("_")[0][-8:], format="%Y%m%d")

    return post_date

In [None]:
aoi_gdf["start_date"] = aoi_gdf["Event_ID"].apply(add_start_date)

In [None]:
selected_fires = aoi_gdf
selected_fires = selected_fires[selected_fires["Incid_Type"].isin(["Wildfire"])] #, "Prescribed Fire"
selected_fires = selected_fires[selected_fires["Comment"].isnull()]
selected_fires = selected_fires[~selected_fires["Pre_ID"].isnull()]
selected_fires = selected_fires[~selected_fires["Post_ID"].isnull()]
selected_fires = selected_fires[selected_fires["start_date"]>pd.to_datetime("20230101", format="%Y%m%d")]
# selected_fires = selected_fires[selected_fires["start_date"]<pd.to_datetime("20240101", format="%Y%m%d")]
len(selected_fires)

In [None]:
selected_fires["pre_date"] = selected_fires["Pre_ID"].apply(add_pre_date)
selected_fires["post_date"] = selected_fires["Post_ID"].apply(add_post_date)

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()#(n_workers=8, threads_per_worker=2)
client = Client(cluster)
print(client.dashboard_link)

In [None]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [None]:
# metadata_df = pd.DataFrame(columns=["chip_id", "date", "sample_id", "type", "x_center", "y_center", "epsg"])
metadata_df = pd.read_csv("../fire_data/metadata_df.csv")

In [None]:
def get_date_ranges(row):
    pre_start_date = row["start_date"]-timedelta(days=91)
    pre_end_date = row["start_date"]-timedelta(days=1)

    post_start_date = row["post_date"]+timedelta(days=1)
    post_end_date = row["post_date"]+timedelta(days=91)

    pre_dates = f"{str(pre_start_date).split(" ")[0]}/{str(pre_end_date).split(" ")[0]}"
    post_dates = f"{str(post_start_date).split(" ")[0]}/{str(post_end_date).split(" ")[0]}"

    control_dates = []
    for delta_year in range(1, 8):
        control_start_date = pre_start_date-timedelta(days=delta_year*365)
        control_end_date = pre_end_date-timedelta(days=delta_year*365)
        control_date = f"{str(control_start_date).split(" ")[0]}/{str(control_end_date).split(" ")[0]}"
        control_dates.append([control_date])
    
    return [pre_dates, post_dates], control_dates

def rasterize_aoi(aoi, s2_stack):
    """
    Rasterize the AOI polygon into a burn mask
    """

    aoi_gdf = gpd.GeoDataFrame(
        {"geometry": [shape(aoi['geometry'])]},
        crs="EPSG:4326"
    )
    
    aoi_proj = aoi_gdf.to_crs(s2_stack.rio.crs)
    
    burn_mask = rasterize(
        [(mapping(aoi_proj['geometry'].iloc[0]), 1)],
        out_shape=(s2_stack.sizes['y'], s2_stack.sizes['x']),
        transform=s2_stack.rio.transform(),
        fill=0,
        dtype='uint8'
    )
    
    burn_mask_da = xr.DataArray(
        burn_mask,
        coords={"y": s2_stack["y"], "x": s2_stack["x"]},
        dims=("y", "x")
    )

    return burn_mask_da

def crop_burn_window(s2_stack, burn_mask, config):
    
    window_size = config['chips']['chip_size']

    # Apply uniform filter (mean filter), then scale to get sum
    burn_mean = uniform_filter(burn_mask.values.astype(float), size=window_size, mode='constant', cval=0.0)
    burn_sum = burn_mean * (window_size ** 2)
    
    # Find maximum sum
    max_idx = np.unravel_index(np.argmax(burn_sum), burn_sum.shape)
    y_idx, x_idx = max_idx
    
    # Calculate start indices
    y_start = max(y_idx - window_size // 2, 0)
    x_start = max(x_idx - window_size // 2, 0)
        
    cropped_stack = s2_stack.isel(
        y=slice(y_start, y_start + window_size),
        x=slice(x_start, x_start + window_size)
    )
    
    return cropped_stack

def harmonize_to_old(data):
    """
    Harmonize new Sentinel-2 data to the old baseline.

    Parameters
    ----------
    data: xarray.DataArray
        A DataArray with four dimensions: time, band, y, x

    Returns
    -------
    harmonized: xarray.DataArray
        A DataArray with all values harmonized to the old
        processing baseline.
    """
    cutoff = datetime(2022, 1, 25)
    offset = 1000
    bands = [
        "B01",
        "B02",
        "B03",
        "B04",
        "B05",
        "B06",
        "B07",
        "B08",
        "B8A",
        "B09",
        "B10",
        "B11",
        "B12",
    ]

    old = data.sel(time=slice(cutoff))

    to_process = list(set(bands) & set(data.band.data.tolist()))
    new = data.sel(time=slice(cutoff, None)).drop_sel(band=to_process)

    new_harmonized = data.sel(time=slice(cutoff, None), band=to_process).clip(offset)
    new_harmonized -= offset

    new = xr.concat([new, new_harmonized], "band").sel(band=data.band.data.tolist())
    
    return xr.concat([old, new], dim="time")

In [None]:
def generate_fire_chips(s2_stack, aoi, config, time_series_type, epsg, index, metadata_df):
   

    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")

    
    burn_mask = rasterize_aoi(aoi, s2_stack)

    try:
        s2_stack_cropped = crop_burn_window(s2_stack, burn_mask, config)
    except:
        print("Cropping S2 stack failed; skipping AOI")
    
    if s2_stack_cropped.shape[2] != 224 or s2_stack_cropped.shape[3] != 224:
        print(f"Skipping chip ID {index} for mismatch dimensions")
        
        return False, metadata_df 
    
    
    if missing_values(s2_stack_cropped, config['chips']['chip_size'], config['chips']['chip_size']):
        print(f"Skipping chip ID {index} for missing values")
        return False, metadata_df      
    
    s2_stack_cropped = harmonize_to_old(s2_stack_cropped)
    
    s2_stack_cropped = s2_stack_cropped.fillna(-999)
    s2_stack_cropped = s2_stack_cropped.rio.write_nodata(-999)
    s2_stack_cropped = s2_stack_cropped.astype(np.dtype(np.int16))
    s2_stack_cropped = s2_stack_cropped.rename("s2")

    if time_series_type == "event":
        for dt in s2_stack_cropped.time.values:
            ts = pd.to_datetime(str(dt)) 
            s2_path = f"/home/benchuser/fire_data/s2_{index:06}_e_{ts.strftime('%Y%m%d')}.tif"
            s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)

            metadata_df = pd.concat([pd.DataFrame([[index,
                                                    ts.strftime('%Y%m%d'),
                                                    f"{index:06}_e_{ts.strftime('%Y%m%d')}",
                                                    "event",
                                                    s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                    s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                    epsg]
                                                  ],
                                                  columns=metadata_df.columns
                                                 ),
                                     metadata_df],
                                    ignore_index=True
                                   )
            
    else:
        dt = s2_stack_cropped.time.values[0]
        ts = pd.to_datetime(str(dt)) 
        s2_path = f"/home/benchuser/fire_data/s2_{index:06}_c_{ts.strftime('%Y%m%d')}.tif"
        s2_stack_cropped.sel(time = dt).squeeze().rio.to_raster(s2_path)
        metadata_df = pd.concat([pd.DataFrame([[index,
                                                ts.strftime('%Y%m%d'),
                                                f"{index:06}_c_{ts.strftime('%Y%m%d')}",
                                                "control",
                                                s2_stack_cropped.x[int(len(s2_stack_cropped.x)/2)].data,
                                                s2_stack_cropped.y[int(len(s2_stack_cropped.y)/2)].data,
                                                epsg]
                                              ],
                                      columns=metadata_df.columns
                                     ),
                         metadata_df],
                        ignore_index=True
                       )

    return True, metadata_df
        

In [None]:
for index, aoi in selected_fires[106:].iterrows():
    print(f"\nProcessing AOI at index {index}")
    
    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    event_date_ranges, control_date_ranges = get_date_ranges(aoi)
    for date_range in event_date_ranges:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<2:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_s2_data(s2_items, config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue
        
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])

    clipping_geom = aoi["geometry"]

    try:
        s2_stack = stackstac.stack(
            s2_items,
            assets=config["sentinel_2"]["bands"],
            epsg=epsg,
            resolution=config["sentinel_2"]["resolution"],
            fill_value=np.nan,
            bounds_latlon = clipping_geom.bounds
        )
        s2_stack = mask_cloudy_pixels(s2_stack)
        s2_stack = s2_stack.drop_sel(band="SCL")
    except Exception as e:
        print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
        continue

    event_status, metadata_df = generate_fire_chips(s2_stack, aoi, config, "event", epsg, index, metadata_df)
    if event_status:
        for control_date_range in control_date_ranges:
            s2_items = search_s2_scenes(aoi, control_date_range[0], catalog, config)
            
            if len(s2_items)<1:
                print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
                continue
                
        
            s2_stack = stack_s2_data(s2_items, config)
            if s2_stack is None:
                print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds} in control year")
                continue
                
            try:
                epsg = s2_items[0].properties["proj:epsg"]
            except:
                epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
            
            try:
                s2_stack = stackstac.stack(
                    s2_items,
                    assets=config["sentinel_2"]["bands"],
                    epsg=epsg,
                    resolution=config["sentinel_2"]["resolution"],
                    fill_value=np.nan,
                    bounds_latlon = clipping_geom.bounds
                )
                s2_stack = mask_cloudy_pixels(s2_stack)
                s2_stack = s2_stack.drop_sel(band="SCL")
            except Exception as e:
                print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
                continue
        
            control_status, metadata_df = generate_fire_chips(s2_stack, aoi, config, "control", epsg, index, metadata_df)
            
    metadata_df.to_csv('/home/benchuser/fire_data/metadata_df.csv', index=False)

In [None]:
import os
import re
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from collections import defaultdict

folder_path = "/home/benchuser/fire_data/"

pattern = re.compile(r"s2_(\w+)_(\w+)_(\d{8})\.tif")

files_by_id = defaultdict(list)

for filename in os.listdir(folder_path):
    match = pattern.match(filename)
    if match:
        id_ = match.group(1)
        time_series_type = match.group(2)
        date = match.group(3)
        full_path = os.path.join(folder_path, filename)
        files_by_id[id_].append((date, full_path))

valid_ids = [id_ for id_, files in files_by_id.items() if len(files) >= 6]
valid_ids = sorted(valid_ids)[:6]  # Select first 10 IDs

fig, axes = plt.subplots(nrows=6, ncols=6, figsize=(6, 5))
fig.tight_layout(pad=-.8)

for i, id_ in enumerate(valid_ids):
    scenes = sorted(files_by_id[id_], key=lambda x: x[0])
    
    for j, (date, path) in enumerate(scenes[-6:]):
        with rasterio.open(path) as src:
            img = src.read([3, 2, 1]).astype(np.float32)

            # img_min = img.min(axis=(1, 2), keepdims=True)
            # img_max = img.max(axis=(1, 2), keepdims=True)
            # img_norm = (img - img_min) / (img_max - img_min + 1e-5)
            img_norm = img/4000
            img_rgb = np.transpose(img_norm, (1, 2, 0))

        ax = axes[i, j]
        ax.imshow(img_rgb)
        ax.set_title(f"{date}", fontsize=6)
        ax.axis('off')

plt.savefig(f"{folder_path}/sample_fires.png", dpi=300, bbox_inches="tight")
plt.show()


In [None]:
import shutil

folder_to_zip = f'/home/benchuser/fire_data/'
output_zip_file = f'/home/benchuser/fire_v0.10'

shutil.make_archive(output_zip_file, 'zip', folder_to_zip)