In [1]:
import os
os.chdir("../")

In [2]:
import pystac_client
import pystac
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from pystac_client.stac_api_io import StacApiIO
import planetary_computer

import dask.distributed
import numpy as np
import rioxarray
import pandas as pd
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import yaml

In [3]:
from datetime import datetime, timedelta
import rasterio
import stackstac 
from src.utils import mask_cloudy_pixels

In [4]:
import matplotlib.pyplot as plt
import xarray as xr
from rasterio.features import rasterize
from shapely.geometry import mapping
from scipy.signal import convolve2d
from scipy.ndimage import uniform_filter

In [5]:
import warnings
warnings.filterwarnings("ignore")

In [6]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [7]:
aoi_gdf = gpd.read_file("data/mtbs/mtbs_perims_DD.shp")

In [8]:
def add_start_date(event_id):
    start_date = pd.to_datetime(event_id[-8:], format="%Y%m%d")
    return start_date

def add_pre_date(pre_id):
    if len(pre_id)<=15:
        pre_date = pd.to_datetime(pre_id[-8:], format="%Y%m%d")
    else:
        pre_date = pd.to_datetime(pre_id.split("_")[0][-8:], format="%Y%m%d")

    return pre_date

def add_post_date(post_id):
    if len(post_id)<=15:
        post_date = pd.to_datetime(post_id[-8:], format="%Y%m%d")
    else:
        post_date = pd.to_datetime(post_id.split("_")[0][-8:], format="%Y%m%d")

    return post_date

In [9]:
aoi_gdf["start_date"] = aoi_gdf["Event_ID"].apply(add_start_date)

In [10]:
selected_fires = aoi_gdf
selected_fires = selected_fires[selected_fires["Incid_Type"].isin(["Wildfire"])] #, "Prescribed Fire"
selected_fires = selected_fires[selected_fires["Comment"].isnull()]
selected_fires = selected_fires[~selected_fires["Pre_ID"].isnull()]
selected_fires = selected_fires[~selected_fires["Post_ID"].isnull()]
selected_fires = selected_fires[selected_fires["start_date"]>pd.to_datetime("20230101", format="%Y%m%d")]
# selected_fires = selected_fires[selected_fires["start_date"]<pd.to_datetime("20240101", format="%Y%m%d")]
len(selected_fires)

199

In [11]:
selected_fires["pre_date"] = selected_fires["Pre_ID"].apply(add_pre_date)
selected_fires["post_date"] = selected_fires["Post_ID"].apply(add_post_date)

In [12]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()#(n_workers=8, threads_per_worker=2)
client = Client(cluster)
print(client.dashboard_link)

http://127.0.0.1:8787/status


In [13]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [14]:
metadata_df = pd.DataFrame(columns=["chip_id", "date", "sample_id", "type", "x_center", "y_center", "epsg"])
# metadata_df = pd.read_csv("../data/metadata_df.csv")

In [15]:
def get_date_ranges(row):
    pre_start_date = row["start_date"]-timedelta(days=91)
    pre_end_date = row["start_date"]-timedelta(days=1)

    post_start_date = row["post_date"]+timedelta(days=1)
    post_end_date = row["post_date"]+timedelta(days=91)

    pre_dates = f"{str(pre_start_date).split(" ")[0]}/{str(pre_end_date).split(" ")[0]}"
    post_dates = f"{str(post_start_date).split(" ")[0]}/{str(post_end_date).split(" ")[0]}"

    control_dates = []
    for delta_year in range(1, 8):
        control_start_date = pre_start_date-timedelta(days=delta_year*365)
        control_end_date = pre_end_date-timedelta(days=delta_year*365)
        control_date = f"{str(control_start_date).split(" ")[0]}/{str(control_end_date).split(" ")[0]}"
        control_dates.append([control_date])
    
    return [pre_dates, post_dates], control_dates

In [16]:
def generate_fire_chips(s2_stack, config, time_series_type, epsg, index, metadata_df):
    x_i = int((len(s2_stack.x) - config['chips']['chip_size'])/2)
    y_i = int((len(s2_stack.y) - config['chips']['chip_size'])/2)
    
    s2_stack = s2_stack.isel(x = slice(x_i, x_i + config['chips']['chip_size']), y = slice(y_i, y_i + config['chips']['chip_size']))

    if s2_stack.shape[2] != 224 or s2_stack.shape[3] != 224:
        print(f"Skipping chip ID {index} for mismatch dimensions")
        return False, metadata_df 
    
    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")

    if missing_values(s2_stack, config['chips']['chip_size'], config['chips']['sample_size']):
        print(f"Skipping chip ID {index} for missing values")
        return False, metadata_df      
    
    s2_stack = s2_stack.fillna(-999)
    s2_stack = s2_stack.rio.write_nodata(-999)
    s2_stack = s2_stack.astype(np.dtype(np.int16))
    s2_stack = s2_stack.rename("s2")

    if time_series_type == "event":
        for dt in s2_stack.time.values:
            ts = pd.to_datetime(str(dt)) 
            s2_path = f"/home/benchuser/fire_data/s2_{index:06}_e_{ts.strftime('%Y%m%d')}.tif"
            s2_stack.sel(time = dt).squeeze().rio.to_raster(s2_path)

            metadata_df = pd.concat([pd.DataFrame([[index,
                                                    ts.strftime('%Y%m%d'),
                                                    f"{index:06}_e_{ts.strftime('%Y%m%d')}",
                                                    "event",
                                                    s2_stack.x[int(len(s2_stack.x)/2)].data,
                                                    s2_stack.y[int(len(s2_stack.y)/2)].data,
                                                    epsg]
                                                  ],
                                                  columns=metadata_df.columns
                                                 ),
                                     metadata_df],
                                    ignore_index=True
                                   )
            
    else:
        dt = s2_stack.time.values[0]
        ts = pd.to_datetime(str(dt)) 
        s2_path = f"/home/benchuser/fire_data/s2_{index:06}_c_{ts.strftime('%Y%m%d')}.tif"
        s2_stack.sel(time = dt).squeeze().rio.to_raster(s2_path)
        metadata_df = pd.concat([pd.DataFrame([[index,
                                                ts.strftime('%Y%m%d'),
                                                f"{index:06}_c_{ts.strftime('%Y%m%d')}",
                                                "control",
                                                s2_stack.x[int(len(s2_stack.x)/2)].data,
                                                s2_stack.y[int(len(s2_stack.y)/2)].data,
                                                epsg]
                                              ],
                                      columns=metadata_df.columns
                                     ),
                         metadata_df],
                        ignore_index=True
                       )

    return True, metadata_df
        

In [17]:
for index, aoi in selected_fires.iterrows():
    print(f"\nProcessing AOI at index {index}")
    
    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    event_date_ranges, control_date_ranges = get_date_ranges(aoi)
    for date_range in event_date_ranges:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<2:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_s2_data(s2_items, config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue
        
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])

    clipping_geom = aoi["geometry"]

    try:
        s2_stack = stackstac.stack(
            s2_items,
            assets=config["sentinel_2"]["bands"],
            epsg=epsg,
            resolution=config["sentinel_2"]["resolution"],
            fill_value=np.nan,
            bounds_latlon = clipping_geom.bounds
        )
        s2_stack = mask_cloudy_pixels(s2_stack)
        s2_stack = s2_stack.drop_sel(band="SCL")
    except Exception as e:
        print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
        continue

    event_status, metadata_df = generate_fire_chips(s2_stack, config, "event", epsg, index, metadata_df)
    # if event_status:
    #     for control_date_range in control_date_ranges:
    #         s2_items = search_s2_scenes(aoi, control_date_range[0], catalog, config)
            
    #         if len(s2_items)<1:
    #             print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
    #             continue
                
        
    #         s2_stack = stack_s2_data(s2_items, config)
    #         if s2_stack is None:
    #             print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds} in control year")
    #             continue
                
    #         try:
    #             epsg = s2_items[0].properties["proj:epsg"]
    #         except:
    #             epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
            
    #         try:
    #             s2_stack = stackstac.stack(
    #                 s2_items,
    #                 assets=config["sentinel_2"]["bands"],
    #                 epsg=epsg,
    #                 resolution=config["sentinel_2"]["resolution"],
    #                 fill_value=np.nan,
    #                 bounds_latlon = clipping_geom.bounds
    #             )
    #             s2_stack = mask_cloudy_pixels(s2_stack)
    #             s2_stack = s2_stack.drop_sel(band="SCL")
    #         except Exception as e:
    #             print(f"Error stacking Sentinel-2 data: {e}. Skipping AOI {index}")
    #             continue
        
    #         event_status, metadata_df = generate_fire_chips(s2_stack, config, "control", epsg, index, metadata_df)
            
    metadata_df.to_csv('/home/benchuser/fire_data/metadata_df.csv', index=False)


Processing AOI at index 29590


In [None]:
import os
import re
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from collections import defaultdict

folder_path = "/home/benchuser/fire_data/"

pattern = re.compile(r"s2_(\w+)_(\w+)_(\d{8})\.tif")

files_by_id = defaultdict(list)

for filename in os.listdir(folder_path):
    match = pattern.match(filename)
    if match:
        id_ = match.group(1)
        time_series_type = match.group(2)
        date = match.group(3)
        full_path = os.path.join(folder_path, filename)
        files_by_id[id_].append((date, full_path))

valid_ids = [id_ for id_, files in files_by_id.items() if len(files) == 9]
valid_ids = sorted(valid_ids)[:10]  # Select first 10 IDs

fig, axes = plt.subplots(nrows=10, ncols=9, figsize=(10, 18))
fig.tight_layout(pad=5.0)

for i, id_ in enumerate(valid_ids):
    scenes = sorted(files_by_id[id_], key=lambda x: x[0])
    
    for j, (date, path) in enumerate(scenes):
        with rasterio.open(path) as src:
            img = src.read([3, 2, 1]).astype(np.float32)

            # img_min = img.min(axis=(1, 2), keepdims=True)
            # img_max = img.max(axis=(1, 2), keepdims=True)
            # img_norm = (img - img_min) / (img_max - img_min + 1e-5)
            img_norm = img/4000
            img_rgb = np.transpose(img_norm, (1, 2, 0))

        ax = axes[i, j]
        ax.imshow(img_rgb)
        ax.set_title(f"{date}", fontsize=8)
        ax.axis('off')

plt.savefig(f"{folder_path}/sample_fires.png", dpi=300, bbox_inches="tight")
plt.show()


In [None]:
import shutil

folder_to_zip = f'/home/benchuser/fire_data/'
output_zip_file = f'/home/benchuser/fire_v0.10'

shutil.make_archive(output_zip_file, 'zip', folder_to_zip)