In [None]:
import os
os.chdir("../")

In [None]:
import dask.distributed
import pystac_client
import planetary_computer
import numpy as np
import pandas as pd
import rioxarray
import geopandas as gpd
from src.utils import search_s2_scenes, search_lc_scene, stack_s2_data, stack_lc_data, unique_class, missing_values, gen_chips
import pystac
import yaml



In [None]:
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)

In [None]:
# # Sentinel-2 settings
# # s2_collection = config["sentinel_2"]["collection"]
# # s2_date_ranges = config["sentinel_2"]["time_ranges"]
# # s2_bands = config["sentinel_2"]["bands"]
# s2_resolution = config["sentinel_2"]["resolution"]
# # cloud_cover_threshold = config["sentinel_2"]["cloud_cover"]  # Max allowed cloud cover

# # Land Cover settings
# # lc_collection = config["land_cover"]["collection"]
# # lc_year = config["land_cover"]["year"]  # Year of LC dataset

# # Chip settings
# # sample_size = config["chips"]["sample_size"]  # Grid size for homogeneity check
# chip_size = config["chips"]["chip_size"]  # Output chip size

# # Output settings

# chip_naming_convention = config["output"]["naming_convention"]

# # Metadata settings
# metadata_file = config["metadata"]["file"]

In [None]:
# Ensure the output directory exists
output_dir = config["output"]["directory"]
os.makedirs(output_dir, exist_ok=True)

# Define seasons for indexing
seasons = ["JFM", "AMJ", "JAS", "OND"]

aoi_gdf = gpd.read_file("data/urbans.geojson") # or "data/aois.geojson"

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)
print(client.dashboard_link)

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

In [None]:
global_index = 0
chip_dict = {}
for index, aoi in aoi_gdf.iterrows():
    print(f"\nProcessing AOI at index {index}")
    
    # Get the bounding box from the geometry of the AOI
    aoi_bounds = aoi['geometry'].bounds
    # print(f"AOI Geometry Bounds: {aoi_bounds}")  # Print the bounding box for the AOI
    s2_items = pystac.item_collection.ItemCollection([])
    for date_range in config["sentinel_2"]["time_ranges"]:
        print(f"Querying for date range: {date_range}")
        
        s2_items_season = search_s2_scenes(aoi, date_range, catalog, config)
        if not s2_items_season:
            print(f"No Sentinel-2 scenes found for AOI {aoi_bounds} and date range {date_range}")
            continue

        s2_items += s2_items_season
        

    s2_stack = stack_s2_data(s2_items, config)
    if s2_stack is None:
        print(f"Failed to stack Sentinel-2 bands for AOI {aoi_bounds} and date range {date_range}")
        continue

    # Search Land Cover data for the AOI's bounding box
    lc_items = search_lc_scene(aoi_bounds, catalog, config)
    if not lc_items:
        print(f"No Land Cover data found for AOI {aoi_bounds} and date range {date_range}")
        continue
    
    lc_stack = stack_lc_data(lc_items, s2_stack.rio.crs.to_epsg(), s2_items[0].bbox, config)
    if lc_stack is None:
        print(f"Failed to stack Land Cover data for AOI {aoi_bounds} and date range {date_range}")
        continue
    print(f"Land Cover stack shape: {lc_stack.shape}")
    break

    #     # Process chips after confirming the data is available
    #     global_index = process_chips(aoi, s2_stack, lc_stack, output_dir, global_index, chip_dict, sample_size)
    # print("stopping after 1st AOI")


In [None]:
lc_uniqueness = lc_stack.coarsen(x = config["chips"]["sample_size"],
                                 y = config["chips"]["sample_size"],
                                 boundary = "trim"
                                ).reduce(unique_class).compute()


In [None]:
lc_uniqueness[0, :] = False
lc_uniqueness[-1, :] = False
lc_uniqueness[:, 0] = False
lc_uniqueness[:, -1] = False

In [None]:
rows, cols = np.where(lc_uniqueness)

In [None]:
for index in range(0, len(rows)):
    print(index)
    i = rows[index]
    j = cols[index]
        
    x_coords = slice((i) * config["chips"]["sample_size"] - int((config["chips"]["chip_size"] - config["chips"]["sample_size"])/2), (i + 1) * config["chips"]["sample_size"] + int((config["chips"]["chip_size"] - config["chips"]["sample_size"])/2))
    y_coords = slice((j) * config["chips"]["sample_size"] - int((config["chips"]["chip_size"] - config["chips"]["sample_size"])/2), (j + 1) * config["chips"]["sample_size"] + int((config["chips"]["chip_size"] - config["chips"]["sample_size"])/2))    
    
    s2_array = s2_stack.isel(x = x_coords, y = y_coords)
    s2_array.rio.write_crs(f"epsg:{int(s2_items[0].properties["proj:code"].split(":")[-1])}", inplace=True)
    s2_array = s2_array.where((s2_array.x >= s2_stack.x[(i - 1) * config["chips"]["sample_size"]]) &
                              (s2_array.x < s2_stack.x[i * config["chips"]["sample_size"]]) & 
                              (s2_array.y <= s2_stack.y[(j - 1) * config["chips"]["sample_size"]] ) &
                              (s2_array.y > s2_stack.y[j * config["chips"]["sample_size"]])
                             )
    
    s2_array = s2_array.fillna(-999)
    s2_array = s2_array.rio.write_nodata(-999)
    s2_array = s2_array.astype(np.dtype(np.int16))
    s2_array = s2_array.rename("s2")
    s2_array = s2_array.compute()
    
    if np.sum(s2_array<0)>0:
        continue
        
    lc_array = lc_stack.isel(x = x_coords, y = y_coords)
    lc_array.rio.write_crs(f"epsg:{int(s2_items[0].properties["proj:code"].split(":")[-1])}", inplace=True)
    lc_array = lc_array.where((lc_array.x >= lc_stack.x[(i - 1) * config["chips"]["sample_size"]]) &
                              (lc_array.x < lc_stack.x[i * config["chips"]["sample_size"]]) & 
                              (lc_array.y <= lc_stack.y[(j - 1) * config["chips"]["sample_size"]] ) &
                              (lc_array.y > lc_stack.y[j * config["chips"]["sample_size"]])
                             )
    

    lc_array = lc_array.fillna(-99)
    lc_array = lc_array.rio.write_nodata(-99)
    lc_array = lc_array.astype(np.dtype(np.int8))
    lc_array = lc_array.rename("lc")
    break
#     gen_status = gen_chips(s2_array, lc_array, global_index)
#     if gen_status:
#         metadata_df = pd.concat([pd.DataFrame([[global_index,
#                                                 lc_array.mean(skipna=True).compute().data,
#                                                 s2_stack.x[i].data,
#                                                 s2_stack.y[j].data,
#                                                 s2_items[0].properties["proj:epsg"]]
#                                               ],
#                                               columns=metadata_df.columns
#                                              ),
#                                  metadata_df],
#                                 ignore_index=True
#                                )
#                     global_index += 1
# metadata_df.to_csv('/home/benchuser/data/metadata_df.csv', index=False)

In [None]:
import pandas as pd 
lc_path = f"/home/benchuser/data/lc_{index:05}.tif"
for dt in s2_array.time.values:
    ts = pd.to_datetime(str(dt)) 
    s2_path = f"/home/benchuser/data/s2_{index:05}_{ts.strftime('%Y%m%d')}.tif"
    s2_array.sel(time = dt).squeeze().rio.to_raster(s2_path)

lc_array.rio.to_raster(lc_path)
gen_status = True

In [None]:
# if len(s2_items) == 0:
#     continue
# metadata_df = pd.DataFrame(columns=["chip_id", "lc", "tlc_x", "tlc_y", "epsg"])