# Generating STAC Machine Learning Model (MLM) Items from a YAML Config Folder

This notebook reads all YAML configuration files in the [`configs/downscaling`](../configs/downscaling) folder
and creates a STAC Item using the [MLM Extension](https://github.com/stac-extensions/mlm).


In [None]:
!pip install pystac stac-model==0.4.0 resoterre

In [None]:
import json
from collections import OrderedDict
from dataclasses import asdict
from pathlib import Path

import pystac
import shapely
import yaml
from dateutil.parser import parse as parse_dt
from pystac import STACValidationError
from pystac.extensions.datacube import DatacubeExtension, Dimension, Variable
from stac_model.input import InputStructure, ModelInput
from stac_model.output import ModelOutput, ModelResult
from stac_model.schema import MLModelExtension, MLModelProperties

from resoterre.datasets.hrdps.hrdps_variables import hrdps_variables
from resoterre.datasets.rdps.rdps_variables import rdps_variables
from resoterre.ml.network_manager import nb_of_parameters
from resoterre.ml.neural_networks_unet import UNet, UNetConfig

## Utils function

In [None]:
def config_aggregator(config_folder: str | Path) -> dict:
    """
    Aggregate all YAML configuration files in a folder into a single dictionary.

    Args:
        config_folder (str | Path): Path to the folder containing YAML config files.

    Returns
    -------
        dict: Aggregated configuration from all YAML files.
    """
    config_folder = Path(config_folder)
    if not config_folder.exists():
        raise FileNotFoundError(f"Config folder not found: {config_folder}")

    aggregated_cfg = {}

    for yaml_file in sorted(config_folder.glob("*.yaml")):
        with yaml_file.open("r", encoding="utf-8") as f:
            cfg = yaml.safe_load(f) or {}  # treat empty YAML as empty dict
            aggregated_cfg.update(cfg)

    return aggregated_cfg


def reorder_stac_json(json_path: str | Path, top_keys: list[str]):
    """Reorder keys in a saved STAC JSON file so that top_keys appear right after 'id'."""
    json_path = Path(json_path)
    with json_path.open() as f:
        data = json.load(f)

    ordered = OrderedDict()
    for key in ["type", "stac_version", "stac_extensions", "id"]:
        if key in data:
            ordered[key] = data.pop(key)
    for key in top_keys:
        if key in data:
            ordered[key] = data.pop(key)
    for k, v in data.items():
        ordered[k] = v

    with json_path.open("w") as f:
        json.dump(ordered, f, indent=2)


def get_variable_dimensions(var_name: str) -> list[str]:
    """
    Get datacube dimensions for a variable using the VariableHandler collections.

    Variables with level dimension:
    - Have vertical_level attribute set (e.g., TT850, HU700)
    - Name contains 'model_levels' or 'pressure_levels' (spans multiple levels)
    - Name ends with numeric level suffix (e.g., HRDPS_P_TT_10000)
    """
    base_dims = ["time", "latitude", "longitude"]

    # Check for multi-level naming patterns first (works for any variable)
    if "model_levels" in var_name or "pressure_levels" in var_name:
        return base_dims + ["level"]

    # Try to find the variable in rdps_variables or hrdps_variables
    var_handler = None
    if var_name in rdps_variables:
        var_handler = rdps_variables[var_name]
    elif var_name in hrdps_variables:
        var_handler = hrdps_variables[var_name]

    # If variable has vertical_level set, it has a level dimension
    if var_handler is not None and var_handler.vertical_level is not None:
        return base_dims + ["level"]

    # For variables not in handlers, check for numeric level suffix (e.g., HRDPS_P_TT_10000)
    if var_handler is None:
        parts = var_name.split("_")
        last_part = parts[-1]
        if last_part.isdigit() and len(last_part) >= 3:
            return base_dims + ["level"]

    return base_dims

## Function to convert a config to STAC item




In [None]:
def create_stac_item_from_config(yaml_config: dict):
    """
    Create a STAC Item (and associated Collection) from a YAML model configuration dicttionary.

    Args:
        yaml_config (dict): Aggregated YAML configuration dictionary describing the model.

    Returns
    -------
        pystac.Item: The generated STAC Item.
    """
    cfg = yaml_config

    train_cfg = cfg["train_dataset"]
    model_path = cfg["path_models"] + "/unet-model.ckpt"
    model_name = "UNet"
    in_channels = cfg["networks"]["UNet"]["in_channels"]
    out_channels = cfg["networks"]["UNet"]["out_channels"]
    depth = cfg["networks"]["UNet"]["depth"]
    initial_nb_of_hidden_channels = cfg["networks"]["UNet"]["initial_nb_of_hidden_channels"]

    # --- Model architecture and parameters ---
    config = UNetConfig(
        in_channels=in_channels,
        out_channels=out_channels,
        depth=depth,
        initial_nb_of_hidden_channels=initial_nb_of_hidden_channels,
    )
    model = UNet(**asdict(config))
    nb_params = nb_of_parameters(model)
    module = model.__class__.__module__
    class_name = model.__class__.__name__
    arch = f"{module}.{class_name}"
    framework = cfg["framework"]
    framework_version = cfg["framework_version"]

    # --- Input/Output shapes ---
    input_variables = train_cfg.get("rdps_variables", [])
    output_variables = train_cfg.get("hrdps_variables", [])
    height, width = 64, 64  # TODO: Add to config

    # --- Input structure ---
    input_struct = InputStructure(
        shape=[-1, len(input_variables), height, width],
        dim_order=["time", "variables", "latitude", "longitude"],
        data_type="float32",
    )

    # --- Processing expression for pre-processing function ---
    preprocessing_command = cfg.get("preprocessing_command", None)
    if preprocessing_command:
        # Remove line continuations (backslash + newline) and collapse to single line
        clean_command = preprocessing_command.replace("\\\n", " ").replace("\n", " ")
        # Collapse multiple spaces into one
        clean_command = " ".join(clean_command.split())
        pre_processing_function = {
            "format": "snakemake",
            "expression": clean_command,
        }

    model_input = ModelInput(
        name="rdps_inputs",
        variables=input_variables,
        input=input_struct,
        pre_processing_function=pre_processing_function,
    )

    # --- Output structure ---
    result_struct = ModelResult(
        shape=[-1, len(output_variables), height, width],
        dim_order=["time", "variables", "latitude", "longitude"],
        data_type="float32",
    )

    model_output = ModelOutput(
        name="hrdps_outputs",
        variables=output_variables,
        tasks=["super-resolution", "downscaling"],
        result=result_struct,
    )

    # --- ML Model Properties ---
    ml_model_meta = MLModelProperties(
        name=f"{model_name} RDPS HRDPS Downscaling",
        architecture=arch,
        tasks={"super-resolution", "downscaling"},
        framework=framework,
        framework_version=framework_version,
        accelerator=cfg.get("device", "cuda"),
        pretrained=True,
        pretrained_source="Custom RDPS-HRDPS dataset",
        input=[model_input],
        output=[model_output],
        total_parameters=nb_params,
    )

    # --- Assets ---
    assets = {
        "model": pystac.Asset(
            title=f"{model_name} checkpoint",
            description=f"{model_name} trained on RDPS HRDPS downscaling task.",
            href=model_path,
            media_type="application/octet-stream; application/pytorch",
            roles=["mlm:model", "mlm:weights"],
            extra_fields={"mlm:artifact_type": "torch.save"},
        ),
        "source_code": pystac.Asset(
            title=f"Source code for {model_name}",
            description="GitHub repo of the PyTorch model",
            href="https://github.com/Ouranosinc/resoterre",
            media_type="text/html",
            roles=["mlm:source_code", "code"],
        ),
    }

    # --- Additional hyperparameters ---
    hyperparameters = {
        "nb_of_new_epochs": cfg["nb_of_new_epochs"],
        "networks": cfg["networks"],
        "data_loader": cfg["data_loader"],
        "optimizers": cfg["optimizers"],
    }

    # --- Temporal / spatial extent ---
    start_dt = parse_dt(train_cfg["start_datetime"])
    end_dt = parse_dt(train_cfg["end_datetime"])
    start_dt_str = start_dt.isoformat() + "Z"
    end_dt_str = end_dt.isoformat() + "Z"

    # --- Bounding box and geometry ---
    coords = cfg["north_america_8km_grid_extent_coordinates"]
    bbox = [coords["lon_min"], coords["lat_min"], coords["lon_max"], coords["lat_max"]]
    geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__

    # --- STAC Catalog ---
    catalog_href = "./catalog.json"  # Catalog file (create if missing)
    if not Path(catalog_href).exists():
        catalog = pystac.Catalog(
            id="ml-model-catalog",
            description="Catalog for ML model collections",
            title="ML Model Catalog",
        )
        catalog.set_self_href(catalog_href)
        catalog.save()
    else:
        catalog = pystac.Catalog.from_file(catalog_href)

    # --- STAC Collection ---
    collection_name = "ml-model-package"
    collection = pystac.Collection(
        id=collection_name,
        title="Machine Learning Model packaging",
        description="Collection of ML model items",
        strategy=pystac.layout.AsIsLayoutStrategy(),
        extent=pystac.Extent(
            spatial=pystac.SpatialExtent([bbox]),
            temporal=pystac.TemporalExtent([[start_dt, end_dt]]),
        ),
    )
    collection_href = "./ml-model-package/collection.json"
    collection.set_self_href(collection_href)

    # --- STAC Item ---
    item_name = f"{model_name.lower()}_rdps_to_hrdps"
    item = pystac.Item(
        id=item_name,
        collection=collection.id,
        geometry=geometry,
        bbox=bbox,
        datetime=None,
        properties={
            "description": f"{model_name} trained to downscale RDPS meteorological data to HRDPS resolution.",
            "start_datetime": start_dt_str,
            "end_datetime": end_dt_str,
            "datetime": None,
        },
        assets=assets,
        extra_fields={
            "mlm:entrypoint": arch,
            "mlm:hyperparameters": hyperparameters,
        },  # Path to the model class and hyperparameters
    )

    # --- Apply ML Model Extension ---
    item_mlm = MLModelExtension.ext(item, add_if_missing=True)
    item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))

    # --- Apply Datacube Extension ---
    item_dc = DatacubeExtension.ext(item_mlm.item, add_if_missing=True)

    # --- Define cube:dimensions ---

    levels = [850, 700, 500, 250]  # allowed pressures

    dimensions = {
        "time": Dimension(
            properties={
                "type": "temporal",
                "description": "Time dimension",
                "extent": [start_dt_str, end_dt_str],
                "unit": "s",
            }
        ),
        "latitude": Dimension(
            properties={
                "type": "spatial",
                "description": "Latitude",
                "extent": [bbox[1], bbox[3]],
                "axis": "y",
                "unit": "degree",
                "reference_system": "EPSG:4326",
            }
        ),
        "longitude": Dimension(
            properties={
                "type": "spatial",
                "description": "Longitude",
                "extent": [bbox[0], bbox[2]],
                "axis": "x",
                "unit": "degree",
                "reference_system": "EPSG:4326",
            }
        ),
        "level": Dimension(
            properties={
                "type": "spatial",
                "description": "Pressure levels",
                "extent": [min(levels), max(levels)],
                "values": levels,
                "axis": "z",
                "unit": "hPa",
            }
        ),
    }

    # --- Define cube:variables using rdps_variables and hrdps_variables ---
    variables = {}
    for var in input_variables + output_variables:
        dims = get_variable_dimensions(var)
        variables[var] = Variable(
            properties={
                "dimensions": dims,
                "type": "data",
                "data_type": "float32",
            }
        )

    # --- Apply dimensions and variables to Item ---
    item_mlm = MLModelExtension.ext(item, add_if_missing=True)
    item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))
    item_dc = DatacubeExtension.ext(item_mlm.item, add_if_missing=True)
    item_dc.apply(dimensions=dimensions, variables=variables)

    # --- Now set a fixed flat path ---
    item_self_href = Path("ml-model-package") / item_dc.item.id / f"{item_name}.json"
    item.set_self_href(str(item_self_href))

    # --- Add or update catalog, collection and item ---
    existing_collection = next((c for c in catalog.get_children() if c.id == collection.id), None)
    if not existing_collection:
        catalog.add_child(collection)
    else:
        collection = existing_collection

    existing_item = next((i for i in collection.get_all_items() if i.id == item.id), None)
    if existing_item:
        collection.remove_item(existing_item.id)
    collection.add_item(item)

    # --- Normalize collection links only ---

    # Use '.' as root_href to ensure all relative links are local.
    # This avoids absolute paths in the generated STAC JSON
    collection.normalize_hrefs(root_href=".", strategy=pystac.layout.AsIsLayoutStrategy())

    # --- Save collection and item
    collection.save_object(include_self_link=True)
    item.save_object(include_self_link=True)

    # --- Add collection to catalog and save catalog last

    # catalog_type=SELF_CONTAINED ensures that all references (collections, items)
    # are stored locally within the catalog folder structure respecting relative paths.
    # The use of catalog is needed to avoid having absolute paths in the generated STAC JSON file.
    catalog.save(catalog_type=pystac.CatalogType.SELF_CONTAINED)

    # --- Print relative paths ---
    cwd = Path.cwd()
    rel_item = Path(item.self_href).relative_to(cwd)
    rel_collection = Path(collection.self_href).relative_to(cwd)
    print(f"STAC Item saved to [{rel_item}]")
    print(f"Collection saved to [{rel_collection}] with {len(list(collection.get_all_items()))} item(s)")

    # --- Fix JSON order ---
    top_keys = ["collection", "mlm:entrypoint", "mlm:hyperparameters"]
    reorder_stac_json(item.self_href, top_keys)

## Example Usage 

In [8]:
config_folder = "../configs/downscaling"
aggregated_config = config_aggregator(config_folder)
stac_item = create_stac_item_from_config(aggregated_config)

STAC Item saved to [ml-model-package/unet_rdps_to_hrdps/unet_rdps_to_hrdps.json]
Collection saved to [ml-model-package/collection.json] with 1 item(s)


## Validate the collection and items

In [7]:
collection_path = "./ml-model-package/collection.json"
item_path = "./ml-model-package/unet_rdps_to_hrdps/unet_rdps_to_hrdps.json"

# # --- Load Collection ---
collection = pystac.Collection.from_file(collection_path)
print("Collection loaded successfully")
print(f"Collection ID: {collection.id}")
print(f"Number of items: {len(list(collection.get_all_items()))}")

# Validate collection
try:
    collection.validate()
    print("Collection is valid")
except STACValidationError as e:
    print(f"Collection validation error: {e}")

# --- Load Item ---
item = pystac.Item.from_file(item_path)
print("\nItem loaded successfully")
print(f"Item ID: {item.id}")

try:
    item.validate()
    print("Item is valid")
except STACValidationError as e:
    print(f"Item validation error: {e}")

# list assets and extensions
print("\nAssets in item:")
for key, asset in item.assets.items():
    print(f" - {key}: {asset.href}")

print("\nSTAC extensions used in item:")
print(item.stac_extensions)

Collection loaded successfully
Collection ID: ml-model-package
Number of items: 1
Collection is valid

Item loaded successfully
Item ID: unet_rdps_to_hrdps
Item is valid

Assets in item:
 - model: /path/to/models/unet-model.ckpt
 - source_code: https://github.com/Ouranosinc/resoterre

STAC extensions used in item:
['https://stac-extensions.github.io/mlm/v1.5.0/schema.json', 'https://stac-extensions.github.io/datacube/v2.2.0/schema.json']
