In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import pystac_client
import pystac
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from pystac_client.stac_api_io import StacApiIO
import planetary_computer

import dask.distributed
import numpy as np
import rioxarray
import pandas as pd
import geopandas as gpd
from src.utils import search_s2_scenes, search_s1_scenes, search_landsat_scenes, search_dem_scene, search_lc_scene 
from src.utils import stack_data, stack_dem_data, stack_lc_data, unique_class, missing_values, gen_chips
import yaml

In [4]:
import warnings
warnings.filterwarnings("ignore")

In [5]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [6]:
version = config['version']
aoi_path = (f'../data/map_{version}.geojson')
aoi_gdf = gpd.read_file(aoi_path)

In [7]:
# Following AOIs have broken scenes in the STAC catalog and should be removed

aoi_gdf = aoi_gdf.drop(config['excluded_aoi_indices')

In [8]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()#(n_workers=8, threads_per_worker=2)
client = Client(cluster)
print(client.dashboard_link)

http://127.0.0.1:8787/status




In [9]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [10]:
def process_array( 
            stack, 
            epsg: int,
            coords: tuple[float, float],
            array_name: str,
            sample_size: int = 64,
            chip_size: int = 64,
            na_value: int = -999,
            dtype = np.int16,
            ):

    x, y = coords
    x_indices = slice((x) * sample_size - int((chip_size - sample_size)/2), (x + 1) * sample_size + int((chip_size - sample_size)/2))
    y_indices = slice((y) * sample_size - int((chip_size - sample_size)/2), (y + 1) * sample_size + int((chip_size - sample_size)/2))    

    array = stack.isel(x = x_indices, y = y_indices)
    array.rio.write_crs(f"epsg:{epsg}", inplace=True)
    array = array.where((array.x >= stack.x[(x) * sample_size]) &
                              (array.x < stack.x[(x + 1) * sample_size]) & 
                              (array.y <= stack.y[(y) * sample_size]) &
                              (array.y > stack.y[(y + 1) * sample_size])
                             )

    array = array.fillna(na_value)
    array = array.rio.write_nodata(na_value)
    array = array.astype(np.dtype(dtype))
    array = array.rename(array_name)

    return array

In [11]:
def process_chips(s2_stack, s1_stack, landsat_stack, lc_stack, dem_stack, epsg, sample_size, chip_size, global_index, aoi_index, metadata_df, root_path):

    print("Loading lc_stack")

    try:
        lc_stack = lc_stack.compute()
    except:
        print("skipping the AOI for no LC data")
        return global_index, metadata_df

    print("Loading s2_stack")
    
    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")
        return global_index, metadata_df

    print("Loading s1_stack")
    
    try:
        s1_stack = s1_stack.compute()
    except:
        print("skipping the AOI for no S1 data")
        return global_index, metadata_df

    print("Loading dem_stack")
    
    try:
        dem_stack = dem_stack.compute()
    except:
        print("skipping the AOI for no dem data")
        return global_index, metadata_df

    try:
        landsat_stack = landsat_stack.compute()
    except:
        print("skipping the AOI for no landsat data")
        return global_index, metadata_df
    
    lc_uniqueness = lc_stack.coarsen(x = sample_size,
                                     y = sample_size,
                                     boundary = "trim"
                                    ).reduce(unique_class)
    lc_uniqueness[0:2, :] = False
    lc_uniqueness[-2:, :] = False
    lc_uniqueness[:, 0:2] = False
    lc_uniqueness[:, -2:] = False

    ys, xs = np.where(lc_uniqueness)

    # Following indices are added to limit the number of rangeland, bareground, and water chips per tile
    rangeland_index = 0
    bareground_index = 0
    water_index = 0
    tree_index = 0
    crops_index = 0
    for index in range(0, len(ys)):
        x = xs[index]
        y = ys[index]
    
        s2_array = process_array(
            stack = s2_stack, 
            epsg = epsg, 
            coords = (x, y),
            array_name = 'S2',
            sample_size = sample_size,
            chip_size = chip_size,
            na_value = -999 ,
            dtype = np.int16,
        )

        if missing_values(s2_array, chip_size, sample_size):
            continue    

        s1_array = process_array(
            stack = s1_stack, 
            epsg = epsg, 
            coords = (x, y),
            array_name = 'S1',
            sample_size = sample_size,
            chip_size = chip_size,
            na_value = -999 ,
            dtype = np.float32,
        )

        if missing_values(s1_array, chip_size, sample_size):
            continue 

        landsat_array = process_array(
            stack = landsat_stack, 
            epsg = epsg, 
            coords = (x, y),
            array_name = 'L8',
            sample_size = sample_size,
            chip_size = chip_size,
            na_value = -999 ,
            dtype = np.float32,
        )

        if missing_values(s1_array, chip_size, sample_size):
            continue 

        lc_array = process_array(
            stack = lc_stack, 
            epsg = epsg, 
            coords = (x, y),
            array_name = 'lc',
            sample_size = sample_size,
            chip_size = chip_size,
            na_value = 0,
            dtype = np.int8,
        )
        
        if missing_values(lc_array, chip_size, sample_size):
            continue
            
        dem_array = process_array(
            stack = dem_stack, 
            epsg = epsg, 
            coords = (x, y),
            array_name = 'dem',
            sample_size = sample_size,
            chip_size = chip_size,
            na_value = -999,
            dtype = np.float32,
        )

        if missing_values(dem_array, chip_size, sample_size):
            continue
            
        if (np.isin(lc_array, [255, 130, 133])).any():
            raise ValueError('Wrong LC value')
        
        # Skipping Flooded Vegetation
        if (np.isin(lc_array, [4])).any():
            continue
        
        lc = np.unique(lc_array)
        if lc == 1:
            water_index += 1
            if water_index > 400:
                continue 
        elif lc == 8:
            bareground_index += 1
            if bareground_index > 400:
                continue
        elif lc == 11:
            rangeland_index += 1
            if rangeland_index > 400:
                continue
        elif lc == 2:
            tree_index += 1
            if tree_index > 400:
                continue
        elif lc == 5:
            crops_index += 1
            if crops_index > 400:
                continue
        gen_status, s2_dts, s1_dts, landsat_dts = gen_chips(s2_array, s1_array, landsat_array, lc_array, dem_array, global_index, root_path)
        if gen_status:
            metadata_df = pd.concat([pd.DataFrame([[global_index,
                                                    aoi_index,
                                                    s2_dts,
                                                    s1_dts,
                                                    landsat_dts,
                                                    np.unique(lc_array),
                                                    s2_stack.x[(x) * sample_size + int(sample_size / 2)].data,
                                                    s2_stack.y[(y) * sample_size + int(sample_size / 2)].data,
                                                    epsg]
                                                  ],
                                                  columns=metadata_df.columns
                                                 ),
                                     metadata_df],
                                    ignore_index=True
                                   )
            global_index += 1
    
    return global_index, metadata_df

In [12]:
global_index = 0
metadata_df = pd.DataFrame(columns=["chip_id", "s2_dates", "s1_dts", "landsat_dts", "lc", "x_center", "y_center", "epsg"])
# metadata_df = pd.read_csv("../data/metadata_df.csv") # Use this line to continue from a previous iteration if the code stops. 

In [13]:
for index, aoi in aoi_gdf.iterrows():
    print(f"\nProcessing AOI at index {index}")
    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    for date_range in config["sentinel_2"]["time_ranges"]:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<4:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_data(s2_items, "sentinel_2", config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
    bbox = s2_items[0].bbox

    s1_items = pystac.item_collection.ItemCollection([])
    landsat_items = pystac.item_collection.ItemCollection([])

    for s2_item in s2_items:
        s2_datetime = s2_item.datetime
        s1_item = search_s1_scenes(aoi, s2_datetime, catalog, config)
        s1_items += s1_item
        landsat_item = search_landsat_scenes(aoi, s2_datetime, catalog, config)
        landsat_items += landsat_item

    s1_stack = stack_data(s1_items, "sentinel_1", config, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox)
    if s1_stack is None:
        print(f"Failed to stack Sentinel-1 bands for AOI {aoi_bounds}")
        continue

    landsat_stack = stack_data(landsat_items, "landsat", config, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox)
    if landsat_stack is None:
        print(f"Failed to stack Landsat bands for AOI {aoi_bounds}")
        continue
         
    lc_items = search_lc_scene(s2_items[0].bbox, catalog, config)
    if not lc_items:
        print(f"No Land Cover data found for AOI {aoi_bounds}")
        continue
    
    lc_stack = stack_lc_data(lc_items, config, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox)
    if lc_stack is None:
        print(f"Failed to stack Land Cover data for AOI {aoi_bounds} and date range {date_range}")
        continue

    dem_items = search_dem_scene(s2_items[0].bbox, catalog, config)
    if not dem_items:
        print(f"No DEM data found for AOI {aoi_bounds}")
        continue
    
    dem_stack = stack_dem_data(dem_items, config, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox)
    if dem_stack is None:
        print(f"Failed to stack DEM data for AOI {aoi_bounds} and date range {date_range}")
        continue    
    global_index, metadata_df = process_chips(s2_stack,
                                              s1_stack,
                                              landsat_stack,
                                              lc_stack,
                                              dem_stack,
                                              epsg,
                                              config["chips"]["sample_size"],
                                              config["chips"]["chip_size"],
                                              global_index,
                                              index,
                                              metadata_df,
                                              config['working_dir']
                                             )
    
    metadata_df.to_csv(Path(config['working_dir']) / 'metadata_df.csv', index=False)


Processing AOI at index 0
Loading lc_stack
Loading s2_stack
Loading s1_stack
Loading dem_stack
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_20230218.tif' failed: /home/benchuser/final_data/S2_000000_0_20230218.tif: Permission denied
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_20230218.tif' failed: /home/benchuser/final_data/S2_000000_0_20230218.tif: Permission denied
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_20230218.tif' failed: /home/benchuser/final_data/S2_000000_0_20230218.tif: Permission denied
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_20230218.tif' failed: /home/benchuser/final_data/S2_000000_0_20230218.tif: Permission denied
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_20230218.tif' failed: /home/benchuser/final_data/S2_000000_0_20230218.tif: Permission denied
Attempt to create new tiff file '/home/benchuser/final_data/S2_000000_0_2023021

KeyboardInterrupt: 

Process Dask Worker process (from Nanny):
2025-07-04 19:04:08,832 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
2025-07-04 19:04:08,832 - distributed.nanny - ERROR - Worker process died unexpectedly
2025-07-04 19:04:08,832 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
Process Dask Worker process (from Nanny):
Traceback (most recent call last):
  File "/opt/conda/envs/gfm_bench/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/gfm_bench/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/opt/conda/envs/gfm_bench/lib/python3.12/site-packages/distributed/nanny.py", line 984, in run
    await worker.finished()
  File "/opt/conda/envs/gfm_bench/lib/python3.12/site-packages/d