# Generating STAC MLM Items from a YAML Config

This notebook reads the `config.yaml` file and creates a STAC Item using mlm extension.

In [1]:
!pip install pystac stac-model==0.4.0 resoterre



In [2]:
from dataclasses import asdict
from datetime import datetime
from pathlib import Path

import yaml
import shapely
from dateutil.parser import parse as parse_dt
import pystac
from pystac import STACValidationError
from pystac.extensions.datacube import DatacubeExtension, Dimension, Variable
from stac_model.input import InputStructure, ModelInput
from stac_model.output import ModelOutput, ModelResult
from stac_model.schema import MLModelExtension, MLModelProperties

from resoterre.ml.network_manager import nb_of_parameters
from resoterre.ml.neural_networks_unet import UNet, UNetConfig


## Utils function

In [3]:
def dt_to_epoch(dt: datetime) -> float:
    """Convert datetime to seconds since Unix epoch."""
    return dt.timestamp()

## Function to convert a config to STAC item




In [4]:
def create_stac_item_from_config(yaml_file: str):
    """
    Create a STAC Item (and associated Collection) from a YAML model configuration file.

    Args:
        yaml_file (str): Path to the YAML configuration file describing the model.

    Returns
    -------
        pystac.Item: The generated STAC Item.
    """
    yaml_path = Path(yaml_file)
    with yaml_path.open() as f:
        cfg = yaml.safe_load(f)

    train_cfg = cfg["train_dataset"]
    model_path = cfg["path_models"]
    model_name = "UNet"

    config = UNetConfig(in_channels=4, out_channels=2, depth=3, initial_nb_of_hidden_channels=32)
    model = UNet(**asdict(config))
    nb_params = nb_of_parameters(model)
    module = model.__class__.__module__
    class_name = model.__class__.__name__
    arch = f"{module}.{class_name}"
    framework = cfg["framework"]
    framework_version = cfg["framework_version"]

    # --- Input/Output shapes ---
    input_variables = train_cfg.get("rdps_variables", [])
    output_variables = train_cfg.get("hrdps_variables", [])
    height, width = 64, 64  # TODO: Add to config

    # --- Input structure ---
    input_struct = InputStructure(
        shape=[-1, len(input_variables), height, width],
        dim_order=["time", "latitude", "longitude", "variables"],
        data_type="float32",
    )

    model_input = ModelInput(name="rdps_inputs", variables=input_variables, input=input_struct)

    # --- Output structure ---
    result_struct = ModelResult(
        shape=[-1, len(output_variables), height, width],
        dim_order=["time", "latitude", "longitude", "variables"],
        data_type="float32",
    )

    model_output = ModelOutput(
        name="hrdps_outputs",
        variables=output_variables,
        tasks=["super-resolution", "downscaling"],
        result=result_struct,
    )

    # --- ML Model Properties ---
    ml_model_meta = MLModelProperties(
        name=f"{model_name} RDPS HRDPS Downscaling",
        architecture=arch,
        tasks={"super-resolution", "downscaling"},
        framework=framework,
        framework_version=framework_version,
        accelerator=cfg.get("device", "cuda"),
        pretrained=True,
        pretrained_source="Custom RDPS-HRDPS dataset",
        input=[model_input],
        output=[model_output],
        total_parameters=nb_params,
    )

    # TODO Validate and add preprocessing function

    # --- Assets ---
    assets = {
        "model": pystac.Asset(
            title=f"{model_name} checkpoint",
            description=f"{model_name} trained on RDPS HRDPS downscaling task.",
            href=model_path,
            media_type="application/octet-stream; application/pytorch",
            roles=["mlm:model", "mlm:weights"],
            extra_fields={"mlm:artifact_type": "torch.save"},
        ),
        "source_code": pystac.Asset(
            title=f"Source code for {model_name}",
            description="GitHub repo of the PyTorch model",
            href="https://github.com/Ouranosinc/resoterre",
            media_type="text/html",
            roles=["mlm:source_code", "code"],
        ),
    }

    # --- Temporal / spatial extent ---
    start_dt = parse_dt(train_cfg["start_datetime"])
    end_dt = parse_dt(train_cfg["end_datetime"])
    start_dt_str = start_dt.isoformat() + "Z"
    end_dt_str = end_dt.isoformat() + "Z"

    bbox = [-7.88219, 37.13739, 27.91165, 58.21798]
    geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__

    # --- STAC Catalog ---
    catalog_href = "./catalog.json"  # Catalog file (create if missing)
    if not Path(catalog_href).exists():
        catalog = pystac.Catalog(
            id="ml-model-catalog",
            description="Catalog for ML model collections",
            title="ML Model Catalog",
        )
        catalog.set_self_href(catalog_href)
        catalog.save()
    else:
        catalog = pystac.Catalog.from_file(catalog_href)

    # --- STAC Collection ---
    collection_name = "ml-model-package"
    collection = pystac.Collection(
        id=collection_name,
        title="Machine Learning Model packaging",
        description="Collection of ML model items",
        strategy=pystac.layout.AsIsLayoutStrategy(),
        extent=pystac.Extent(
            spatial=pystac.SpatialExtent([bbox]),
            temporal=pystac.TemporalExtent([[start_dt, end_dt]]),
        ),
    )
    collection_href = "./ml-model-package/collection.json"
    collection.set_self_href(collection_href)

    # --- STAC Item ---
    item_name = f"{model_name.lower()}_rdps_to_hrdps"
    item = pystac.Item(
        id=item_name,
        collection=collection.id,
        geometry=geometry,
        bbox=bbox,
        datetime=None,
        properties={
            "description": f"{model_name} trained to downscale RDPS meteorological data to HRDPS resolution.",
            "start_datetime": start_dt_str,
            "end_datetime": end_dt_str,
            "datetime": None,
        },
        assets=assets,
        extra_fields={"mlm:entrypoint": arch},  # Path to the model class
        stac_extensions=[
            MLModelExtension.get_schema_uri(),
            "https://stac-extensions.github.io/datacube/v2.3.0/schema.json",
        ],
    )

    # --- Apply ML Model Extension ---
    item_mlm = MLModelExtension.ext(item, add_if_missing=True)
    item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))

    # --- Apply Datacube Extension ---
    item_dc = DatacubeExtension.ext(item_mlm.item, add_if_missing=True)

    # --- Define cube:dimensions ---

    levels = [850, 700, 500, 250]  # allowed pressures

    dimensions = {
        "time": Dimension(
            properties={
                "type": "geometry",
                "dim_type": "temporal",
                "description": "Time dimension",
                "bbox": [dt_to_epoch(start_dt), dt_to_epoch(start_dt), dt_to_epoch(end_dt), dt_to_epoch(end_dt)],
                "extent": [start_dt_str, end_dt_str],
                "values": [start_dt_str, end_dt_str],
                "axis": "t",
                "unit": "seconds",
            }
        ),
        "latitude": Dimension(
            properties={
                "type": "geometry",
                "dim_type": "spatial",
                "description": "Latitude",
                "bbox": [bbox[1], bbox[1], bbox[3], bbox[3]],
                "extent": [bbox[1], bbox[3]],
                "axis": "y",
                "unit": "degrees",
                "reference_system": "EPSG:4326",
            }
        ),
        "longitude": Dimension(
            properties={
                "type": "geometry",
                "dim_type": "spatial",
                "description": "Longitude",
                "bbox": [bbox[0], bbox[0], bbox[2], bbox[2]],
                "extent": [bbox[0], bbox[2]],
                "axis": "x",
                "unit": "degrees",
                "reference_system": "EPSG:4326",
            }
        ),
        "level": Dimension(
            properties={
                "type": "spatial",
                "dim_type": "vertical",
                "description": "Pressure levels",
                "bbox": [min(levels), max(levels), min(levels), max(levels)],
                "extent": [min(levels), max(levels)],
                "values": levels,
                "axis": "z",
                "unit": "hPa",
            }
        ),
    }

    # --- Define cube:variables ---
    variables = {}
    for var in input_variables + output_variables:
        if var in ["TT_model_levels", "HU850", "GZ500", "PC", "HRDPS_P_TT_10000"]:
            dims = ["time", "latitude", "longitude", "level"]
        elif var in ["PR", "HRDPS_P_PR_SFC"]:
            dims = ["time", "latitude", "longitude"]  # Surface-level, no vertical
        variables[var] = Variable(
            properties={
                "dimensions": dims,
                "type": "data",
                "data_type": "float32",
            }
        )

    # --- Apply dimensions and variables to Item ---
    item_mlm = MLModelExtension.ext(item, add_if_missing=True)
    item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))
    item_dc = DatacubeExtension.ext(item_mlm.item, add_if_missing=True)
    item_dc.apply(dimensions=dimensions, variables=variables)

    # --- Now set a fixed flat path ---
    item_self_href = Path("ml-model-package") / item_dc.item.id / f"{item_name}.json"
    item.set_self_href(str(item_self_href))

    # --- Add or update catalog, collection and item ---
    existing_collection = next((c for c in catalog.get_children() if c.id == collection.id), None)
    if not existing_collection:
        catalog.add_child(collection)
    else:
        collection = existing_collection

    existing_item = next((i for i in collection.get_all_items() if i.id == item.id), None)
    if existing_item:
        collection.remove_item(existing_item.id)
    collection.add_item(item)

    # --- Normalize collection links only ---
    collection.normalize_hrefs(root_href=".", strategy=pystac.layout.AsIsLayoutStrategy())

    # --- Save collection and item
    collection.save_object(include_self_link=True)
    item.save_object(include_self_link=True)

    # --- Add collection to catalog and save catalog last
    catalog.save(catalog_type=pystac.CatalogType.SELF_CONTAINED)

    # --- Print relative paths ---
    cwd = Path.cwd()
    rel_item = Path(item.self_href).relative_to(cwd)
    rel_collection = Path(collection.self_href).relative_to(cwd)
    print(f"STAC Item saved to {rel_item}")
    print(f"Collection saved to {rel_collection} with {len(list(collection.get_all_items()))} item(s)")

## Example Usage 

In [5]:
yaml_file = "../configs/downscaling/downscaling_training_rdps_to_hrdps.yaml"
stac_item = create_stac_item_from_config(yaml_file)

STAC Item saved to ml-model-package/unet_rdps_to_hrdps/unet_rdps_to_hrdps.json
Collection saved to ml-model-package/collection.json with 1 item(s)


## Validate the collection and items

In [6]:
collection_path = "./ml-model-package/collection.json"
item_path = "./ml-model-package/unet_rdps_to_hrdps/unet_rdps_to_hrdps.json"

# # --- Load Collection ---
collection = pystac.Collection.from_file(collection_path)
print("Collection loaded successfully")
print(f"Collection ID: {collection.id}")
print(f"Number of items: {len(list(collection.get_all_items()))}")

# Validate collection
try:
    collection.validate()
    print("Collection is valid")
except STACValidationError as e:
    print(f"Collection validation error: {e}")

# --- Load Item ---
item = pystac.Item.from_file(item_path)
print("\nItem loaded successfully")
print(f"Item ID: {item.id}")

try:
    item.validate()
    print("Item is valid")
except STACValidationError as e:
    print(f"Item validation error: {e}")

# list assets and extensions
print("\nAssets in item:")
for key, asset in item.assets.items():
    print(f" - {key}: {asset.href}")

print("\nSTAC extensions used in item:")
print(item.stac_extensions)

Collection loaded successfully
Collection ID: ml-model-package
Number of items: 1
Collection is valid

Item loaded successfully
Item ID: unet_rdps_to_hrdps
Item is valid

Assets in item:
 - model: /path/to/models
 - source_code: https://github.com/Ouranosinc/resoterre

STAC extensions used in item:
['https://stac-extensions.github.io/mlm/v1.5.0/schema.json', 'https://stac-extensions.github.io/datacube/v2.3.0/schema.json']
