In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path += ["../"]

import warnings
warnings.filterwarnings("ignore")

In [None]:
import os
import src.mosaiks.utils as utl
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

#### `pip install` MOSAIKS

From local folder:

In [None]:
!pip install -e ..

From GitHub package:

🚨🚨 **Make sure you update github token in the secrets file** 🚨🚨 

In [None]:
secrets = utl.load_yaml_config("../config/secrets.yml")
GITHUB_TOKEN = secrets["GITHUB_TOKEN"]
mosaiks_package_link = f"git+https://{GITHUB_TOKEN}@github.com/IDinsight/ds_nudge_up@as-package"

In [None]:
!pip uninstall mosaiks -y
!pip install {mosaiks_package_link} --upgrade

# Setup Dask Cluster and Client

## Local Cluster

4 workers with 4 threads each seem to work best. A lot of time a thread is waiting on data to load so CPU is underutilized.

In [None]:
import logging
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=4, processes=True, threads_per_worker=4, silence_logs=logging.ERROR)
client = Client(cluster)
client

## Gateway cluster

In [None]:
# from dask_gateway import Gateway
# import dask_gateway
# from dask.distributed import PipInstall

# gateway = Gateway()
# options = gateway.cluster_options()
# options

In [None]:
# from dask.distributed import PipInstall

# cluster = gateway.new_cluster(options)
# client = cluster.get_client()
# print(cluster.dashboard_link)

# plugin = PipInstall(packages=[mosaiks_package_link], pip_options=["--upgrade"], restart=False)
# client.register_worker_plugin(plugin)

# cluster.scale(10)

In [None]:
# cluster.shutdown()

# Load params

In [None]:
from mosaiks.featurize import *

from dask import delayed
from dask.distributed import as_completed
from time import sleep
import pandas as pd
import numpy as np

In [None]:
featurization_params = utl.load_yaml_config("featurisation.yaml")
satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[featurization_params['satellite_search_params']['satellite_name']]
data_sources = utl.load_yaml_config('data_catalog.yaml')

# Load point coords

In [None]:
points_gdf = utl.load_points_gdf(**data_sources['request_points_centroids'])

In [None]:
focus_states_id_dict = {
    20:"jharkhand",
    22:"chhattisgarh",
    8:"rajasthan",
    23:"madhya pradesh",
    18:"assam",
    16:"tripura",
}
focus_states_filter = points_gdf["pc11_s_id"].isin(focus_states_id_dict.keys())

In [None]:
points_gdf_focus = points_gdf[focus_states_filter]
points_gdf_focus.shape

In [None]:
points_gdf_focus["shrid"].drop_duplicates(keep='first').shape

In [None]:
points_gdf_focus.plot(markersize=0.01)

# Fetch image stac refs

`fetch_image_refs` now returns a dask dataframe and is not yet computed. So it finishes quite quickly.

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    points_gdf_focus, 
    featurization_params['dask']['n_partitions'],
    featurization_params['satellite_search_params']
)

# Define delayed objects

We use the `delayed` decorator to turn our function into a delayed function. This means it will not run immediately when called but instead return a delayed object that can be run later

In [None]:
@delayed
def partition_run(df, satellite_config, featurization_params, model, device):
    
    data_loader = create_data_loader(df, satellite_config, featurization_params['batch_size'])
    X_features = create_features(data_loader, featurization_params['num_features'], len(df), 
                             model, device, satellite_config['min_image_edge'])
    
    df = pd.DataFrame(X_features, index=df.index.copy())
    
    return df

We want to convert our dask dataframe into "delayed" objects. Each partition is now a delayed pandas dataframe and can be passed to our delayed function above

In [None]:
partitions = points_gdf_with_stac.to_delayed()

In [None]:
model = RCF(featurization_params['num_features'], 
            featurization_params['kernel_size'], 
            len(satellite_config['bands']))

Batch size of 10 seems to be optimal balance between maximally using the CPU and not blowing up the memory

In [None]:
featurization_params['batch_size']

# Run in parallel

## Trial run

The cell below will only run it for 8 of the partitions. That seems to be about how many we can do in parallel on a local cluster. We may be able to do more on a Gateway Cluster once that is working.

There are also better schemes. For example, kick off another partitions whenever one finishes. That might be a better use of resources.

In [None]:
%%time

dfs = []
for i, p in enumerate(partitions[:2]):
    f = partition_run(p, satellite_config, featurization_params, model, 'cuda', dask_key_name=f'run_{i}')
    dfs.append(f)
dfs = client.compute(dfs, )

df_list = []
for f in as_completed(dfs):
    df_list.append(f.result())


In [None]:
df_list[0]

8 partitions should take ~7-8 minutes on an MPC GPU instance. So that's <1 minute per partition. If nothing goes wrong, the whole job should finish in <4 hours.

In theory, objects should get garbage collected once there are no references to them. But it seems to take forever (or never!) for python to do that. Possibly since we have a lot of nested things and a model object that we are still holding a reference to.

Restarting the cluster seems to be the sure way of clearing worker memory.

In [None]:
_ = client.restart()

## Full run

This is going to create 200 dataframes - one for each partition. If any fail, we can always just rerun that single component.

In [None]:
from datetime import datetime

In [None]:
N_PARTITIONS = len(partitions)
N_PER_RUN = 8
START_IDX = 44
str_column_names = [str(i) for i in range(featurization_params['num_features'])]

In [None]:
p_ids = np.arange(START_IDX, N_PARTITIONS + N_PER_RUN, N_PER_RUN)

for p_start_id, p_end_id in zip(p_ids[:-1], p_ids[1:]):
    now = datetime.now().strftime("%d-%b %H:%M:%S")
    print(f"{now} Running batch: ", p_start_id, "to", p_end_id-1)
    
    delayed_dfs = []
    for i, p in enumerate(partitions[p_start_id:p_end_id]):
        f = partition_run(p, satellite_config, featurization_params, model, 
                          featurization_params['device'], dask_key_name=f'features_{p_start_id + i}')
        delayed_dfs.append(f)
    futures_dfs = client.compute(delayed_dfs)
    
    for f in as_completed(futures_dfs):
        try:
            df = f.result()
            df.columns = str_column_names
            df.to_parquet(f'data/df_{f.key}.parquet.gzip', compression='gzip')
        except Exception as e:
            print(f"Partition {f.key} failed. Error:", e)
        
    client.restart()
    sleep(5)

In [None]:
futures_dfs

In [None]:
# for f in as_completed(futures_dfs[-3:]):
#     df = f.result()
#     df.columns = str_column_names
#     df.to_parquet(f'data/df_{f.key}.parquet.gzip', compression='gzip')

In [None]:
# client.shutdown()

## Diagnostics

In [None]:
p_44 = partitions[44].compute()
stac_items = p_44.stac_item.unique()

In [None]:
p_44.loc[500042]

In [None]:
p_44.loc[500042].stac_item

In [None]:
import pyproj
import stackstac

In [None]:
# for i, row in p_44.iterrows():

i = 500042 # 500015 works
row = p_44.loc[i]   
    
print("Index:", i)
stac_item = row["stac_item"]
lat = row["Lat"]
lon = row["Lon"]
buffer = 1200

crs = stac_item.properties["proj:epsg"]
x_utm, y_utm = pyproj.Proj(crs)(lon, lat)
x_min, x_max = x_utm - buffer, x_utm + buffer
y_min, y_max = y_utm - buffer, y_utm + buffer

xarray = stackstac.stack(
    stac_item,
    assets=satellite_config["bands"],
    resolution=satellite_config["resolution"],
    rescale=False,
    dtype=np.uint8,
    bounds=[x_min, y_min, x_max, y_max],
    fill_value=0,
    # snap_bounds=False
)

print(xarray.values.shape)

In [None]:
x_min_p, y_min_p, x_max_p, y_max_p = p_44.loc[i].stac_item.properties["proj:bbox"]

In [None]:
print("image", x_min_p, y_min_p, x_max_p, y_max_p)
print("crop", x_min, y_min, x_max, y_max)

In [None]:
from shapely.geometry import Polygon
import geopandas as gpd

image_square = Polygon([[x_min_p, y_min],[x_min_p, y_max_p], [x_max_p, y_max_p], [x_max_p, y_min_p]])
crop_square = Polygon([[x_min, y_min],[x_min, y_max], [x_max, y_max], [x_max, y_min]])
intersect = image_square.intersection(crop_square)

In [None]:
intersect

In [None]:
g = gpd.GeoDataFrame({"item":["image", "crop", "intersect"]}, geometry=[image_square, crop_square, intersect])

In [None]:
g.plot(column="item", legend=True, figsize=(7,7))

## Re-run failed partitions

Use this to just run partitions that failed

In [None]:
# %%time

# FAILED_IDX = [44]

# delayed_dfs = []
# for i in FAILED_IDX:
#     p = partitions[i]
#     f = partition_run(p, satellite_config, featurization_params, model, 
#                       featurization_params['device'], dask_key_name=f'features_{i}')
#     delayed_dfs.append(f)
#     futures_dfs = client.compute(delayed_dfs)
    
#     for f in as_completed(futures_dfs):
#         f.result().to_csv(f'data/df_{f.key}.csv')

In [None]:
# _ = client.restart()

# Load checkpoint files and combine

In [None]:
import pandas as pd
import os

path = './data'
all_files = os.listdir(path)

# Select only CSV files from the folder
parquet_files = sorted([file for file in all_files if file.endswith('.gzip')])
parquet_files = parquet_files[1:3]

In [None]:
pd.Series(parquet_files).to_csv("./data/file_list.csv")

In [None]:
dfs = []
for filename in parquet_files:
    
    df = pd.read_parquet("./data/"+filename)
    dfs.append(df)

combined_df = pd.concat(dfs, axis=0)
print("Dataset size in memory (MB):", combined_df.memory_usage().sum() / 1000000)

In [None]:
combined_df.shape

In [None]:
combined_df.sort_index()

In [None]:
combined_df.to_parquet("centroid_features_landsat_TEMP.parquet.gzip", compression="gzip")

In [None]:
df = pd.read_parquet("df_features_198.parquet.gzip")
df