In [1]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import yaml
from claymodel.module import ClayMAEModule

# Add the src directory to the Python path
sys.path.append(str(Path().resolve().parent / "src"))

import config

# Import necessary libraries
import torch
from torch import nn
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader, random_split, Subset

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:

# Load model
model = ClayMAEModule.load_from_checkpoint(config.FM_MODEL_DIR / "clay-v1.5.ckpt", metadata_path=config.PROJECT_ROOT / "configs/metadata.yaml")
model.eval()
# prepare encoder
model.model.encoder.mask_ratio = 0

ClayMAEModule(
  (model): ClayMAE(
    (teacher): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_featur

In [18]:

# Load sensor metadata
with open(config.PROJECT_ROOT / "configs/metadata.yaml", "r") as f:
    metadata = yaml.safe_load(f)

# Prepare Sentinel-2 data
sensor = "sentinel-2"
chips = torch.randn(1, 4, 64, 64)  # [batch, bands, height, width]

# Get wavelengths from metadata (convert μm to nm)
bands = metadata[sensor]["band_order"]
waves = torch.tensor(
    [metadata[sensor]["bands"]["wavelength"][b] * 1000 for b in bands],
    dtype=torch.float32
).to(device) 

timestamps = torch.zeros(1, 4)  # [week, hour, lat, lon] - can be zeros

datacube = {
    "pixels": chips.to(device),              # shape: [B, C, H, W]
    "time": torch.zeros(1, 6).to(device),    # shape: [B, 2] — week, hour
    "latlon": torch.zeros(1, 2).to(device),  # shape: [B, 2] — lat, lon
    "gsd": torch.tensor([10.0], device=device),  # shape: [1]
    "waves": waves.to(device),         # shape: [C] or [1, C]
}

# Generate embeddings
with torch.no_grad():
    embeddings = model.model.encoder(datacube)

print(f"Embeddings shape: {embeddings[0].shape}")  # [1, 1024]

torch.Size([1, 64, 1016]) torch.Size([1, 64, 8])
Embeddings shape: torch.Size([1, 65, 1024])


In [27]:
metadata["sentinel-2"]["bands"]["mean"].values()

dict_values([1552.0, 1355.0, 1105.0, 2743.0])