# Embedding creation

This tutorial shows how to create embeddings using Clay and store them in geoparquet.
Creating embeddings is useful for use in similarity seach applications, and when
training classification heads on top of the embeddings, as shown in the
[](bla) tutorial.

Creating embeddings consists of three simple steps:

1. Search for imagery to be used
2. Create chips dynamically from the source data with [stacchip](https://clay-foundation.github.io/stacchip/)
3. Pass chips to Clay and store the output as geoparquet

Lets look at these one by one, but first ensure that stacchip is installed,
a library we are going to use to generate dynamic chips to pass to Clay.

In [25]:
! pip install stacchip



In [26]:
import math

import geopandas as gpd
import numpy as np
import pandas as pd
import pystac_client
import torch
import yaml
from box import Box
from matplotlib import pyplot as plt
from rasterio.enums import Resampling
from shapely import Point
from torchvision.transforms import v2
import numpy as np
import math
import geoarrow.pyarrow as ga
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq

import pystac_client
from stacchip.indexer import Sentinel2Indexer
from stacchip.chipper import Chipper
import os
import matplotlib.pyplot as plt
import requests

### Note: This notebook requires CUDA

This is because we are using the Clay encoder from a [torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html), which was compiled using CUDA.

In [27]:
if not torch.cuda.is_available():
    raise ValueError("The compiled version of Clay needs CUDA")

## Find data for AOI

The first step is to find STAC items of imagery that we want to use
to create embeddings. In this example we are going to use
[Earth Genome's composite dataset](https://medium.com/earthrisemedia/announcing-public-access-to-our-global-cloud-free-imagery-archive-bb21311abb69)
which comes with a great STAC catalog.

We are also going to create embeddings along time so that we have multiple
embeddings for the same location at different moments in time.

In [28]:
# Point over Monchique Portugal
lat, lon = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

In [29]:
# Optimize GDAL settings for cloud optimized reading
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["AWS_REQUEST_PAYER"] = "requester"

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.get_all_items()

# Reduce to one per date (there might be some duplicates
# based on the location)
items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")



Found 12 items


To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case.

In [30]:
chips = []
datetimes = []
bboxs = []
chip_ids = []
item_ids = []

for item in items:
    print(f"Working on {item}")

    # Index the chips in the item
    indexer = Sentinel2Indexer(item)

    # Instanciate the chipper
    chipper = Chipper(indexer, assets=["red", "green", "blue", "nir", "scl"])

    # Get first chip for the "image" asset key
    for idx, (x, y, chip) in enumerate(chipper):
        if idx > 2:
            break
        del chip["scl"]
        chips.append(chip)
        datetimes.append(item.datetime)
        bboxs.append(indexer.get_chip_bbox(x, y))
        chip_ids.append((x, y))
        item_ids.append(item.id)

Working on <Item id=S2A_29SNB_20180828_1_L2A>
Working on <Item id=S2B_29SNB_20180823_1_L2A>
Working on <Item id=S2A_29SNB_20180818_1_L2A>
Working on <Item id=S2B_29SNB_20180813_0_L2A>
Working on <Item id=S2A_29SNB_20180808_1_L2A>
Working on <Item id=S2B_29SNB_20180803_1_L2A>
Working on <Item id=S2A_29SNB_20180729_1_L2A>
Working on <Item id=S2B_29SNB_20180724_0_L2A>
Working on <Item id=S2A_29SNB_20180719_0_L2A>
Working on <Item id=S2B_29SNB_20180714_0_L2A>
Working on <Item id=S2A_29SNB_20180709_0_L2A>
Working on <Item id=S2B_29SNB_20180704_0_L2A>


In [31]:
pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])
pixels.shape

(36, 4, 256, 256)

In [32]:
# Extract mean, std, and wavelengths from metadata
platform = "sentinel-2-l2a"
# Retrieve the file content from the URL

url = (
    "https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml"
)
response = requests.get(url, allow_redirects=True)

# Convert bytes to string
content = response.content.decode("utf-8")

# Load the yaml
content = yaml.safe_load(content)

metadata = Box(content)
mean = []
std = []
waves = []
# Use the band names to get the correct values in the correct order.
for band in chips[0].keys():
    mean.append(metadata[platform].bands.mean[band])
    std.append(metadata[platform].bands.std[band])
    waves.append(metadata[platform].bands.wavelength[band])

# Prepare the normalization transform function using the mean and std values.
transform = v2.Compose(
    [
        v2.Normalize(mean=mean, std=std),
    ]
)

In [33]:
def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24

    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))


times = [normalize_timestamp(dat) for dat in datetimes]
week_norm = [dat[0] for dat in times]
hour_norm = [dat[1] for dat in times]


# Prep lat/lon embedding using the
def normalize_latlon(lat, lon):
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180

    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))


latlons = [normalize_latlon(lat, lon)] * len(times)
lat_norm = [dat[0] for dat in latlons]
lon_norm = [dat[1] for dat in latlons]

# Prep gsd
gsd = [10]

# Normalize pixels
pixels = transform(pixels)

In [34]:
datacube = (
    torch.tensor(pixels, dtype=torch.float32, device="cuda"),
    torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32, device="cuda"),
    torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device="cuda"),
    torch.tensor(waves, dtype=torch.float32, device="cuda"),
    torch.tensor(gsd, dtype=torch.float32, device="cuda"),
)

In [35]:
[dat.shape for dat in datacube]

[torch.Size([36, 4, 256, 256]),
 torch.Size([36, 4]),
 torch.Size([36, 4]),
 torch.Size([4]),
 torch.Size([1])]

## Generate embeddings using the Clay encoder

We are going to download the compiled verision of the Clay
encouder, which has been prepared using torchscript.

In [36]:
!wget https://huggingface.co/made-with-clay/Clay/resolve/main/clay-v1-encoder.pt

--2024-07-22 09:35:43--  https://huggingface.co/made-with-clay/Clay/resolve/main/clay-v1-encoder.pt
Resolving huggingface.co (huggingface.co)... 18.239.50.49, 18.239.50.103, 18.239.50.80, ...
Connecting to huggingface.co (huggingface.co)|18.239.50.49|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/9e/5f/9e5f70717de49e5e8fb94cc66c7c40e24e6800ae6dbf377099154c19eafdc5f6/6efe1d94fde51e88de4d2d6df699fb9f055a57ea8f1bc31c7a25fb1b7796f5ad?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27clay-v1-encoder.pt%3B+filename%3D%22clay-v1-encoder.pt%22%3B&Expires=1721900143&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTkwMDE0M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzllLzVmLzllNWY3MDcxN2RlNDllNWU4ZmI5NGNjNjZjN2M0MGUyNGU2ODAwYWU2ZGJmMzc3MDk5MTU0YzE5ZWFmZGM1ZjYvNmVmZTFkOTRmZGU1MWU4OGRlNGQyZDZkZjY5OWZiOWYwNTVhNTdlYThmMWJjMzFjN2EyNWZiMWI3Nzk

Load the packaged encoder using pytorch.

In [37]:
clay_encoder = torch.export.load("clay-v1-encoder.pt").module()

  getattr_node = gm.graph.get_attr(lifted_node)


Run the encoder and extract the class embedding, which is the
main embedding vector that can be used for image classification
or similarity search.

In [38]:
# Run the clay encoder
with torch.no_grad():
    unmsk_patch, unmsk_idx, msk_idx, msk_matrix = clay_encoder(*datacube)
# Get class embeddings
cls_embedding = unmsk_patch[:, 0, :]
# Print shape of class embeddings
cls_embedding.shape

torch.Size([36, 768])

## Store the results in a geoparquet table

We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk.

In [39]:
# Write data to pyarrow table
index = {
    "datetimes": datetimes,
    "chip_ids": chip_ids,
    "item_ids": item_ids,
    "emeddings": [np.ascontiguousarray(dat) for dat in cls_embedding.cpu().numpy()],
    "geometry": ga.as_geoarrow([dat.wkt for dat in bboxs]),
}
table = pa.table(index)
table

pyarrow.Table
datetimes: timestamp[us, tz=UTC]
chip_ids: list<item: int64>
  child 0, item: int64
item_ids: string
emeddings: list<item: float>
  child 0, item: float
geometry: extension<geoarrow.polygon<PolygonType>>
----
datetimes: [[2018-08-28 11:30:56.771000,2018-08-28 11:30:56.771000,2018-08-28 11:30:56.771000,2018-08-23 11:30:50.574000,2018-08-23 11:30:50.574000,...,2018-07-09 11:24:55.535000,2018-07-09 11:24:55.535000,2018-07-04 11:30:35.271000,2018-07-04 11:30:35.271000,2018-07-04 11:30:35.271000]]
chip_ids: [[[0,0],[1,0],...,[1,0],[2,0]]]
item_ids: [["S2A_29SNB_20180828_1_L2A","S2A_29SNB_20180828_1_L2A","S2A_29SNB_20180828_1_L2A","S2B_29SNB_20180823_1_L2A","S2B_29SNB_20180823_1_L2A",...,"S2A_29SNB_20180709_0_L2A","S2A_29SNB_20180709_0_L2A","S2B_29SNB_20180704_0_L2A","S2B_29SNB_20180704_0_L2A","S2B_29SNB_20180704_0_L2A"]]
emeddings: [[[-0.14773352,0.08466569,0.13797817,0.11150878,0.06517958,...,0.03668152,-0.092160314,0.025934448,-0.124962896,-0.034070194],[-0.14430065,0.085857

In [40]:
pq.write_table(table, "clay_embeddings.parquet")