In [None]:
import os
os.chdir("../")

In [None]:
import pystac_client
import pystac
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from pystac_client.stac_api_io import StacApiIO
import planetary_computer

import dask.distributed
import numpy as np
import rioxarray
import pandas as pd
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import yaml

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [None]:
aoi_gdf = gpd.read_file("data/map_v0.11.geojson")

In [None]:
# Following AOIs have broken scenes in the STAC catalog and should be removed
aoi_gdf = aoi_gdf[aoi_gdf.index != 12]
aoi_gdf = aoi_gdf[aoi_gdf.index != 25]
aoi_gdf = aoi_gdf[aoi_gdf.index != 46]
aoi_gdf = aoi_gdf[aoi_gdf.index != 60]
aoi_gdf = aoi_gdf[aoi_gdf.index != 81]
aoi_gdf = aoi_gdf[aoi_gdf.index != 153]

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()#(n_workers=8, threads_per_worker=2)
client = Client(cluster)
print(client.dashboard_link)

In [None]:
retry = Retry(
    total=10, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
stac_api_io = StacApiIO(max_retries=retry)

catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    stac_io=stac_api_io
)

In [None]:
def process_chips(s2_stack, lc_stack, epsg, sample_size, chip_size, global_index, metadata_df):
    
    try:
        lc_stack = lc_stack.compute()
    except:
        print("skipping the AOI for no LC data")
        return global_index, metadata_df
    
    lc_uniqueness = lc_stack.coarsen(x = sample_size,
                                     y = sample_size,
                                     boundary = "trim"
                                    ).reduce(unique_class)
    lc_uniqueness[0, :] = False
    lc_uniqueness[-1, :] = False
    lc_uniqueness[:, 0] = False
    lc_uniqueness[:, -1] = False

    ys, xs = np.where(lc_uniqueness)
    print("Loading s2_stack")
    
    try:
        s2_stack = s2_stack.compute()
    except:
        print("skipping the AOI for no S2 data")
        return global_index, metadata_df
    
    
    for index in range(0, len(ys)):
        y = ys[index]
        x = xs[index]
    
            
        x_coords = slice((x) * sample_size - int((chip_size - sample_size)/2), (x + 1) * sample_size + int((chip_size - sample_size)/2))
        y_coords = slice((y) * sample_size - int((chip_size - sample_size)/2), (y + 1) * sample_size + int((chip_size - sample_size)/2))    
        
        s2_array = s2_stack.isel(x = x_coords, y = y_coords)
        s2_array.rio.write_crs(f"epsg:{epsg}", inplace=True)
        s2_array = s2_array.where((s2_array.x >= s2_stack.x[(x) * sample_size]) &
                                  (s2_array.x < s2_stack.x[(x + 1) * sample_size]) & 
                                  (s2_array.y <= s2_stack.y[(y) * sample_size]) &
                                  (s2_array.y > s2_stack.y[(y + 1) * sample_size])
                                 )
        
        if missing_values(s2_array, chip_size, sample_size):
            continue        
        
        s2_array = s2_array.fillna(-999)
        s2_array = s2_array.rio.write_nodata(-999)
        s2_array = s2_array.astype(np.dtype(np.int16))
        s2_array = s2_array.rename("s2")
        

                
        lc_array = lc_stack.isel(x = x_coords, y = y_coords)
        lc_array.rio.write_crs(f"epsg:{epsg}", inplace=True)
        lc_array = lc_array.where((lc_array.x >= lc_stack.x[(x) * sample_size]) &
                                  (lc_array.x < lc_stack.x[(x + 1) * sample_size]) & 
                                  (lc_array.y <= lc_stack.y[(y) * sample_size] ) &
                                  (lc_array.y > lc_stack.y[(y + 1) * sample_size])
                                 )
        
        if missing_values(lc_array, chip_size, sample_size):
            continue

        if (np.isin(lc_array, [255, 130, 133])).any():
            raise ValueError('Wrong LC value')
        
        lc_array = lc_array.fillna(0)
        lc_array = lc_array.rio.write_nodata(0)
        lc_array = lc_array.astype(np.dtype(np.int8))
        lc_array = lc_array.rename("lc")
        if (np.isin(lc_array, [7])).any():
            gen_status, dts = gen_chips(s2_array, lc_array, global_index)
            if gen_status:
                metadata_df = pd.concat([pd.DataFrame([[global_index,
                                                        dts,
                                                        np.unique(lc_array)[1],
                                                        s2_stack.x[(x) * sample_size + int(sample_size / 2)].data,
                                                        s2_stack.y[(y) * sample_size + int(sample_size / 2)].data,
                                                        epsg]
                                                      ],
                                                      columns=metadata_df.columns
                                                     ),
                                         metadata_df],
                                        ignore_index=True
                                       )
                global_index += 1
    
    return global_index, metadata_df

In [None]:
global_index = 0
metadata_df = pd.DataFrame(columns=["chip_id", "dates", "lc", "x_center", "y_center", "epsg"])
# metadata_df = pd.read_csv("../data/metadata_df.csv") # Use this line to continue from a previous iteration if the code stops. 

In [None]:
for index, aoi in aoi_gdf[219:].iterrows():
    print(f"\nProcessing AOI at index {index}")
    
    aoi_bounds = aoi['geometry'].bounds
    s2_items = pystac.item_collection.ItemCollection([])
    for date_range in config["sentinel_2"]["time_ranges"]:        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        s2_items += s2_items_season

    if len(s2_items)<4:
        print(f"Missing Sentinel-2 scenes for AOI {aoi_bounds}")
        continue
        

    s2_stack = stack_s2_data(s2_items, config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds}")
        continue
        
    try:
        epsg = s2_items[0].properties["proj:epsg"]
    except:
        epsg = int(s2_items[0].properties["proj:code"].split(":")[-1])
        


    lc_items = search_lc_scene(s2_items[0].bbox, catalog, config)
    if not lc_items:
        print(f"No Land Cover data found for AOI {aoi_bounds}")
        continue
    
    lc_stack = stack_lc_data(lc_items, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox, config)
    if lc_stack is None:
        print(f"Failed to stack Land Cover data for AOI {aoi_bounds} and date range {date_range}")
        continue

    global_index, metadata_df = process_chips(s2_stack,
                                              lc_stack,
                                              epsg,
                                              config["chips"]["sample_size"],
                                              config["chips"]["chip_size"],
                                              global_index,
                                              metadata_df)
    metadata_df.to_csv('/home/benchuser/data/metadata_df.csv', index=False)