# README for DEBUGGING

Note: By default, solution 1 below is applied in this branch.

### For the implementation with errors
change the following in `stacs.py`:
- In `fetch_stac_items()` - use unmodified image shapes for point matching on line 222 instead of line 224
- In `CustomDataSet` -  use the original stacstack implementation beginning at line 406 (searches in the output image crs)

### For solution 1 
Only assigning images that cover our point with both their shape and `proj:bbox` shape:
- In `fetch_stac_items()` - use `_get_trimmed_stac_shapes_gdf()` on line 224 instead of line 222

### For solution 2 - 
using stackstac in the latlong 4326 crs:
- Revert `fetch_stac_items()` to use image shapes directly
- In `CustomDataSet` use the new stacstack implementation beginning at line 423 (searches in latlon 4326 crs)


Note: The resolution parameter in stackstac become a bit complicated in this solution.

# Notebook Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# get current working directory
import os

os.getcwd()


In [None]:
# change working directory to the root of the project
os.chdir("/home/jovyan/ds_nudge_up/")


In [None]:
import sys

sys.path += ["../"]

import warnings

warnings.filterwarnings("ignore")


In [None]:
import os
import src.mosaiks.utils as utl

rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)


#### `pip install` MOSAIKS

From local folder:

In [None]:
!pip install -e .

From GitHub package:

🚨🚨 **Make sure you update github token in the secrets file** 🚨🚨 

In [None]:
# secrets = utl.load_yaml_config("../config/secrets.yml")
# GITHUB_TOKEN = secrets["GITHUB_TOKEN"]
# mosaiks_package_link = f"git+https://{GITHUB_TOKEN}@github.com/IDinsight/ds_nudge_up@as-package"


In [None]:
# !pip uninstall mosaiks -y
# !pip install {mosaiks_package_link} --upgrade


# Setup Dask Cluster and Client

## Local Cluster

4 workers with 4 threads each seem to work best. A lot of time a thread is waiting on data to load so CPU is underutilized.

In [None]:
import logging
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(
    n_workers=4, processes=True, threads_per_worker=4, silence_logs=logging.ERROR
)
client = Client(cluster)
client


## Gateway cluster

In [None]:
# from dask_gateway import Gateway
# import dask_gateway
# from dask.distributed import PipInstall

# gateway = Gateway()
# options = gateway.cluster_options()
# options


In [None]:
# from dask.distributed import PipInstall

# cluster = gateway.new_cluster(options)
# client = cluster.get_client()
# print(cluster.dashboard_link)

# plugin = PipInstall(packages=[mosaiks_package_link], pip_options=["--upgrade"], restart=False)
# client.register_worker_plugin(plugin)

# cluster.scale(10)


In [None]:
# cluster.shutdown()


# Load params

In [None]:
from mosaiks.featurize import *

from dask import delayed
from dask.distributed import as_completed
from time import sleep
import pandas as pd
import numpy as np


In [None]:
featurization_params = utl.load_yaml_config("featurisation.yaml")
satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_params["satellite_search_params"]["satellite_name"]
]
data_sources = utl.load_yaml_config("data_catalog.yaml")


# Load point coords

In [None]:
points_gdf = utl.load_points_gdf(**data_sources["request_points_centroids"])

focus_states_id_dict = {
    20: "jharkhand",
    22: "chhattisgarh",
    8: "rajasthan",
    23: "madhya pradesh",
    18: "assam",
    16: "tripura",
}
focus_states_filter = points_gdf["pc11_s_id"].isin(focus_states_id_dict.keys())

In [None]:
points_gdf_focus = points_gdf[focus_states_filter]
points_gdf_focus.shape

In [None]:
points_gdf_focus["shrid"].drop_duplicates(keep="first").shape

In [None]:
# temp = points_gdf_focus.sample(300)

# Fetch image stac refs

`fetch_image_refs` now returns a dask dataframe and is not yet computed. So it finishes quite quickly.

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    points_gdf_focus, 
    featurization_params['dask']['n_partitions'],
    featurization_params['satellite_search_params']
)

# Define delayed objects

We use the `delayed` decorator to turn our function into a delayed function. This means it will not run immediately when called but instead return a delayed object that can be run later

In [None]:
@delayed
def partition_run(df, satellite_config, featurization_params, model, device):

    data_loader = create_data_loader(
        df, satellite_config, featurization_params["batch_size"]
    )
    X_features = create_features(
        data_loader,
        featurization_params["num_features"],
        len(df),
        model,
        device,
        satellite_config["min_image_edge"],
    )

    df = pd.DataFrame(X_features, index=df.index.copy())

    return df

We want to convert our dask dataframe into "delayed" objects. Each partition is now a delayed pandas dataframe and can be passed to our delayed function above

In [None]:
partitions = points_gdf_with_stac.to_delayed()

In [None]:
model = RCF(
    featurization_params["num_features"],
    featurization_params["kernel_size"],
    len(satellite_config["bands"]),
)

# Diagnostics

In [None]:
import shapely.geometry
import geopandas as gpd
import pyproj
import stackstac

### Run for the problematic partition

In [None]:
p_44 = partitions[44].compute()

In [None]:
# use loop below or run for selective indices using...
# i = 500015 # 500045 fails, 500015 works
# row = p_44.loc[i]

failing_IDs = []
for i, row in p_44.iterrows():

    stac_item = row["stac_item"]
    lat = row["Lat"]
    lon = row["Lon"]
    buffer = 1200

    # convert point latlons to image crs and create buffer using meters
    stac_crs = stac_item.properties["proj:epsg"]
    proj_latlon_to_stac = pyproj.Transformer.from_crs(4326, stac_crs, always_xy=True)
    x_utm, y_utm = proj_latlon_to_stac.transform(lon, lat)

    x_min, x_max = x_utm - buffer, x_utm + buffer
    y_min, y_max = y_utm - buffer, y_utm + buffer

    # convert buffer bounds back to latlons
    proj_stac_to_latlon = pyproj.Transformer.from_crs(stac_crs, 4326, always_xy=True)
    x_min, y_min = proj_stac_to_latlon.transform(x_min, y_min)
    x_max, y_max = proj_stac_to_latlon.transform(x_max, y_max)

    xarray = stackstac.stack(
        stac_item,
        assets=satellite_config["bands"],
        epsg=4326,
        resolution=0.00027,  # satellite_config["resolution"],
        bounds_latlon=[x_min, y_min, x_max, y_max],
        # rescale=False,
        dtype=np.uint8,
        fill_value=0,
        # snap_bounds=False
    )

    # if time dimension is 0 it means the returned xarray is unusable
    time_dim = xarray.shape[0]
    p_44.loc[i, "xarray_time_dim"] = time_dim

    # Check where the point sits inside the image

    # 1. STAC geometry (in same projection as STAC image and crop)
    stac_shape = shapely.geometry.shape(stac_item.geometry)
    stac_shape = gpd.GeoSeries(stac_shape).set_crs("EPSG:4326").geometry[0] # .to_crs(stac_crs)

    # 2. Use the STAC proj:bbox property to make a shape
    x_min_p, y_min_p, x_max_p, y_max_p = stac_item.properties["proj:bbox"]
    image_bbox = shapely.geometry.Polygon(
        [[x_min_p, y_min_p], [x_min_p, y_max_p], [x_max_p, y_max_p], [x_max_p, y_min_p]]
    )
    image_bbox = gpd.GeoSeries(image_bbox).set_crs(stac_crs).to_crs("EPSG:4326").geometry[0]

    # 3. Convert the crop square to a shape
    crop_square = shapely.geometry.Polygon(
        [[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]
    )

    # Store whether the crop sits within this bounding box/STAC geometry
    p_44.loc[i, "intersects_bbox"] = crop_square.intersects(image_bbox)
    p_44.loc[i, "intersects_geometry"] = crop_square.intersects(stac_shape)
    # p_44.loc[i, "xarray_sum"] = np.array(xarray).sum()

#     if p_44.loc[i, "xarray_sum"] == 0: #p_44.loc[i, "xarray_time_dim"] == 0:
#         failing_IDs.append(i)

#         print(i)
#         print("Intersects bbox?", p_44.loc[i, "intersects_bbox"])
#         print("Intersects STAC geometry?", p_44.loc[i, "intersects_geometry"])

#         # plot all shapes
#         shapes_gdf = gpd.GeoDataFrame(
#             {"item":["stac_shape", "image_bbox", "crop"]},
#             geometry=[stac_shape, image_bbox, crop_square]
#         ).set_crs("EPSG:4326")

#         shapes_gdf.plot(column="item", legend=True, alpha=0.6, figsize=(4,4))
#         plt.show()

In [None]:
# check how many of the failing points sit inside/outside STAC geometry
crosstab_geom = pd.crosstab(
    index=p_44["xarray_time_dim"], columns=p_44["intersects_geometry"]
)
crosstab_geom

In [None]:
# check how many of the failing points sit inside/outside bbox
crosstab_bbox = pd.crosstab(
    index=p_44["xarray_time_dim"], columns=p_44["intersects_bbox"]
)
crosstab_bbox

In [None]:
# p_44[~p_44["xarray_sum"]>0]

So every point that fails sits within the STAC geometry (by definition, since only STAC items that had points inside were fetched based on their given geometry) but outside the "proj:bbox" shape - this must be a data/coding issue on the database end since the bbox should always match the geometry parameter perfectly but sometimes it does not. 

Two solutions:
1. Trim the geometries to within only the area within the bbox (implemented through the `_get_trimmed_stac_shapes_gdf` function in `stacs.py`) when trying to select which STAC item(s) to return for each point. This results in a different STAC item being returned as least cloudy, for example, than the problematic one.
2. Catch xarrays with a 0 time dimension as errors and return None for the points that suffer from this issue. Not preferable as we lose datapoints for no real reason (usually there are other valid images that could be used).

# Run in parallel

## Trial run

The cell below will only run it for 8 of the partitions. That seems to be about how many we can do in parallel on a local cluster. We may be able to do more on a Gateway Cluster once that is working.

There are also better schemes. For example, kick off another partitions whenever one finishes. That might be a better use of resources.

In [None]:
%%time

i = 0
p = partitions[i]
f = partition_run(p, satellite_config, featurization_params, model, 'cuda', dask_key_name=f'run_{i}')
df_future = client.compute(f)
for f in as_completed([df_future]):
    df = f.result()

df

In [None]:
_ = client.restart()

## Full run

This is going to create 200 dataframes - one for each partition. If any fail, we can always just rerun that single component.

In [None]:
# from datetime import datetime

# N_PARTITIONS = len(partitions)
# N_PER_RUN = 8
# START_IDX = 44
# str_column_names = [str(i) for i in range(featurization_params['num_features'])]

# p_ids = np.arange(START_IDX, N_PARTITIONS + N_PER_RUN, N_PER_RUN)

# for p_start_id, p_end_id in zip(p_ids[:-1], p_ids[1:]):
#     now = datetime.now().strftime("%d-%b %H:%M:%S")
#     print(f"{now} Running batch: ", p_start_id, "to", p_end_id-1)

#     delayed_dfs = []
#     for i, p in enumerate(partitions[p_start_id:p_end_id]):
#         f = partition_run(p, satellite_config, featurization_params, model,
#                           featurization_params['device'], dask_key_name=f'features_{p_start_id + i}')
#         delayed_dfs.append(f)
#     futures_dfs = client.compute(delayed_dfs)

#     for f in as_completed(futures_dfs):
#         try:
#             df = f.result()
#             df.columns = str_column_names
#             df.to_parquet(f'data/df_{f.key}.parquet.gzip', compression='gzip')
#         except Exception as e:
#             print(f"Partition {f.key} failed. Error:", e)

#     client.restart()
#     sleep(5)

In [None]:
# for f in as_completed(futures_dfs[-3:]):
#     df = f.result()
#     df.columns = str_column_names
#     df.to_parquet(f'data/df_{f.key}.parquet.gzip', compression='gzip')

In [None]:
# client.shutdown()

## Re-run failed partitions

Use this to just run partitions that failed

In [None]:
# %%time

# FAILED_IDX = [44]

# delayed_dfs = []
# for i in FAILED_IDX:
#     p = partitions[i]
#     f = partition_run(p, satellite_config, featurization_params, model,
#                       featurization_params['device'], dask_key_name=f'features_{i}')
#     delayed_dfs.append(f)
#     futures_dfs = client.compute(delayed_dfs)

#     for f in as_completed(futures_dfs):
#         f.result().to_csv(f'data/df_{f.key}.csv')

In [None]:
# _ = client.restart()

# Load checkpoint files and combine

In [None]:
import pandas as pd
import os

path = "./data"
all_files = os.listdir(path)

# Select only CSV files from the folder
parquet_files = sorted([file for file in all_files if file.endswith(".gzip")])
parquet_files = parquet_files[1:3]

In [None]:
pd.Series(parquet_files).to_csv("./data/file_list.csv")

In [None]:
dfs = []
for filename in parquet_files:

    df = pd.read_parquet("./data/" + filename)
    dfs.append(df)

combined_df = pd.concat(dfs, axis=0)
print("Dataset size in memory (MB):", combined_df.memory_usage().sum() / 1000000)

In [None]:
combined_df.shape

In [None]:
combined_df.sort_index()

In [None]:
combined_df.to_parquet(
    "centroid_features_landsat_TEMP.parquet.gzip", compression="gzip"
)

In [None]:
df = pd.read_parquet("df_features_198.parquet.gzip")
df