### TODO

- Simplify further?
    - Stac search per point instead of the hilbert distances shenanigans
    - Instead of 500 points per thread on a dataloader, just have each thread work on 1 point at a time
- Add error catching to Queued Futures

# Notebook prep

In [None]:
%load_ext autoreload
%autoreload 2

## Install the mosaiks package

In [None]:
# Locally
# !pip install -e .. --upgrade

In [None]:
# From github
# 🚨 Make sure you update github token in the secrets file 🚨
# import src.mosaiks.utils as utl
# mosaiks_package_link = utl.get_mosaiks_package_link
# !pip install {mosaiks_package_link} --upgrade

## Import packages

In [None]:
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
import sys
import os
import warnings

sys.path += ["../"]
warnings.filterwarnings("ignore")

In [None]:
from pathlib import Path
import mosaiks.utils as utl
from mosaiks.featurize import RCF

## Setup Rasterio

In [None]:
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

# Load params + defaults

In [None]:
featurization_config = utl.load_yaml_config("featurisation.yaml")

satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_config["satellite_search_params"]["satellite_name"]
]

In [None]:
mosaiks_column_names = [
    f"mosaiks_{i}" for i in range(featurization_config["model"]["num_features"])
]
test_mosaiks_folder_path = Path("test_outputs")

In [None]:
model = RCF(
    featurization_config["model"]["num_features"],
    featurization_config["model"]["kernel_size"],
    len(satellite_config["bands"]),
)

# Load Data

In [None]:
request_points_gdf = utl.load_df_w_latlons_to_gdf(
    dataset_name=featurization_config["coord_set_name"]
)

In [None]:
points_gdf = request_points_gdf.iloc[
    :1000
]  # .sample(1000, random_state=0) # Select random 1000 points (for testing)

In [None]:
test_points_gdf = points_gdf.iloc[1:2]

# Simple non-dask run

In [None]:
from mosaiks.featurize import fetch_image_refs, create_data_loader, create_features
import pandas as pd

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    test_points_gdf, 
    featurization_config['satellite_search_params']
)

data_loader = create_data_loader(
    points_gdf_with_stac=points_gdf_with_stac,
    satellite_params=satellite_config,
    batch_size=featurization_config["model"]["batch_size"],
)

X_features = create_features(
    dataloader=data_loader,
    n_features=featurization_config["model"]["num_features"],
    model=model,
    device=featurization_config["model"]["device"],
    min_image_edge=satellite_config["min_image_edge"],
)

df = pd.DataFrame(
    data=X_features, index=test_points_gdf.index, columns=mosaiks_column_names
)

utl.save_dataframe(
    df=df, file_path=f"{test_mosaiks_folder_path}/df_TEST.csv"
)

In [None]:
df

# Dask runs

In [None]:
from mosaiks.dask_run import get_local_dask_client

In [None]:
client = get_local_dask_client(
    featurization_config["dask"]["n_workers"],
    featurization_config["dask"]["threads_per_worker"],
)
client

In [None]:
mosaiks_folder_path = utl.make_output_folder_path(featurization_config)
os.makedirs(mosaiks_folder_path, exist_ok=True)

## Method 1 (Preferred) - Queued Futures

In [None]:
from mosaiks.dask_run import run_queued_futures_pipeline

In [None]:
%%time

# note that stopping this cell does not stop the dask cluster processing what
# is currently submitted. Use client.restart().
run_queued_futures_pipeline(
    points_gdf,
    client=client,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    column_names=mosaiks_column_names,
    save_folder_path=mosaiks_folder_path,
)

## Method 2 - Batched Delayed

In [None]:
from mosaiks.dask_run import run_batched_delayed_pipeline

In [None]:
%%time

# note that stopping this cell does not stop the dask cluster processing what
# is currently submitted. Use client.restart().
run_batched_delayed_pipeline(
    points_gdf,
    client=client,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    column_names=mosaiks_column_names,
    save_folder_path=mosaiks_folder_path,
)

## Method 3 - Unbatched Delayed

In [None]:
from mosaiks.dask_run import make_delayed_task, make_all_delayed_tasks,

### Single task

In [None]:
delayed_task = make_delayed_task(
    test_points_gdf,
    model,
    featurization_config,
    satellite_config,
    mosaiks_column_names,
    test_mosaiks_folder_path,
    "TEST_dask_delayed.csv",
)

In [None]:
delayed_task.visualize(filename=f"{test_mosaiks_folder_path}/TEST_dask_graph.png")

In [None]:
delayed_task.compute()

### Full run

In [None]:
from dask.distributed import progress

In [None]:
delayed_task_list = make_all_delayed_tasks(
    points_gdf=points_gdf,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    save_folder_path=mosaiks_folder_path,
)

In [None]:
persist_tasks = client.persist(delayed_task_list)
progress(persist_tasks)

# Load checkpoint files and combine

In [None]:
# simple test
# data = utl.load_dataframe(mosaiks_folder_path / "df_0.parquet.gzip")
# data

In [None]:
checkpoint_filenames = utl.get_filtered_filenames(
    folder_path=mosaiks_folder_path, prefix="df_"
)

combined_df = utl.load_and_combine_dataframes(
    folder_path=mosaiks_folder_path, filenames=checkpoint_filenames
)

combined_df = combined_df.join(points_gdf[["Lat", "Lon", "shrid"]])

print("Dataset size in memory (MB):", combined_df.memory_usage().sum() / 1000000)