# Embeddings factory script

Generate embeddings in bulk, on a per MGRS tile basis.

Embeddings are stored into the `data/embeddings/` folder,
and then uploaded to `s3://clay-vector-embeddings/`.

In [None]:
import glob
import os
import warnings

import duckdb
import lightning as L
import torch
import tqdm

from src.datamodule import ClayDataModule
from src.model_clay import CLAYModule

In [2]:
# Set some environment variables and parameters
torch.set_float32_matmul_precision(precision="medium")
os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1"

In [3]:
# Generate list of MGRS tiles
#!aws s3 ls s3://clay-tiles-02/02/ | tr -s ' ' |  cut -d ' ' -f 3 | cut -d '/' -f 1 > mgrs_world.txt
mgrs_tiles = open(file="mgrs_world.txt").read().splitlines()
mgrs_tiles.sort(key=lambda m: m[2])  # sort by latitudinal band from South to North

In [4]:
# Setup trainer and load model weights from checkpoint
# !aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/
trainer = L.Trainer(precision="bf16-mixed", logger=False)
model: L.LightningModule = CLAYModule.load_from_checkpoint(
    checkpoint_path="checkpoints/clay-small-70MT-1100T-10E.ckpt"
)
#!mamba install triton
# model.model.encoder = torch.compile(model=model.model.encoder)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Generate embeddings for each MGRS tile
for mgrs_tile in (pbar := tqdm.tqdm(iterable=mgrs_tiles)):
    if len(glob.glob(pathname=f"data/embeddings/{mgrs_tile}_*.gpq")) == 0:
        pbar.set_description(desc=f"Processing MGRS Tile {mgrs_tile}")
        datamodule: L.LightningDataModule = ClayDataModule(
            data_dir=f"s3://clay-tiles-02/02/{mgrs_tile}", batch_size=32, num_workers=16
        )
        try:
            trainer.predict(model=model, datamodule=datamodule)
            !aws s3 cp data/embeddings/$mgrs_tile*.gpq s3://clay-vector-embeddings/v001/
        except RuntimeError as err:
            print(f"Processing of MGRS Tile {mgrs_tile} failed because of {err}")
            warnings.warn(message=repr(err))
        except AssertionError as err:
            print(f"Processing of MGRS Tile {mgrs_tile} failed because of {err}")
            warnings.warn(message=repr(err))

print("All done!")

Processing MGRS Tile 59HQB:   0%|                      | 0/1203 [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                             | 0/? [00:00<?, ?it/s]

In [6]:
duckdb.sql(query="SELECT COUNT(*) from read_parquet('data/embeddings/*.gpq')")

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│       947019 │
└──────────────┘