In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import warnings
import logging

sys.path += ["../"]
warnings.filterwarnings("ignore")

# dask
from dask.distributed import Client, LocalCluster, as_completed
from dask import delayed

from datetime import datetime
from time import sleep
import pandas as pd
import numpy as np


# Install 'mosaiks' package

### Install locally

In [None]:
# !pip uninstall mosaiks -y

In [None]:
!pip install -e ..

In [None]:
from mosaiks.featurize import *
import mosaiks.utils as utl

### Install from GitHub (outdated, maybe install from main?)

🚨🚨 **Make sure you update github token in the secrets file** 🚨🚨 

In [None]:
# import src.mosaiks.utils as utl
# secrets = utl.load_yaml_config("../config/secrets.yml")
# GITHUB_TOKEN = secrets["GITHUB_TOKEN"]
# mosaiks_package_link = f"git+https://{GITHUB_TOKEN}@github.com/IDinsight/mosaiks@as-package"

In [None]:
# !pip install {mosaiks_package_link} --upgrade
# from mosaiks.featurize import *
# import mosaiks.utils as utl

# Setup Rasterio

In [None]:
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

# Setup Dask Cluster and Client

### Local Cluster

4 workers with 4 threads each seem to work best. A lot of time a thread is waiting on data to load so CPU is underutilized.

In [None]:
cluster = LocalCluster(
    n_workers=4, processes=True, threads_per_worker=4, silence_logs=logging.ERROR
)
client = Client(cluster)
client


### Gateway cluster

In [None]:
# from dask_gateway import Gateway
# import dask_gateway
# from dask.distributed import PipInstall

# gateway = Gateway()
# options = gateway.cluster_options()
# options


In [None]:
# cluster = gateway.new_cluster(options)
# client = cluster.get_client()
# print(cluster.dashboard_link)

# plugin = PipInstall(packages=[mosaiks_package_link], pip_options=["--upgrade"], restart=False)
# client.register_worker_plugin(plugin)

# cluster.scale(10)

# Load params

In [None]:
featurization_params = utl.load_yaml_config("featurisation.yaml")
satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_params["satellite_search_params"]["satellite_name"]
]
coord_set_name = "request_points_centroids"

# Load point coords

In [None]:
request_points_gdf = utl.load_df_w_latlons_to_gdf(dataset_name=coord_set_name)

In [None]:
# only keep points in the focus states
focus_states_id_dict = {
    20: "jharkhand",
    22: "chhattisgarh",
    8: "rajasthan",
    23: "madhya pradesh",
    18: "assam",
    16: "tripura",
}

focus_states_filter = request_points_gdf["pc11_s_id"].isin(focus_states_id_dict.keys())
points_gdf_focus = request_points_gdf[focus_states_filter]
points_gdf_focus.shape

In [None]:
points_gdf = points_gdf_focus  # .sample(200, random_state=0) # Select random 200 points (for testing)

# Fetch image stac refs

`fetch_image_refs` now returns a dask dataframe and is not yet computed. So it finishes quite quickly.

🌱 **SUGGESTION:** Change `n_partitions` parameter to `n_per_partition` and calculate `n_partitions` here? This will ensure each partition takes a similar amount of time to process even across differing point-set sizes.

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    points_gdf, 
    featurization_params['dask']['n_partitions'],
    featurization_params['satellite_search_params']
)

# Define delayed objects

We use the `delayed` decorator to turn our function into a delayed function. This means it will not run immediately when called but instead return a delayed object that can be run later

In [None]:
mosaiks_column_names = [f"mosaiks_{i}" for i in range(featurization_params["num_features"])]

In [None]:
@delayed
def partition_run(df, satellite_config, featurization_params, model, device):

    data_loader = create_data_loader(
        df, satellite_config, featurization_params["batch_size"]
    )
    X_features = create_features(
        data_loader,
        featurization_params["num_features"],
        len(df),
        model,
        device,
        satellite_config["min_image_edge"],
    )

    df = pd.DataFrame(X_features, index=df.index.copy(), columns=mosaiks_column_names)

    return df

We want to convert our dask dataframe into "delayed" objects. Each partition is now a delayed pandas dataframe and can be passed to our delayed function above

In [None]:
partitions = points_gdf_with_stac.to_delayed()

In [None]:
model = RCF(
    featurization_params["num_features"],
    featurization_params["kernel_size"],
    len(satellite_config["bands"]),
)

# Run in parallel

## Trial run

The cell below will only run it for 8 of the partitions. That seems to be about how many we can do in parallel on a local cluster. We may be able to do more on a Gateway Cluster once that is working.

There are also better schemes. For example, kick off another partitions whenever one finishes. That might be a better use of resources.

In [None]:
%%time

i = 0
p = partitions[i]
f = partition_run(p, satellite_config, featurization_params, model, 'cuda', dask_key_name=f'run_{i}')
df_future = client.compute(f)
for f in as_completed([df_future]):
    df = f.result()

In [None]:
print("Average feature value:", df.mean().mean())
df.iloc[0].hist()
_ = client.restart()

In [None]:
# %%time

# for i in range(4):
#     p = partitions[i]
#     f = partition_run(p, satellite_config, featurization_params, model, 'cuda', dask_key_name=f'run_{i}')
#     df_future = client.compute(f)
#     for f in as_completed([df_future]):
#         df = f.result()


## Full run

This is going to create 200 dataframes - one for each partition. If any fail, we can always just rerun that single component.

### Setup saving location

In [None]:
satellite = featurization_params["satellite_search_params"]["satellite_name"]
year = featurization_params["satellite_search_params"]["search_start"].split("-")[0]
coord_set_name = coord_set_name
n_features = str(featurization_params["num_features"])

mosaiks_folder_path = utl.make_features_path(
    satellite,
    year,
    coord_set_name,
    n_features,
    filename=None,
)

os.makedirs(mosaiks_folder_path, exist_ok=True)

### Create features and save checkpoints to file

In [None]:
N_PARTITIONS = len(partitions)
N_PER_RUN = 8
START_IDX = 0

In [None]:
p_ids = np.arange(START_IDX, N_PARTITIONS + N_PER_RUN, N_PER_RUN)

failed_list = []
for p_start_id, p_end_id in zip(p_ids[:-1], p_ids[1:]):
    now = datetime.now().strftime("%d-%b %H:%M:%S")
    print(f"{now} Running batch: ", p_start_id, "to", p_end_id - 1)

    delayed_dfs = []
    for i, p in enumerate(partitions[p_start_id:p_end_id]):
        
        str_i = str(p_start_id + i).zfill(3) # makes 1 into '001'
        f = partition_run(
            p,
            satellite_config,
            featurization_params,
            model,
            featurization_params["device"],
            dask_key_name=f"features_{str_i}",
        )
        delayed_dfs.append(f)
    futures_dfs = client.compute(delayed_dfs)

    for f in as_completed(futures_dfs):
        try:
            df = f.result()
            df.to_parquet(f"{mosaiks_folder_path}/df_{f.key}.parquet.gzip")

        except Exception as e:
            f_key = f.key
            partition_id = int(f_key.split("features_")[1])
            print(f"Partition {partition_id} failed. Error:", e)
            failed_list.append(partition_id)

    client.restart()
    sleep(5)


## Re-run failed partitions

Use this to just run partitions that failed

In [None]:
%%time

FAILED_IDX = failed_list #[44]

delayed_dfs = []
failed_list_1 = []
for i in FAILED_IDX:
    p = partitions[i]
    str_i = str(i).zfill(3)
    f = partition_run(
        p, 
        satellite_config, 
        featurization_params, model, 
        featurization_params['device'], 
        dask_key_name=f'features_{str_i}'
    )
    delayed_dfs.append(f)
    futures_dfs = client.compute(delayed_dfs)
    
    for f in as_completed(futures_dfs):
        try:
            df = f.result()
            df.to_parquet(f"{mosaiks_folder_path}/df_{f.key}.parquet.gzip")
        except Exception as e:
            print(f"Partition {f.key} failed. Error:", e)
            failed_list_1.append(f.key)

# Load checkpoint files and combine

In [None]:
checkpoint_filenames = utl.get_filtered_filenames(mosaiks_folder_path, prefix="df_")

In [None]:
combined_df = utl.load_and_combine_dataframes(mosaiks_folder_path, checkpoint_filenames)
combined_df = combined_df.join(points_gdf[["Lat", "Lon", "shrid"]])

print("Dataset size in memory (MB):", combined_df.memory_usage().sum() / 1000000)

In [None]:
%%time
combined_filename = "features.parquet.gzip"
utl.save_dataframe(combined_df, file_path=mosaiks_folder_path / combined_filename)