In [1]:
import os
os.chdir("../")

In [2]:
import dask.distributed
import pystac_client
import planetary_computer
import numpy as np
import pandas as pd
import rioxarray
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import pystac
import yaml

In [3]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [4]:
aoi_gdf = gpd.read_file("data/all_aois.geojson")

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)
print(client.dashboard_link)

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [None]:
def process_chips(s2_stack, lc_stack, epsg, sample_size, chip_size, global_index, metadata_df):
    
    lc_stack = lc_stack.compute()
    
    lc_uniqueness = lc_stack.coarsen(x = sample_size,
                                     y = sample_size,
                                     boundary = "trim"
                                    ).reduce(unique_class)
    lc_uniqueness[0, :] = False
    lc_uniqueness[-1, :] = False
    lc_uniqueness[:, 0] = False
    lc_uniqueness[:, -1] = False

    ys, xs = np.where(lc_uniqueness)
    print("Loading s2_stack")
    s2_stack = s2_stack.compute()
    
    for index in range(0, len(ys)):
        y = ys[index]
        x = xs[index]
    
            
        x_coords = slice((x) * sample_size - int((chip_size - sample_size)/2), (x + 1) * sample_size + int((chip_size - sample_size)/2))
        y_coords = slice((y) * sample_size - int((chip_size - sample_size)/2), (y + 1) * sample_size + int((chip_size - sample_size)/2))    
        
        s2_array = s2_stack.isel(x = x_coords, y = y_coords)
        s2_array.rio.write_crs(f"epsg:{epsg}", inplace=True)
        s2_array = s2_array.where((s2_array.x >= s2_stack.x[(x) * sample_size]) &
                                  (s2_array.x < s2_stack.x[(x + 1) * sample_size]) & 
                                  (s2_array.y <= s2_stack.y[(y) * sample_size]) &
                                  (s2_array.y > s2_stack.y[(y + 1) * sample_size])
                                 )
        
        if missing_values(s2_array, chip_size, sample_size):
            print(f"Skipping chip at index {index}")
            continue        
        
        s2_array = s2_array.fillna(-999)
        s2_array = s2_array.rio.write_nodata(-999)
        s2_array = s2_array.astype(np.dtype(np.int16))
        s2_array = s2_array.rename("s2")
        

                
        lc_array = lc_stack.isel(x = x_coords, y = y_coords)
        lc_array.rio.write_crs(f"epsg:{epsg}", inplace=True)
        lc_array = lc_array.where((lc_array.x >= lc_stack.x[(x) * sample_size]) &
                                  (lc_array.x < lc_stack.x[(x + 1) * sample_size]) & 
                                  (lc_array.y <= lc_stack.y[(y) * sample_size] ) &
                                  (lc_array.y > lc_stack.y[(y + 1) * sample_size])
                                 )
        
        if missing_values(lc_array, chip_size, sample_size):
            print(f"Skipping chip at index {index}")
            continue
            
        lc_array = lc_array.fillna(-1)
        lc_array = lc_array.rio.write_nodata(-1)
        lc_array = lc_array.astype(np.dtype(np.int16))
        lc_array = lc_array.rename("lc")
        gen_status = gen_chips(s2_array, lc_array, global_index)
        if gen_status:
            metadata_df = pd.concat([pd.DataFrame([[global_index,
                                                    lc_array.mean(skipna=True).compute().data,
                                                    s2_stack.x[(x) * sample_size].data,
                                                    s2_stack.y[(y) * sample_size].data,
                                                    epsg]
                                                  ],
                                                  columns=metadata_df.columns
                                                 ),
                                     metadata_df],
                                    ignore_index=True
                                   )
            global_index += 1
    
    return global_index, metadata_df

In [None]:
global_index = 0
metadata_df = pd.DataFrame(columns=["chip_id", "lc", "tlc_x", "tlc_y", "epsg"])
for index, aoi in aoi_gdf.iterrows():
    print(f"\nProcessing AOI at index {index}")
    
    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    for date_range in config["sentinel_2"]["time_ranges"]:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<4:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_s2_data(s2_items, config)
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
        
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue

    lc_items = search_lc_scene(aoi_bounds, catalog, config)
    if not lc_items:
        print(f"No Land Cover data found for AOI {aoi_bounds}")
        continue
    
    lc_stack = stack_lc_data(lc_items, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox, config)
    if lc_stack is None:
        print(f"Failed to stack Land Cover data for AOI {aoi_bounds} and date range {date_range}")
        continue

    global_index, metadata_df = process_chips(s2_stack,
                                              lc_stack,
                                              epsg,
                                              config["chips"]["sample_size"],
                                              config["chips"]["chip_size"],
                                              global_index,
                                              metadata_df)


In [None]:
# metadata_df.to_csv('/home/benchuser/data/metadata_df.csv', index=False)