# STAC EuroSAT

This notebook demonstrates how to convert annotations provided by [EuroSAT](https://github.com/phelber/EuroSAT) dataset
into STAC-compatible definitions with extensions relevant for machine learning tasks.
Notably, the STAC [Label](https://github.com/stac-extensions/label)and [Scientific](https://github.com/stac-extensions/scientific)
extensions are used to reference the labeled annotations from train, validation and test splits, and provide citation reference to
the original work respectively.

To facilitate parsing of EuroSAT metadata itself,
the [torchgeo.datasets.EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#torchgeo.datasets.EuroSAT)
class will be used to handle the metadata extraction process, parsing the labeled data hierarchy, and generate the splits
definition and sample generation from them.

## First Step

Below are equivalent `torchgeo.datasets.EuroSAT100` and `torchgeo.datasets.EuroSAT` classes that define a subset and the complete dataset
respectively. While developing or editing the STAC generation pipeline, it is recommended to work with the 100 subset variation to speed up
the process and directly observe an overview of the expected result.

In [6]:
# pick one:
from torchgeo.datasets import EuroSAT100 as DatasetEuroSAT  # subset of (6 train, 2 val, 2 test) image samples per class (10)
# from torchgeo.datasets import EuroSAT as DatasetEuroSAT  # full dataset

## General configurations

In [7]:
import os

DATA_ROOT_DIR = os.path.abspath("../data")
CATALOG_ROOT_DIR = DATA_ROOT_DIR
EUROSAT_ROOT_DIR = os.path.join(DATA_ROOT_DIR, "EuroSAT")
EUROSAT_DATA_DIR = os.path.join(EUROSAT_ROOT_DIR, "data")
EUROSAT_STAC_DIR = os.path.join(EUROSAT_ROOT_DIR, "stac")
EUROSAT_STAC_URL = "https://example.com/"  # base URL where samples would be accessible from (links in STAC)

os.makedirs(EUROSAT_DATA_DIR, exist_ok=True)
os.makedirs(EUROSAT_STAC_DIR, exist_ok=True)

## STAC Definitions

Following types are **NOT** used for "strong" type checking.
They are provided as reference of the expected STAC properties and for quick validation of the structure by IDEs.

### Base Definitions

In [20]:
import datetime
import os
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from typing_extensions import NotRequired, Required


STAC_VERSION = "1.0.0"

STAC_CATALOG_EXTENSIONS = []
STAC_CATALOG_SCHEMAS = [
    f"https://schemas.stacspec.org/v{STAC_VERSION}/catalog-spec/json-schema/catalog.json",
] + STAC_CATALOG_EXTENSIONS

STAC_COLLECTION_EXTENSIONS = [
    "https://stac-extensions.github.io/ml-aoi/v0.1.0/schema.json",
    "https://stac-extensions.github.io/version/v1.0.0/schema.json",
]
STAC_COLLECTION_SCHEMAS = [
    f"https://schemas.stacspec.org/v{STAC_VERSION}/collection-spec/json-schema/collection.json",
] + STAC_COLLECTION_EXTENSIONS

STAC_ITEM_EXTENSIONS = [
    "https://stac-extensions.github.io/file/v1.0.0/schema.json",
    "https://stac-extensions.github.io/label/v1.0.1/schema.json",
    "https://stac-extensions.github.io/ml-aoi/v0.1.0/schema.json",
    "https://stac-extensions.github.io/scientific/v1.0.0/schema.json",
    "https://stac-extensions.github.io/version/v1.0.0/schema.json",
]
STAC_ITEM_SCHEMAS = [
    f"https://schemas.stacspec.org/v{STAC_VERSION}/item-spec/json-schema/item.json",
] + STAC_ITEM_EXTENSIONS

# technically, tuples would be better for bbox/point, but not the types used in JSON
Number = Union[int, float]
BoundingBox = List[Number]  # 4 value
Point = List[Number]  # 2 values
DateTimeInterval = List[Union[str, None]]
GeoJSONGeometry = TypedDict(
    "GeoJSONGeometry",
    {
        "type": Literal["Polygon", "MultiPolygon"],  # others exist, but not really applicable for this case
        "coordinates": List[Point],  # at least 4 normally for square/polygon bbox, but more valid if multi-polygon
    }
)
SpatialExtent = TypedDict(
    "SpatialExtent",
    {
        "bbox": Required[List[BoundingBox]],
    }
)
TemporalExtent = TypedDict(
    "TemporalExtent",
    {
        "interval": Required[List[DateTimeInterval]],
    }
)
Extent = TypedDict(
    "Extent",
    {
        "spatial": Required[SpatialExtent],
        "temporal": Required[TemporalExtent],
    }
)
Provider = TypedDict(
    "Provider",
    {
        "name": str,
        "roles": List[str],
        "url": str,
    }
)
Link = TypedDict(
    "Link",
    {
        "rel": str,
        "href": str,
        "type": str,  # media-type
        "title": NotRequired[str],
    },
    total=False,
)
STACMetadata = TypedDict(
    "STACMetadata",
    {
        "stac_version": Required[str],
        "type": Required[Literal["Catalog", "Collection", "Feature"]],  # NB: Feature == STAC Item
        "id": Required[str],
        "title": NotRequired[str],
        "description": NotRequired[str],
        "links": Required[List[Link]],
    }
)

### STAC Catalog

In [21]:
STACCatalog = STACMetadata  # only requires "links" to contain the STAC Collection as "child"

### STAC Collection and Extensions

In [22]:
STACMetadataCollection = TypedDict(
    "STACMetadataCollection",
    {
        "stac_extensions": Required[List[str]],
        "version": Required[str],
        "keywords": NotRequired[List[str]],
        "license": Required[str],  # anything, but commonly "CC-BY-SA-4.0"
    }
)
STACMetadataCollection = Union[STACMetadataCollection, STACMetadata]
STACExtensionVersion = TypedDict(
    "STACExtensionVersion",
    {
        "version": Required[str],
        "deprecated": NotRequired[bool],
        "experimental": NotRequired[bool],
    }
)
STACCoreCollection = TypedDict(
    "STACCoreCollection",
    {
        "extent": Required[Extent],
        "providers": NotRequired[List[Provider]],
    }
)
STACLabelRef = TypedDict(
    "STACLabelRef",
    {
        "title": str,
        "href": str,    # URL to GeoJSON FeatureCollection or GeoTiff/COG
        "type": str,    # media-type
    }
)
STACLabelAssets = TypedDict(
    "STACLabelAssets",
    {
        "labels": STACLabelRef,
        "raster": STACLabelRef,
    }
)
STACLabelClass = TypedDict(
    "STACLabelClass",
    {
        # name that can define a "category" of classe names
        # those categories should be specified in Features under the keys defined by 'label:properties'
        "name": Required[Union[str, None]],
        # all the applicable classes that should be part of the "category"
        "classes": Required[Union[List[str], List[int]]],
    }
)
STACLabelCount = TypedDict(
    "STACLabelCount",
    {
        "name": str,  # class
        "count": int,
    }
)
STACLabelOverview = TypedDict(
    "STACLabelOverview",
    {
        "property_key": str,
        "counts": List[STACLabelCount],
    }
)
STACLabelProperties = TypedDict(
    "STACLabelProperties",
    {
        # properties in the linked 'labels' asset with GeoJSON FeatureCollection
        # those properties should be provided for each Feature in the GeoJSON
        # the values of those properties should contain the keys from 'label:classes'
        "label:properties": Required[List[str]],
        "label:classes": Required[List[STACLabelClass]],
        "label:type": Required[Literal["raster", "vector"]],
        "label:description": Required[str],
        "label:methods": NotRequired[List[Literal["manual", "automatic"]]],
        "label:tasks": NotRequired[List[Literal["classification", "detection", "segmentation", "regression"]]],
        "label:overviews": NotRequired[List[Dict[str, Any]]],
    }
)
STACExtensionLabel = TypedDict(
    "STACExtensionLabel",
    {
        "assets": Required[STACLabelAssets],
        "properties": Required[Union[STACLabelProperties, STACExtensionVersion]],
    }
)
STACCitation = TypedDict(
    "STACCitation",
    {
        "doi": str,
        "citation": str,
    }
)
STACExtensionScientific = TypedDict(
    "STACExtensionScientific",
    {
        # how to cite this collection
        "sci:doi": Required[str],  # ++ "Link" with 'cite-as' using the DOI reference (RFC-8574)
        "sci:citation": Required[str],
        # related work/citations that use this STAC data collection
        "sci:publications": NotRequired[List[STACCitation]],
    }
)
STACExtendedCollection = Union[
    STACMetadataCollection,
    STACCoreCollection,
    STACExtensionScientific,
    STACExtensionVersion,
]

### STAC Item and Assets

In [23]:
STACCoreMetadataItem = Union[STACMetadata, ]
STACCoreItemProperties = TypedDict(
    "STACCoreItemProperties",
    {
        "datetime": Required[str],
        "license": Required[str],
    },
    total=False,
)
STACCoreFeatureItem = TypedDict(
    "STACCoreFeatureItem",
    {
        "type": Literal["Feature"],
        "bbox": Required[BoundingBox],
        "geometry": Required[GeoJSONGeometry],
        "assets": Required[List[Dict[str, Any]]],
        "properties": Required[STACCoreItemProperties],
        "collection": str,
    }
)
STACExtendedItem = Union[
    STACCoreMetadataItem,
    STACCoreFeatureItem,
    STACExtensionLabel,
]

### STAC Metadata Definition

To make the calling functions more succinct, start by defining constant metadata references used by STAC Collections.

In [24]:
EUROSAT_STAC_COLLECTION_BASE: STACExtendedCollection = {
    "stac_version": STAC_VERSION,
    "stac_extensions": STAC_COLLECTION_EXTENSIONS,
    "type": "Collection",
    "version": "1",
    "experimental": True,
    # https://github.com/phelber/EuroSAT/blob/master/README.md
    "sci:doi": "arXiv:1709.00029",
    "sci:citation": (
        "Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification. "
        "Patrick Helber, Benjamin Bischke, Andreas Dengel, Damian Borth. "
        "IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 2019."
    ),
    "sci:publications": [
        {
            "doi": "10.1109/IGARSS.2018.8519248",
            "citation": (
                "Introducing EuroSAT: A Novel Dataset and Deep Learning Benchmark for Land Use and Land Cover Classification. "
                "Patrick Helber, Benjamin Bischke, Andreas Dengel. 2018 "
                "IEEE International Geoscience and Remote Sensing Symposium, 2018."
            ),
        }
    ],
    "links": [
        {
            "rel": "cite-as",
            "href": "https://arxiv.org/abs/1709.00029",
            "title": "EuroSAT: A Novel Dataset and Deep Learning Benchmark for Land Use and Land Cover Classification",
        }
    ]
    # other items to fill by script
}

## Populate STAC Catalog, Collections, Items and Assets from EuroSAT Dataset

**Note** <br>
Because we need the metadata to populate STAC Collections, Items and Assets, the `download=True` parameter is used.
However, we don't need the actual data (imagery pixel values), but instead the metadata that each GeoTiff contains.
Therefore, we override the `__getitem__` method to retrieve only metadata, by bypassing image loading/conversion, and make parsing faster.
Nevertheless, it can take some time to download and extract the ZIP contents on the first run.

In [37]:
import json
from copy import deepcopy
from datetime import datetime

import rasterio
from rasterio.crs import CRS
from rasterio.warp import transform_bounds, transform_geom
from osgeo import gdal
from shapely.geometry import box
from torchvision.datasets.folder import pil_loader
from PIL.Image import Resampling


SampleMetadata = TypedDict(
    "SampleMetadata",
    {
        "image": str,
        "index": int,
        "label": str,
        "class": str,
    }
)

EPSG_4326 =  CRS.from_epsg(4326)  # required by STAC as standard for all Geo references


class DataLoaderEuroSAT(DatasetEuroSAT):
    def __getitem__(
        self,
        sample_index: int,
    ) -> Optional[SampleMetadata]:  # type: ignore  # mismatch 'Dict[str, Tensor]' on purpose
        img_path, target_index = self.samples[sample_index]
        class_name = self.classes[target_index]
        return {"image": img_path, "index": sample_index, "label": str(target_index), "class": class_name}


def save_thumbnail(src_path: str, png_path: str) -> None:
    if not os.path.isfile(png_path):
        os.makedirs(os.path.dirname(png_path), exist_ok=True)
        img = pil_loader(src_path)
        w_size = 64
        if img.size != (w_size, w_size):
            w_scale = w_size / float(img.size[0])
            h_size = int(float(img.size[1]) * float(w_scale))
            img = img.resize((w_size, h_size), Resampling.LANCZOS)
        img.save(png_path)


def convert_sample_to_stac_item(sample: SampleMetadata) -> STACExtendedItem:
    #img = gdal.Open(path)
    #geo = img.GetGeoTransform()
    img = rasterio.open(sample["image"])
    png_path = sample["image"].replace("/tif/", "/png/").replace(".tif", ".png")
    save_thumbnail(sample["image"], png_path)
    bbox_bounds = transform_bounds(
        img.crs,
        EPSG_4326,
        *img.bounds,
    )
    geom = transform_geom(
        img.crs,
        EPSG_4326,
        box(*bbox_bounds),
    )
    image_name = os.path.splitext(os.path.split(sample["image"])[-1])[0]
    label_geojson = geom.__geo_interface__,  # GeoJSON
    label_path = sample["image"].replace("/tif/", "/label/").replace(".tif", ".geojson")
    os.makedirs(os.path.dirname(label_path), exist_ok=True)
    save_json(label_geojson, label_path)
    class_name = sample["class"]
    stac_item: STACExtendedItem = {
        "stac_version": STAC_VERSION,
        "stac_extensions": STAC_ITEM_EXTENSIONS,
        "type": "Feature",
        "bbox": bbox_bounds,
        "geometry": geom.__geo_interface__,  # GeoJSON
        "assets": {
            "labels": {
                "title": f"Labels for image {image_name} with {class_name} class",
                "href": get_url(label_path),
                "type": "application/geo+json"
            },
            "raster": {
                "title": f"Raster {image_name} with {class_name} class",
                "href": get_url(sample["image"]),
                "type": "image/tiff; application=geotiff",  # add "; profile=cloud-optimized" if applicable
            },
            "thumbnail": {
                "title": f"Preview of {image_name}.",
                "href": get_url(png_path),
                "type": "image/png",
            }
        },
        "properties": {
            "datetime": datetime.utcnow().isoformat(),
            "license": "CC-BY-4.0",
            "version": "1",
            "label:properties": [
                "class"
            ],
            "label:method": ["manual"],
            "label:description": "Land-cover area classification on Sentinel-2 image.",
            "label:classes": [
                {
                    "name": sample["class"],
                    "classes": [str(sample["index"])],
                }
            ],
            "label:overviews": [
                # basic overview since each sample has its own STAC Item
                {
                    "property_key": "class",
                    "counts": [{"name": sample["class"], "count": 1}],
                }
            ]
        }
    }
    return stac_item


def get_url(url_path: str) -> str:
    url_path = url_path.replace(CATALOG_ROOT_DIR, "").lstrip("/")
    url_path = os.path.join(EUROSAT_STAC_URL, url_path)
    return url_path


def save_json(data: Dict[str, Any], json_file_path: str) -> None:
    with open(json_file_path, mode="w", encoding="utf-8") as fs:
        json.dump(data, fs, indent=2, ensure_ascii=False, sort_keys=False)


def load_json(json_file_path: str) -> Dict[str, Any]:
    with open(json_file_path, mode="r", encoding="utf-8") as fs:
        return json.load(fs)

In [38]:
CATALOG_URL = get_url("catalog.json")
CATALOG_ROOT_LINK = {
  "rel": "root",
  "href": CATALOG_URL,
  "type": "application/json"
}


def generate_stac_collections():
    stac_catalog_collection_links: List[Link] = []
    for split in DatasetEuroSAT.splits:
        stac_collection = EUROSAT_STAC_COLLECTION_BASE.copy()
        stac_collection_path = os.path.join(EUROSAT_STAC_DIR, split, "collection.json")
        stac_collection_url = get_url(stac_collection_path)
        stac_collection_link: Link = {
            "rel": "collection",
            "href": stac_collection_url,
            "type": "application/json",
            "title":  f"EuroSAT STAC Collection with samples from '{split}' split.",
        }
        stac_collection_parent = stac_collection_link.copy()
        stac_collection_parent["rel"] = "parent"

        data_loader = DataLoaderEuroSAT(root=EUROSAT_DATA_DIR, split=split, download=True)
        for sample in data_loader:
            label = sample["label"]
            if not label:  # ignore tiles by themselves without annotations
                continue
            index = sample["index"]
            item = convert_sample_to_stac_item(sample)
            item["links"] = [
                CATALOG_ROOT_LINK,
                stac_collection_parent,
                stac_collection_link,
            ]
            iloc = f"{split}/item-{index}.json"
            path = os.path.join(EUROSAT_STAC_DIR, split, iloc)
            save_json(item, path)
            stac_collection["links"].append(
                {
                    "rel": "item",
                    "href": get_url(iloc),
                    "type": "application/geo+json",
                }
            )

        stac_collection_self = stac_collection_link.copy()
        stac_collection_self["rel"] = "self"
        stac_catalog_collection_links.append(stac_collection_self)
        stac_collection["links"].extend([
            CATALOG_ROOT_LINK,
            stac_collection_link,
            {
                "rel": "parent",
                "href": get_url("catalog.json"),
                "type": "application/json",
                "title": "STAC Catalog",
            }
        ])
        save_json(stac_collection, stac_collection_path)

    return stac_catalog_collection_links

In [None]:
stac_catalog_collections = generate_stac_collections()
catalog_collections_refs = deepcopy(stac_catalog_collections)
for link in catalog_collections_refs:
    link["rel"] = "child"


catalog: STACCatalog = {
    "id": "example",
    "type": "Catalog",
    "title": "Example STAC Catalog",
    "stac_version": STAC_VERSION,
    "description": "Example catalog with annotated label collections.",
    "links": [
         CATALOG_ROOT_LINK,
        {
            "rel": "self",
            "href": CATALOG_URL,
            "type": "application/json"
        }
    ] + catalog_collections_refs
}

More samples per pixel than can be decoded: 13
More samples per pixel than can be decoded: 13


## STAC Schema Validation

Verify that all items that were generated respect the various STAC schemas.

In [None]:
from jsonschema.exceptions import ValidationError
from jsonschema.validators import validate

stac_catalog_files = []
stac_collection_files = []
stac_item_files = []
for root, _, files in os.walk(CATALOG_ROOT_DIR):
    for file in files:
        file_path = os.path.join(root, file)
        if file == "catalog.json":
            stac_catalog_files.append(file_path)
        elif file == "collection.json":
            stac_collection_files.append(file_path)
        elif file.startswith("item-") and file.endswith(".json"):
            stac_item_files.append(file_path)

assert stac_catalog_files
assert stac_collection_files
assert stac_item_files

stac_file = None
schema_file = None
try:
    for files, schemas in [
        (stac_catalog_files, STAC_CATALOG_SCHEMAS),
        (stac_collection_files, STAC_COLLECTION_SCHEMAS),
        (stac_item_files, STAC_ITEM_SCHEMAS),
    ]:
        for schema_file in schemas:
            schema = load_json(schema_file)
            for stac_file in files:
                content = load_json(stac_file)
                validate(content, schema)
except ValidationError as exc:
    raise AssertionError(f"Failed [{stac_file}] validation against [{schema_file}]") from exc
