In [22]:
import math
import os
import warnings

import geoarrow.pyarrow as ga
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pystac_client
import requests
import torch
import yaml
from box import Box
from stacchip.chipper import Chipper
from stacchip.indexer import Sentinel2Indexer
from torchvision.transforms import v2

warnings.filterwarnings("ignore")

### Find data for AOI
The first step is to find STAC items of imagery that we want to use to create embeddings. In this example we are going to use Earth Genome's composite dataset which comes with a great STAC catalog.

We are also going to create embeddings along time so that we have multiple embeddings for the same location at different moments in time.

In [2]:
# Point over Monchique Portugal
lat, lon = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

In [3]:
# Optimize GDAL settings for cloud optimized reading
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["AWS_REQUEST_PAYER"] = "requester"

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.get_all_items()

# Reduce to one per date (there might be some duplicates
# based on the location)
items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")

Found 12 items


To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case.

In [4]:
chips = []
datetimes = []
bboxs = []
chip_ids = []
item_ids = []

for item in items:
    print(f"Working on {item}")

    # Index the chips in the item
    indexer = Sentinel2Indexer(item)

    # Instanciate the chipper
    chipper = Chipper(indexer, assets=["red", "green", "blue", "nir", "scl"])

    # Get first chip for the "image" asset key
    for idx, (x, y, chip) in enumerate(chipper):
        if idx > 2:
            break
        del chip["scl"]
        chips.append(chip)
        datetimes.append(item.datetime)
        bboxs.append(indexer.get_chip_bbox(x, y))
        chip_ids.append((x, y))
        item_ids.append(item.id)

Working on <Item id=S2A_29SNB_20180828_1_L2A>
Working on <Item id=S2B_29SNB_20180823_1_L2A>
Working on <Item id=S2A_29SNB_20180818_1_L2A>
Working on <Item id=S2B_29SNB_20180813_0_L2A>
Working on <Item id=S2A_29SNB_20180808_1_L2A>
Working on <Item id=S2B_29SNB_20180803_1_L2A>
Working on <Item id=S2A_29SNB_20180729_1_L2A>
Working on <Item id=S2B_29SNB_20180724_0_L2A>
Working on <Item id=S2A_29SNB_20180719_0_L2A>
Working on <Item id=S2B_29SNB_20180714_0_L2A>
Working on <Item id=S2A_29SNB_20180709_0_L2A>
Working on <Item id=S2B_29SNB_20180704_0_L2A>


In [5]:
pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])
pixels.shape

(36, 4, 256, 256)

In [29]:
metadata_path = "https://raw.githubusercontent.com/Clay-foundation/model/refs/heads/main/configs/metadata.yaml"
headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}

response = requests.get(metadata_path, headers=headers)
content = response.text

In [30]:
# Extract mean, std, and wavelengths from metadata
platform = "sentinel-2-l2a"

# Load the metadata
content = yaml.safe_load(content)

metadata = Box(content)
mean = []
std = []
waves = []
# Use the band names to get the correct values in the correct order.
for band in chips[0].keys():
    mean.append(metadata[platform].bands.mean[band])
    std.append(metadata[platform].bands.std[band])
    waves.append(metadata[platform].bands.wavelength[band])

# Prepare the normalization transform function using the mean and std values.
transform = v2.Compose(
    [
        v2.Normalize(mean=mean, std=std),
    ]
)

In [31]:
def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24

    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))


times = [normalize_timestamp(dat) for dat in datetimes]
week_norm = [dat[0] for dat in times]
hour_norm = [dat[1] for dat in times]


# Prep lat/lon embedding using the
def normalize_latlon(lat, lon):
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180

    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))


latlons = [normalize_latlon(lat, lon)] * len(times)
lat_norm = [dat[0] for dat in latlons]
lon_norm = [dat[1] for dat in latlons]

# Prep gsd
gsd = [10]

# Normalize pixels
pixels = transform(pixels)

In [32]:
datacube = {
    "pixels": torch.tensor(pixels, dtype=torch.float32),
    "time": torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32),
    "latlon": torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32),
    "waves": torch.tensor(waves, dtype=torch.float32),
    "gsd": torch.tensor(gsd, dtype=torch.float32),
}

In [33]:
for k, v in datacube.items():
    print(k, v.shape)

pixels torch.Size([36, 4, 256, 256])
time torch.Size([36, 4])
latlon torch.Size([36, 4])
waves torch.Size([4])
gsd torch.Size([1])


### Clay Embedder

#### Load the embedder that is stored in ExportedProgram format using **cpu**.

In [10]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/compiled/clay-v1.5-encoder-cpu.pt2

In [36]:
%%time
ep_embedder_cpu = torch.export.load("clay-v1.5-encoder-cpu.pt2").module()

CPU times: user 2.18 s, sys: 842 ms, total: 3.02 s
Wall time: 3.02 s


In [11]:
%%time
with torch.no_grad():
    embeddings = ep_embedder_cpu(datacube)
datacube["pixels"].shape, embeddings.shape

CPU times: user 7min 8s, sys: 19.3 s, total: 7min 27s
Wall time: 2min 13s


(torch.Size([36, 4, 256, 256]), torch.Size([36, 1024]))

For each chip, we have an embedding of size `1024`

#### Load the embedder that is stored in ExportedProgram format using **gpu**.

In [12]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/compiled/clay-v1.5-encoder.pt2

In [None]:
datacube = {k: v.to("cuda") for k, v in datacube.items()}

In [35]:
%%time
ep_embedder = torch.export.load("clay-v1.5-encoder.pt2").module()

CPU times: user 2.33 s, sys: 1 s, total: 3.33 s
Wall time: 3.33 s


In [37]:
%%time
with torch.no_grad():
    embeddings = ep_embedder(datacube)
datacube["pixels"].shape, embeddings.shape

CPU times: user 47.1 ms, sys: 0 ns, total: 47.1 ms
Wall time: 17 ms


(torch.Size([36, 4, 256, 256]), torch.Size([36, 1024]))

For each chip, we have an embedding of size `1024`

#### Load the embedder that is stored in ONNX format using **cpu**.

In [15]:
import onnxruntime as ort

In [16]:
!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/compiled/clay-v1.5-encoder-cpu.onnx

In [17]:
datacube = {k: v.to("cpu") for k, v in datacube.items()}
onnx_embedder = ort.InferenceSession(
    "clay-v1.5-encoder-cpu.onnx", providers=["CPUExecutionProvider"]
)

In [18]:
%%time
embeddings = onnx_embedder.run(
    [],
    {
        "cube": datacube["pixels"].numpy(),
        "time": datacube["time"].numpy(),
        "latlon": datacube["latlon"].numpy(),
        "waves": datacube["waves"].numpy(),
        "gsd": datacube["gsd"].numpy(),
    },
)[0]
embeddings.shape

CPU times: user 12min 17s, sys: 3.29 s, total: 12min 20s
Wall time: 1min 37s


(36, 1024)

For each chip, we have an embedding of size `1024`

### Store the results

We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk.

In [17]:
# Write data to pyarrow table
index = {
    "datetimes": datetimes,
    "chip_ids": chip_ids,
    "item_ids": item_ids,
    "emeddings": [np.ascontiguousarray(dat) for dat in embeddings],
    "geometry": ga.as_geoarrow([dat.wkt for dat in bboxs]),
}
table = pa.table(index)
table

pyarrow.Table
datetimes: timestamp[us, tz=UTC]
chip_ids: list<item: int64>
  child 0, item: int64
item_ids: string
emeddings: list<item: float>
  child 0, item: float
geometry: extension<geoarrow.polygon<PolygonType>>
----
datetimes: [[2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-28 11:30:56.771000Z,2018-08-23 11:30:50.574000Z,2018-08-23 11:30:50.574000Z,...,2018-07-09 11:24:55.535000Z,2018-07-09 11:24:55.535000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z,2018-07-04 11:30:35.271000Z]]
chip_ids: [[[0,0],[1,0],...,[1,0],[2,0]]]
item_ids: [["S2A_29SNB_20180828_1_L2A","S2A_29SNB_20180828_1_L2A","S2A_29SNB_20180828_1_L2A","S2B_29SNB_20180823_1_L2A","S2B_29SNB_20180823_1_L2A",...,"S2A_29SNB_20180709_0_L2A","S2A_29SNB_20180709_0_L2A","S2B_29SNB_20180704_0_L2A","S2B_29SNB_20180704_0_L2A","S2B_29SNB_20180704_0_L2A"]]
emeddings: [[[0.08737555,0.09504964,0.053098626,-0.08628022,-0.048699543,...,-0.0032533202,-0.25458118,-0.022807367,-0.0469472,0.05704065],[0.086890

In [18]:
pq.write_table(table, "embeddings.parquet")