In [None]:
import mosaiks.utils as utl
mosaiks_git_link = utl.get_mosaiks_package_link("update_unit_tests")

In [None]:
mosaiks_git_link

In [None]:
from dask_cloudprovider.azure import AzureVMCluster
cluster = AzureVMCluster(
    resource_group="leaninnvoation",
    vnet="aks-vnet-mosaik",
    security_group="aks-sg-mosaik",
    location ="westeurope",
    env_vars = {'EXTRA_PIP_PACKAGES':mosaiks_git_link},
    n_workers=4
)

In [None]:
# need py 3.10
from dask.distributed import Client
client = Client(cluster)
client

In [None]:
import sys
import os
import warnings

sys.path += ["../"]
warnings.filterwarnings("ignore")

from pathlib import Path
import src.mosaiks.utils as utl
from mosaiks.featurize import RCF

In [None]:
rasterio_config = utl.load_yaml_config("rasterioc_config.yaml")
os.environ.update(rasterio_config)

featurization_config = utl.load_yaml_config("featurisation.yaml")

satellite_config = utl.load_yaml_config("satellite_config.yaml")
satellite_config = satellite_config[
    featurization_config["satellite_search_params"]["satellite_name"]
]

In [None]:
mosaiks_col_names = [
    f"mosaiks_{i}" for i in range(featurization_config["model"]["num_features"])
]
test_mosaiks_folder_path = Path("test_outputs")

In [None]:
model = RCF(
    featurization_config["model"]["num_features"],
    featurization_config["model"]["kernel_size"],
    len(satellite_config["bands"]),
)

In [None]:
request_points_gdf = utl.load_df_w_latlons_to_gdf(
    dataset_name=featurization_config["coord_set"]["coord_set_name"]
)

In [None]:
points_gdf = request_points_gdf.iloc[:1000]
test_points_gdf = points_gdf.iloc[1:2]

In [None]:
mosaiks_folder_path = Path("test_outputs") #utl.make_output_folder_path(featurization_config)
os.makedirs(mosaiks_folder_path, exist_ok=True)

In [None]:
from mosaiks.fetch import fetch_image_refs, create_data_loader
from mosaiks.featurize import create_features, make_result_df
import pandas as pd

In [None]:
%%time
points_gdf_with_stac = fetch_image_refs(
    test_points_gdf, 
    featurization_config['satellite_search_params']
)


In [None]:

data_loader = create_data_loader(
    points_gdf_with_stac=points_gdf_with_stac,
    satellite_params=satellite_config,
    batch_size=featurization_config["model"]["batch_size"],
)

X_features = create_features(
    dataloader=data_loader,
    n_features=featurization_config["model"]["num_features"],
    model=model,
    device=featurization_config["model"]["device"],
    min_image_edge=satellite_config["min_image_edge"],
)

df = pd.DataFrame(
    data=X_features, index=test_points_gdf.index, columns=mosaiks_col_names
)

utl.save_dataframe(
    df=df, file_path=f"{test_mosaiks_folder_path}/df_TEST2.csv"
)

In [None]:
from mosaiks.dask import run_queued_futures_pipeline

In [None]:
mosaiks_col_names = [
    f"mosaiks_{i}" for i in range(featurization_config["model"]["num_features"])
]
test_mosaiks_folder_path = Path("test_outputs")

In [None]:
%%time

# note that stopping this cell does not stop the dask cluster processing what
# is currently submitted. Use client.restart().
run_queued_futures_pipeline(
    test_points_gdf,
    client=client,
    model=model,
    featurization_config=featurization_config,
    satellite_config=satellite_config,
    col_names=mosaiks_col_names,
    save_folder_path=mosaiks_folder_path,
)

In [None]:
client.close()
cluster.close()