In [None]:
import math
import os
import requests
import warnings

import geoarrow.pyarrow as ga
import numpy as np
import pystac_client
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import yaml
from box import Box
from torchvision.transforms import v2

from stacchip.indexer import Sentinel2Indexer
from stacchip.chipper import Chipper

warnings.filterwarnings("ignore")

### Find data for AOI
The first step is to find STAC items of imagery that we want to use to create embeddings. In this example we are going to use Earth Genome's composite dataset which comes with a great STAC catalog.

We are also going to create embeddings along time so that we have multiple embeddings for the same location at different moments in time.

In [None]:
# Point over Monchique Portugal
lat, lon = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

In [None]:
# Optimize GDAL settings for cloud optimized reading
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["AWS_REQUEST_PAYER"] = "requester"

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.get_all_items()

# Reduce to one per date (there might be some duplicates
# based on the location)
items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")

To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case.

In [None]:
chips = []
datetimes = []
bboxs = []
chip_ids = []
item_ids = []

for item in items:
    print(f"Working on {item}")

    # Index the chips in the item
    indexer = Sentinel2Indexer(item)

    # Instanciate the chipper
    chipper = Chipper(indexer, assets=["red", "green", "blue", "nir", "scl"])

    # Get first chip for the "image" asset key
    for idx, (x, y, chip) in enumerate(chipper):
        if idx > 2:
            break
        del chip["scl"]
        chips.append(chip)
        datetimes.append(item.datetime)
        bboxs.append(indexer.get_chip_bbox(x, y))
        chip_ids.append((x, y))
        item_ids.append(item.id)

In [None]:
pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])
pixels.shape

In [None]:
# Extract mean, std, and wavelengths from metadata
platform = "sentinel-2-l2a"
# Retrieve the file content from the URL

url = (
    "https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml"
)
response = requests.get(url, allow_redirects=True)

# Convert bytes to string
content = response.content.decode("utf-8")

# Load the yaml
content = yaml.safe_load(content)

metadata = Box(content)
mean = []
std = []
waves = []
# Use the band names to get the correct values in the correct order.
for band in chips[0].keys():
    mean.append(metadata[platform].bands.mean[band])
    std.append(metadata[platform].bands.std[band])
    waves.append(metadata[platform].bands.wavelength[band])

# Prepare the normalization transform function using the mean and std values.
transform = v2.Compose(
    [
        v2.Normalize(mean=mean, std=std),
    ]
)

In [None]:
def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24

    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))


times = [normalize_timestamp(dat) for dat in datetimes]
week_norm = [dat[0] for dat in times]
hour_norm = [dat[1] for dat in times]


# Prep lat/lon embedding using the
def normalize_latlon(lat, lon):
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180

    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))


latlons = [normalize_latlon(lat, lon)] * len(times)
lat_norm = [dat[0] for dat in latlons]
lon_norm = [dat[1] for dat in latlons]

# Prep gsd
gsd = [10]

# Normalize pixels
pixels = transform(pixels)

In [None]:
datacube = {
    "pixels": torch.tensor(pixels, dtype=torch.float32),
    "time": torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32),
    "latlon": torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32),
    "waves": torch.tensor(waves, dtype=torch.float32),
    "gsd": torch.tensor(gsd, dtype=torch.float32),
}

In [None]:
for k, v in datacube.items():
    print(k, v.shape)

### Clay Embedder

#### Load the embedder that is stored in ExportedProgram format using **cpu**.

In [None]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.pt2

In [None]:
ep_embedder_cpu = torch.export.load("clay-v1-encoder-cpu.pt2").module()

In [None]:
%%time
with torch.no_grad():
    embeddings = ep_embedder_cpu(datacube)
datacube["pixels"].shape, embeddings.shape

For each chip, we have an embedding of size `768`

#### Load the embedder that is stored in ExportedProgram format using **gpu**.

In [None]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder.pt2

In [None]:
datacube = {k: v.to("cuda") for k, v in datacube.items()}
ep_embedder = torch.export.load("clay-v1-encoder.pt2").module()

In [None]:
%%time
with torch.no_grad():
    embeddings = ep_embedder(datacube)
datacube["pixels"].shape, embeddings.shape

For each chip, we have an embedding of size `768`

#### Load the embedder that is stored in ONNX format using **cpu**.

In [None]:
import onnx
import onnxruntime as ort

In [None]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/compiled/v1.0/clay-v1-encoder-cpu.onnx

In [None]:
datacube = {k: v.to("cpu") for k, v in datacube.items()}
onnx_embedder = ort.InferenceSession(
    "clay-v1-encoder-cpu.onnx", providers=["CPUExecutionProvider"]
)

In [None]:
%%time
embeddings = onnx_embedder.run(
    [],
    {
        "cube": datacube["pixels"].numpy(),
        "time": datacube["time"].numpy(),
        "latlon": datacube["latlon"].numpy(),
        "waves": datacube["waves"].numpy(),
        "gsd": datacube["gsd"].numpy(),
    },
)[0]
embeddings.shape

For each chip, we have an embedding of size `768`

### Store the results

We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk.

In [None]:
# Write data to pyarrow table
index = {
    "datetimes": datetimes,
    "chip_ids": chip_ids,
    "item_ids": item_ids,
    "emeddings": [np.ascontiguousarray(dat) for dat in embeddings],
    "geometry": ga.as_geoarrow([dat.wkt for dat in bboxs]),
}
table = pa.table(index)
table

In [None]:
pq.write_table(table, "embeddings.parquet")