In [3]:
import os
os.chdir("../")

In [4]:
import dask.distributed
import pystac_client
import planetary_computer
import stackstac 
import numpy as np
import pandas as pd
import rioxarray
import geopandas as gpd
from src.utils import gen_chips

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [5]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)
print(client.dashboard_link)

http://127.0.0.1:8787/status


Task exception was never retrieved
future: <Task finished name='Task-16518' coro=<Client._gather.<locals>.wait() done, defined at /opt/conda/envs/gfm_bench/lib/python3.12/site-packages/distributed/client.py:2385> exception=AllExit()>
Traceback (most recent call last):
  File "/opt/conda/envs/gfm_bench/lib/python3.12/site-packages/distributed/client.py", line 2394, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-19170' coro=<Client._gather.<locals>.wait() done, defined at /opt/conda/envs/gfm_bench/lib/python3.12/site-packages/distributed/client.py:2385> exception=AllExit()>
Traceback (most recent call last):
  File "/opt/conda/envs/gfm_bench/lib/python3.12/site-packages/distributed/client.py", line 2394, in wait
    raise AllExit()
distributed.client.AllExit


In [6]:
aoi_gdf = gpd.read_file("data/urbans.geojson") # or "data/aois.geojson"

In [8]:
s2_assets = ["B02", "B03", "B04", "B08", "B11", "B12"]
chip_size = 224
sample_size = 100
metadata_df = pd.DataFrame(columns=["chip_id", "lc", "tlc_x", "tlc_y", "epsg"])

In [10]:
global_index = 0
for aoi in aoi_gdf.iterrows():
    s2_search = catalog.search(collections = ["sentinel-2-l2a"], 
                               bbox = aoi[1]["geometry"].bounds,
                               datetime="2023-02-01/2023-08-30",
                               query=["eo:cloud_cover<1", ],
                               sortby=["+properties.eo:cloud_cover"],
                               max_items=1
                              )
    
    s2_items = s2_search.item_collection()

    if len(s2_items) == 0:
        continue
    else:
        s2_stack = stackstac.stack(
            s2_items,
            assets = s2_assets,
            epsg = s2_items[0].properties["proj:epsg"],
            resolution = 10, 
            bounds_latlon = s2_items[0].bbox
        )
        s2_stack_resampled = s2_stack.median("time", skipna=True).squeeze()
        s2_stack_resampled = s2_stack_resampled.chunk(chunks={"band":6, "x":sample_size, "y":sample_size})
    
        lc_search = catalog.search(collections = ["io-lulc-annual-v02"],
                                   bbox = s2_items[0].bbox,
                                   datetime = "2023-01-02/2023-12-30", #This only returns 2023 tiles
                                  )
        lc_items = lc_search.item_collection()
        if len(lc_items) == 0:
            continue
        else:
            lc_stack = stackstac.stack(
                lc_items,
                dtype = np.ubyte,
                fill_value = 255,
                sortby_date = False,
                epsg = s2_items[0].properties["proj:epsg"],
                resolution = 10,
                bounds_latlon = s2_items[0].bbox
            ).squeeze()
        
            lc_stack = lc_stack.chunk(chunks={"x":sample_size, "y":sample_size})
            
            for i in range(3 * sample_size, s2_stack_resampled.shape[1]-int(chip_size/2)-5, sample_size):
                for j in range(3 * sample_size, s2_stack_resampled.shape[2]-int(chip_size/2)-5, sample_size):
    
                    
                    lc_array_temp = lc_stack.isel(x = slice(i, i + sample_size), y = slice(j, j + sample_size))
                    
                    if (np.unique(lc_array_temp)).shape[0] == 1:
                        if np.unique(lc_array_temp) == 7:
                            x_coords = slice(i - int((chip_size - sample_size)/2), i + sample_size + int((chip_size - sample_size)/2))
                            y_coords = slice(j - int((chip_size - sample_size)/2), j + sample_size + int((chip_size - sample_size)/2))    
                            s2_array = s2_stack_resampled.isel(x = x_coords, y = y_coords)
                            s2_array.rio.write_crs(f"epsg:{s2_items[0].properties["proj:epsg"]}", inplace=True)
                            s2_array = s2_array.where((s2_array.x >= s2_stack_resampled.x[i]) & (s2_array.x < s2_stack_resampled.x[i + sample_size]) & 
                                                          (s2_array.y <= s2_stack_resampled.y[j] ) & (s2_array.y > s2_stack_resampled.y[j + sample_size]))
                   
                            s2_array = s2_array.fillna(-999)
                            s2_array = s2_array.rio.write_nodata(-999)
                            s2_array = s2_array.astype(np.dtype(np.int16))
                            s2_array = s2_array.rename("s2")
        
                            lc_array = lc_stack.isel(x = x_coords, y = y_coords)
                            lc_array.rio.write_crs(f"epsg:{s2_items[0].properties["proj:epsg"]}", inplace=True)
                            lc_array = lc_array.where((lc_array.x >= lc_stack.x[i]) & (lc_array.x < lc_stack.x[i + sample_size]) & 
                                                      (lc_array.y <= lc_stack.y[j] ) & (lc_array.y > lc_stack.y[j + sample_size]))
                            
            
                            lc_array = lc_array.fillna(-99)
                            lc_array = lc_array.rio.write_nodata(-99)
                            lc_array = lc_array.astype(np.dtype(np.int8))
                            lc_array = lc_array.rename("lc")
                            gen_status = gen_chips(s2_array, lc_array, global_index)
                            if gen_status:
                                metadata_df = pd.concat([pd.DataFrame([[global_index,
                                                                        lc_array.mean(skipna=True).compute().data,
                                                                        s2_stack_resampled.x[i].data,
                                                                        s2_stack_resampled.y[j].data,
                                                                        s2_items[0].properties["proj:epsg"]]
                                                                      ],
                                                                      columns=metadata_df.columns
                                                                     ),
                                                         metadata_df],
                                                        ignore_index=True
                                                       )
                                global_index += 1
    metadata_df.to_csv('/home/benchuser/data/metadata_df.csv', index=False)

  times = pd.to_datetime(
  times = pd.to_datetime(


KeyboardInterrupt: 