In [None]:
# reloads
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import warnings
from pathlib import Path

sys.path += ["../"]
warnings.filterwarnings("ignore")
os.environ["USE_PYGEOS"] = "0"

In [None]:
import src.mosaiks.utils as utl
from mosaiks.featurize import RCF

## Params and input data

In [None]:
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

featurization_config = utl.load_yaml_config("featurisation.yaml")

satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_config["satellite_search_params"]["satellite_name"]
]

In [None]:
mosaiks_col_names = [
    f"mosaiks_{i}" for i in range(featurization_config["model"]["num_features"])
]

In [None]:
model = RCF(
    featurization_config["model"]["num_features"],
    featurization_config["model"]["kernel_size"],
    len(satellite_config["bands"]),
)

In [None]:
request_points_gdf = utl.load_df_w_latlons_to_gdf(
    dataset_name=featurization_config["coord_set"]["coord_set_name"]
)
test_points_gdf = request_points_gdf.iloc[1:2]

In [None]:
local_folder_path = Path("test_outputs") #utl.make_output_folder_path(featurization_config)
s3_bucket_path = "s3://gs-test-then-delete/test_mosaiks_df.parquet.gzip"

## Local test run

In [None]:
from mosaiks.fetch import fetch_image_refs, create_data_loader
from mosaiks.featurize import create_features, make_result_df

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    test_points_gdf, 
    featurization_config['satellite_search_params']
)

data_loader = create_data_loader(
    points_gdf_with_stac=points_gdf_with_stac,
    satellite_params=satellite_config,
    batch_size=featurization_config["model"]["batch_size"],
)

X_features = create_features(
    dataloader=data_loader,
    n_features=featurization_config["model"]["num_features"],
    model=model,
    device=featurization_config["model"]["device"],
    min_image_edge=satellite_config["min_image_edge"],
)

result_df = make_result_df(
    features=X_features,
    mosaiks_col_names=mosaiks_col_names,
    context_gdf=points_gdf_with_stac,
    context_cols_to_keep=featurization_config["coord_set"]["context_cols_to_keep"],
)
result_df

In [None]:
utl.save_dataframe(
    df=result_df, file_path=f"{local_folder_path}/df_TEST.csv"
)

In [None]:
secrets = utl.load_yaml_config("secrets.yaml")
os.environ["AWS_ACCESS_KEY_ID"] = secrets["AWS_ACCESS_KEY_ID"]
os.environ["AWS_SECRET_ACCESS_KEY"] = secrets["AWS_SECRET_ACCESS_KEY"]

utl.save_dataframe(
    df=result_df, 
    file_path=s3_bucket_path,
)

## Cloud run

In [None]:
from dask_cloudprovider.azure import AzureVMCluster

In [None]:
mosaiks_git_link = utl.get_mosaiks_package_link("main")

In [None]:
cluster = AzureVMCluster(
    resource_group="leaninnvoation",
    vnet="aks-vnet-mosaik",
    security_group="aks-sg-mosaik",
    location="westeurope",
    env_vars={
        "EXTRA_PIP_PACKAGES": mosaiks_git_link,
        "USE_PYGEOS": "0",
        "AWS_ACCESS_KEY_ID": secrets["AWS_ACCESS_KEY_ID"],
        "AWS_SECRET_ACCESS_KEY": secrets["AWS_SECRET_ACCESS_KEY"],
    },
    n_workers=4,
)

In [None]:
# need py 3.10
from dask.distributed import Client
client = Client(cluster)
client

In [None]:
from mosaiks.dask import run_queued_futures_pipeline, run_batched_delayed_pipeline

In [None]:
%%time

run_queued_futures_pipeline(
    test_points_gdf,
    client=client,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    col_names=mosaiks_col_names,
    save_folder_path=s3_bucket_path,
)

In [None]:
%%time

run_batched_delayed_pipeline(
    test_points_gdf,
    client=client,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    col_names=mosaiks_col_names,
    save_folder_path=s3_bucket_path,
)

In [None]:
client.close()
cluster.close()