# Session 4A: Generate Your Own Clay Embeddings

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

In Session 2, we used **precomputed Clay embeddings** so everyone could run on CPU without installing Clay. In this advanced notebook, you will learn to generate embeddings **from scratch** using the Clay foundation model.

By the end of this notebook, you will be able to:

1. **Acquire satellite imagery** via STAC or from sample GeoTIFF files
2. **Chip imagery** into 256×256 tiles suitable for Clay
3. **Prepare Clay inputs** (normalized pixels, timestamps, wavelengths)
4. **Generate embeddings** using the Clay encoder
5. **Save embeddings** in a format compatible with Session 2 workflows
6. **Reuse the Session 2 classifier/retrieval pipeline** on your own embeddings

---

## Prerequisites

This notebook requires a **dedicated Clay environment** (separate from the main `r2ska-tutorial` environment) because Clay needs Python ≥3.11 and PyTorch ≥2.4.

### Quick Setup (Recommended)

Run the setup script which creates the environment and downloads the checkpoint:

```bash
# From the tutorial directory
./setup_clay_env.sh
```

Then activate the environment:
```bash
conda activate r2ska-clay
jupyter lab
```

### Manual Setup

If you prefer to set up manually:

**1. Create the Clay environment:**
```bash
conda env create -f environment-clay.yml
conda activate r2ska-clay
```

**2. Download the Clay v1.5 checkpoint (~1.2 GB):**
```bash
wget https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt
```

Or download manually from: https://huggingface.co/made-with-clay/Clay/tree/main/v1.5

Place `clay-v1.5.ckpt` in the same directory as this notebook (or update `CHECKPOINT_PATH` below).

### Environment Differences

| Feature | `r2ska-tutorial` | `r2ska-clay` |
|---------|------------------|--------------|
| Python | 3.10 | 3.11 |
| PyTorch | ≥2.0 | ≥2.4 |
| Clay | Not installed | Installed |
| STAC tools | Optional | Included |

---

## References

- Clay documentation: https://clay-foundation.github.io/model/
- Clay "Basic Use" tutorial: https://clay-foundation.github.io/model/getting-started/basic_use.html
- Earth Search STAC: https://earth-search.aws.element84.com/v1

## Part 1: Setup and Configuration

In [None]:
# Core dependencies
import os
import json
import random
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import requests

from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors

# Check for Clay
try:
    import torch
    TORCH_OK = True
except ImportError:
    TORCH_OK = False
    print("ERROR: PyTorch not installed. Install with: pip install torch")

try:
    from clay.module import ClayMAEModule
    CLAY_OK = True
except ImportError:
    try:
        # Alternative import path
        from claymodel.module import ClayMAEModule
        CLAY_OK = True
    except ImportError:
        CLAY_OK = False
        print("ERROR: Clay not installed.")
        print("Install with: pip install git+https://github.com/Clay-foundation/model.git")

# Check for optional STAC dependencies
try:
    import pystac_client
    import stackstac
    STAC_OK = True
except ImportError:
    STAC_OK = False
    print("Note: pystac-client/stackstac not available. Will use sample GeoTIFF fallback.")
    print("Install with: pip install pystac-client stackstac")

# Check for rasterio (needed for GeoTIFF loading)
try:
    import rasterio
    RASTERIO_OK = True
except ImportError:
    RASTERIO_OK = False
    print("Note: rasterio not available. Install with: pip install rasterio")

# Check for shapely (needed for geometry operations)
try:
    from shapely.geometry import box, Point, Polygon
    from shapely import wkb
    SHAPELY_OK = True
except ImportError:
    SHAPELY_OK = False
    print("Note: shapely not available. Install with: pip install shapely")

print("\n=== Dependency Status ===")
print(f"PyTorch:    {'OK' if TORCH_OK else 'MISSING'}")
print(f"Clay:       {'OK' if CLAY_OK else 'MISSING'}")
print(f"STAC tools: {'OK' if STAC_OK else 'Not installed (will use fallback)'}")
print(f"Rasterio:   {'OK' if RASTERIO_OK else 'Not installed'}")
print(f"Shapely:    {'OK' if SHAPELY_OK else 'Not installed'}")

if not CLAY_OK:
    print("\n" + "="*60)
    print("CLAY IS REQUIRED FOR THIS NOTEBOOK")
    print("Please install Clay before continuing.")
    print("="*60)

In [None]:
# Device selection: CUDA > MPS (Apple Silicon) > CPU
if TORCH_OK:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS (Apple Silicon)")
    else:
        device = torch.device("cpu")
        print("Using CPU (embedding generation will be slow)")
else:
    device = None
    print("PyTorch not available")

In [None]:
@dataclass
class Config:
    """Configuration for embedding generation."""
    
    # Paths
    data_dir: str = "./data/session2b"
    checkpoint_path: str = "./clay-v1.5.ckpt"
    output_parquet: str = "./data/session2b/my_embeddings.parquet"
    
    # AOI: SF Bay Area (same region as Session 2 precomputed data)
    bbox: Tuple[float, float, float, float] = (-122.52, 37.70, -122.35, 37.83)
    
    # STAC parameters
    stac_url: str = "https://earth-search.aws.element84.com/v1"
    collection: str = "sentinel-2-l2a"
    datetime_range: str = "2022-05-01/2022-05-31"
    max_cloud_cover: int = 20
    
    # Sentinel-2 bands and wavelengths (nm)
    # Using 4 bands for simplicity: Blue, Green, Red, NIR
    bands: List[str] = field(default_factory=lambda: ["B02", "B03", "B04", "B08"])
    wavelengths_nm: List[float] = field(default_factory=lambda: [490.0, 560.0, 665.0, 842.0])
    resolution_m: int = 10
    
    # Chipping parameters
    chip_size: int = 256  # Clay expects 256x256
    max_chips: int = 100  # Limit for workshop
    nodata_threshold: float = 0.1  # Skip chips with >10% nodata
    
    # Processing
    batch_size: int = 8
    random_seed: int = 42
    
    def __repr__(self):
        lines = ["Config:"]
        for k, v in self.__dict__.items():
            lines.append(f"  {k}: {v}")
        return "\n".join(lines)

cfg = Config()

# Set random seeds
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
if TORCH_OK:
    torch.manual_seed(cfg.random_seed)

# Create data directory
Path(cfg.data_dir).mkdir(parents=True, exist_ok=True)

print(cfg)

## Part 2: Data Acquisition via STAC

### What is STAC?

**STAC (SpatioTemporal Asset Catalog)** is a standardized way to describe and search for geospatial data. Think of it as a "search engine for satellite imagery" that allows you to:

- **Search** for imagery by location (bounding box), time range, and properties (e.g., cloud cover)
- **Discover** what data is available without downloading everything first
- **Access** imagery directly from cloud storage (Cloud-Optimized GeoTIFFs)

### Key STAC Concepts

| Term | Description |
|------|-------------|
| **Catalog** | A collection of STAC items, like a library |
| **Collection** | A group of related items (e.g., "sentinel-2-l2a" for Sentinel-2 Level-2A) |
| **Item** | A single spatiotemporal asset (one satellite scene at one time) |
| **Asset** | A file within an item (e.g., individual spectral bands) |

### Our Data Source: Earth Search

We use [Earth Search](https://earth-search.aws.element84.com/v1), a free STAC API hosted by Element 84 that provides access to:
- **Sentinel-2 L2A**: Surface reflectance imagery (atmospherically corrected)
- **Landsat**: USGS Landsat Collection 2
- **NAIP**: US aerial imagery
- And more...

### The Query Process

```
1. Connect to STAC catalog (Earth Search)
2. Search for items matching:
   - Collection: sentinel-2-l2a
   - Bounding box: our area of interest
   - Date range: May 2022
   - Cloud cover: < 20%
3. Select the best matching scene
4. Load specific bands (Blue, Green, Red, NIR)
5. Stack bands into a single array
```

If STAC is unavailable, the notebook falls back to synthetic data for demonstration.

In [None]:
def download_file(url: str, dst: Path, chunk_size: int = 1 << 20) -> Path:
    """Download a file with progress indication."""
    dst = Path(dst)
    dst.parent.mkdir(parents=True, exist_ok=True)
    
    if dst.exists() and dst.stat().st_size > 0:
        print(f"  File exists: {dst}")
        return dst
    
    print(f"  Downloading: {url}")
    with requests.get(url, stream=True, timeout=120) as r:
        r.raise_for_status()
        total = int(r.headers.get('content-length', 0))
        downloaded = 0
        with open(dst, "wb") as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    if total > 0:
                        pct = 100 * downloaded / total
                        print(f"\r  Progress: {pct:.1f}%", end="", flush=True)
        print()
    return dst


# Mapping from Sentinel-2 band codes to Earth Search asset names
SENTINEL2_BAND_MAP = {
    "B01": "coastal",
    "B02": "blue",
    "B03": "green",
    "B04": "red",
    "B05": "rededge1",
    "B06": "rededge2",
    "B07": "rededge3",
    "B08": "nir",
    "B8A": "nir08",
    "B09": "nir09",
    "B11": "swir16",
    "B12": "swir22",
    "SCL": "scl",
}


def acquire_via_stac(cfg: Config) -> Tuple[np.ndarray, dict]:
    """
    Acquire imagery via STAC query.
    
    Returns:
        imagery: np.ndarray of shape (bands, height, width)
        metadata: dict with acquisition info
    """
    if not STAC_OK:
        raise ImportError("STAC dependencies not available")
    
    print(f"Searching STAC: {cfg.stac_url}")
    print(f"  Collection: {cfg.collection}")
    print(f"  BBox: {cfg.bbox}")
    print(f"  Date range: {cfg.datetime_range}")
    
    catalog = pystac_client.Client.open(cfg.stac_url)
    
    search = catalog.search(
        collections=[cfg.collection],
        bbox=cfg.bbox,
        datetime=cfg.datetime_range,
        query={"eo:cloud_cover": {"lt": cfg.max_cloud_cover}},
        max_items=1,
    )
    
    items = list(search.items())
    if len(items) == 0:
        raise RuntimeError("No STAC items found. Try different date range or cloud cover threshold.")
    
    item = items[0]
    print(f"\nFound item: {item.id}")
    print(f"  Datetime: {item.datetime}")
    print(f"  Cloud cover: {item.properties.get('eo:cloud_cover', 'N/A')}%")
    
    # Debug: show available assets
    available_assets = list(item.assets.keys())
    print(f"  Available assets: {available_assets[:10]}{'...' if len(available_assets) > 10 else ''}")
    
    # Map band codes to asset names (Earth Search uses descriptive names)
    asset_names = []
    for band in cfg.bands:
        if band in available_assets:
            # Band code exists as-is
            asset_names.append(band)
        elif band in SENTINEL2_BAND_MAP and SENTINEL2_BAND_MAP[band] in available_assets:
            # Map band code to descriptive name
            asset_names.append(SENTINEL2_BAND_MAP[band])
        elif band.lower() in available_assets:
            # Try lowercase
            asset_names.append(band.lower())
        else:
            raise RuntimeError(f"Band {band} not found in available assets: {available_assets}")
    
    print(f"  Mapped bands: {cfg.bands} -> {asset_names}")

    # Stack bands using stackstac
    print(f"\nLoading bands: {asset_names}")

    # Debug: show STAC item metadata for CRS/bounds
    print(f"  Item EPSG from metadata: {item.properties.get('proj:epsg')}")
    print(f"  Item bbox: {item.bbox}")

    # Get the item's native EPSG code from metadata
    # If not available, determine UTM zone from bbox center longitude
    item_epsg = item.properties.get('proj:epsg')
    if item_epsg is None:
        # Calculate UTM zone from center longitude of bbox
        center_lon = (cfg.bbox[0] + cfg.bbox[2]) / 2
        utm_zone = int((center_lon + 180) / 6) + 1
        # Northern hemisphere for positive latitude
        center_lat = (cfg.bbox[1] + cfg.bbox[3]) / 2
        item_epsg = 32600 + utm_zone if center_lat >= 0 else 32700 + utm_zone
        print(f"  Computed UTM zone {utm_zone}N -> EPSG:{item_epsg}")
    print(f"  Using EPSG: {item_epsg}")

    # Transform bbox from EPSG:4326 to target CRS for explicit bounds
    # stackstac needs explicit bounds when STAC metadata is incomplete (GitHub #187)
    try:
        from pyproj import Transformer
        transformer = Transformer.from_crs("EPSG:4326", f"EPSG:{item_epsg}", always_xy=True)
        min_lon, min_lat, max_lon, max_lat = cfg.bbox
        x_min, y_min = transformer.transform(min_lon, min_lat)
        x_max, y_max = transformer.transform(max_lon, max_lat)
        stack_bounds = (x_min, y_min, x_max, y_max)
        print(f"  Transformed bounds: {stack_bounds}")
    except ImportError:
        print("  Warning: pyproj not available, cannot transform bounds")
        stack_bounds = None

    stack = stackstac.stack(
        [item],
        assets=asset_names,
        resolution=cfg.resolution_m,
        epsg=item_epsg,  # Explicit EPSG prevents metadata inference issues
        bounds=stack_bounds,  # Explicit bounds in target CRS
        chunksize=2048,  # Load in reasonable chunks
    )
    
    # Debug: check stack shape before computing
    print(f"  Stack shape (lazy): {stack.shape}")
    print(f"  Stack dims: {stack.dims}")
    
    if stack.shape[0] == 0:
        raise RuntimeError("Stack has no time dimension - no data was loaded")
    if stack.shape[1] == 0:
        raise RuntimeError("Stack has no bands - check asset names")
    
    # Load into memory (first time slice)
    arr = stack.isel(time=0).compute().values  # (bands, height, width)
    
    print(f"  Loaded array: {arr.shape}")
    
    # Handle NaN values (nodata)
    arr = np.nan_to_num(arr, nan=0).astype(np.float32)
    
    metadata = {
        "source": "stac",
        "item_id": item.id,
        "datetime": item.datetime.isoformat() if item.datetime else None,
        "bbox": cfg.bbox,
        "bands": cfg.bands,
    }
    
    return arr, metadata


def create_synthetic_imagery(cfg: Config) -> Tuple[np.ndarray, dict]:
    """
    Create synthetic imagery for demonstration when real data is unavailable.
    
    Returns:
        imagery: np.ndarray of shape (bands, height, width)
        metadata: dict with acquisition info
    """
    print("Creating synthetic sample data for demonstration...")
    
    # Create synthetic data matching expected dimensions
    # Enough for ~16 chips at 256x256
    height = width = cfg.chip_size * 4
    n_bands = len(cfg.bands)
    
    # Generate realistic-looking data with some structure
    np.random.seed(cfg.random_seed)
    arr = np.random.uniform(0, 3000, (n_bands, height, width)).astype(np.float32)
    
    # Add some spatial structure (gradient + noise)
    y_grad = np.linspace(0, 1, height).reshape(-1, 1)
    x_grad = np.linspace(0, 1, width).reshape(1, -1)
    for b in range(n_bands):
        arr[b] += 1000 * y_grad * x_grad
    
    print(f"Created synthetic data: {arr.shape}")
    
    metadata = {
        "source": "synthetic",
        "datetime": "2022-05-18T00:00:00",
        "bbox": cfg.bbox,
        "bands": cfg.bands,
    }
    
    return arr, metadata


def acquire_imagery(cfg: Config) -> Tuple[np.ndarray, dict]:
    """Acquire imagery using STAC, falling back to synthetic data if unavailable."""
    
    # Try STAC if dependencies are available
    if STAC_OK:
        try:
            print("Attempting STAC acquisition...")
            return acquire_via_stac(cfg)
        except Exception as e:
            print(f"STAC acquisition failed: {e}")
            print("Falling back to synthetic data...\n")
            return create_synthetic_imagery(cfg)
    else:
        # STAC not available - use synthetic data
        print("STAC dependencies not installed.")
        return create_synthetic_imagery(cfg)

In [None]:
# Acquire the imagery
imagery, acq_metadata = acquire_imagery(cfg)

print(f"\n=== Acquisition Summary ===")
print(f"Source: {acq_metadata['source']}")
print(f"Shape: {imagery.shape} (bands, height, width)")
print(f"Dtype: {imagery.dtype}")
print(f"Value range: [{imagery.min():.1f}, {imagery.max():.1f}]")
print(f"NaN percentage: {100 * np.isnan(imagery).mean():.2f}%")

In [None]:
# Visualize the acquired imagery
def normalize_for_display(arr, percentile_low=2, percentile_high=98):
    """Normalize array for display using percentile stretching."""
    arr = np.nan_to_num(arr, nan=0)
    vmin = np.percentile(arr, percentile_low)
    vmax = np.percentile(arr, percentile_high)
    return np.clip((arr - vmin) / (vmax - vmin + 1e-8), 0, 1)

# Create RGB composite (bands: B04, B03, B02 = indices 2, 1, 0)
if imagery.shape[0] >= 3:
    rgb = np.stack([
        normalize_for_display(imagery[2]),  # Red (B04)
        normalize_for_display(imagery[1]),  # Green (B03)
        normalize_for_display(imagery[0]),  # Blue (B02)
    ], axis=-1)
else:
    rgb = normalize_for_display(imagery[0])

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# RGB composite
axes[0].imshow(rgb)
axes[0].set_title(f"RGB Composite\nSource: {acq_metadata['source']}")
axes[0].axis("off")

# NIR band (index 3)
if imagery.shape[0] >= 4:
    nir_display = normalize_for_display(imagery[3])
    axes[1].imshow(nir_display, cmap="RdYlGn")
    axes[1].set_title("NIR Band (B08)")
else:
    axes[1].imshow(normalize_for_display(imagery[-1]), cmap="gray")
    axes[1].set_title("Last Band")
axes[1].axis("off")

plt.tight_layout()
plt.show()

print(f"Image dimensions: {imagery.shape[1]} x {imagery.shape[2]} pixels")
print(f"At {cfg.resolution_m}m resolution: ~{imagery.shape[1] * cfg.resolution_m / 1000:.1f} x {imagery.shape[2] * cfg.resolution_m / 1000:.1f} km")

## Part 3: Chipping and Preprocessing

Clay expects input chips of shape `(batch, bands, 256, 256)`. We need to:

1. **Tile** the imagery into non-overlapping 256×256 chips
2. **Filter** chips that contain too much nodata
3. **Normalize** pixel values (Sentinel-2 L2A values are typically 0-10000, we scale to 0-1)

In [None]:
def create_chips(imagery: np.ndarray, chip_size: int = 256, 
                 nodata_threshold: float = 0.1, max_chips: int = None) -> Tuple[np.ndarray, List[dict]]:
    """
    Create non-overlapping chips from imagery.
    
    Args:
        imagery: Array of shape (bands, height, width)
        chip_size: Size of each chip (default 256 for Clay)
        nodata_threshold: Skip chips with more than this fraction of nodata
        max_chips: Maximum number of chips to return
    
    Returns:
        chips: Array of shape (n_chips, bands, chip_size, chip_size)
        chip_info: List of dicts with chip metadata (row, col, bounds)
    """
    bands, height, width = imagery.shape
    
    chips = []
    chip_info = []
    
    n_rows = height // chip_size
    n_cols = width // chip_size
    
    print(f"Creating {chip_size}x{chip_size} chips...")
    print(f"  Image size: {height} x {width}")
    print(f"  Potential chips: {n_rows} rows x {n_cols} cols = {n_rows * n_cols}")
    
    for row in range(n_rows):
        for col in range(n_cols):
            # Extract chip
            y0 = row * chip_size
            x0 = col * chip_size
            chip = imagery[:, y0:y0+chip_size, x0:x0+chip_size]
            
            # Check for nodata (NaN or zeros)
            nodata_frac = np.isnan(chip).mean() + (chip == 0).mean()
            if nodata_frac > nodata_threshold:
                continue
            
            chips.append(chip)
            chip_info.append({
                "row": row,
                "col": col,
                "y0": y0,
                "x0": x0,
            })
            
            if max_chips and len(chips) >= max_chips:
                break
        if max_chips and len(chips) >= max_chips:
            break
    
    chips = np.stack(chips, axis=0)  # (n_chips, bands, h, w)
    
    print(f"  Valid chips: {len(chips)} (skipped {n_rows * n_cols - len(chips)} with nodata)")
    
    return chips, chip_info


def normalize_sentinel2(chips: np.ndarray) -> np.ndarray:
    """
    Normalize Sentinel-2 L2A values to [0, 1] range.
    
    Sentinel-2 L2A surface reflectance values are typically in [0, 10000].
    We divide by 10000 and clip to [0, 1].
    """
    chips = np.nan_to_num(chips, nan=0)
    chips = chips / 10000.0
    chips = np.clip(chips, 0, 1)
    return chips.astype(np.float32)


# Create chips
chips_raw, chip_info = create_chips(
    imagery, 
    chip_size=cfg.chip_size,
    nodata_threshold=cfg.nodata_threshold,
    max_chips=cfg.max_chips
)

# Normalize
chips_norm = normalize_sentinel2(chips_raw)

print(f"\n=== Chips Summary ===")
print(f"Shape: {chips_norm.shape} (n_chips, bands, height, width)")
print(f"Value range: [{chips_norm.min():.4f}, {chips_norm.max():.4f}]")

In [None]:
# Visualize sample chips as RGB
n_show = min(12, len(chips_norm))
n_cols = 4
n_rows = (n_show + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows))
axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes.flatten()

for i in range(n_show):
    chip = chips_norm[i]
    
    # Create RGB (B04, B03, B02 = indices 2, 1, 0)
    if chip.shape[0] >= 3:
        # Use percentile normalization per-chip for proper display
        def chip_normalize(band):
            vmin, vmax = np.percentile(band, [2, 98])
            return np.clip((band - vmin) / (vmax - vmin + 1e-8), 0, 1)
        rgb = np.stack([chip_normalize(chip[2]), chip_normalize(chip[1]), chip_normalize(chip[0])], axis=-1)
    else:
        rgb = chip[0]
    
    axes[i].imshow(rgb, cmap="gray" if rgb.ndim == 2 else None)
    axes[i].set_title(f"Chip {i}\n({chip_info[i]['row']}, {chip_info[i]['col']})")
    axes[i].axis("off")

# Hide unused axes
for i in range(n_show, len(axes)):
    axes[i].axis("off")

plt.suptitle(f"Sample chips ({cfg.chip_size}x{cfg.chip_size} pixels)", fontsize=14)
plt.tight_layout()
plt.show()

## Part 4: Prepare Clay Inputs

### Why Does Clay Need Special Input Preparation?

Clay is a **foundation model** trained on diverse satellite imagery from multiple sensors. To handle this diversity, Clay uses a flexible input format that explicitly encodes:

1. **What** the model sees (pixel values and which spectral bands)
2. **When** the image was captured (temporal encoding)
3. **Where** on Earth the image is located (spatial encoding)

This design allows Clay to generalize across different sensors, seasons, and geographic regions.

### The Four Input Components

| Input | Shape | Description |
|-------|-------|-------------|
| **pixels** | `[B, C, 256, 256]` | Normalized pixel values (0-1 range) |
| **time** | `[B, 4]` | Encoded timestamp: `[week_sin, week_cos, hour_sin, hour_cos]` |
| **latlon** | `[B, 4]` | Encoded location: `[lat_sin, lat_cos, lon_sin, lon_cos]` |
| **waves** | `[N]` | Center wavelengths of each band in nanometers |
| **gsd** | scalar | Ground sampling distance in meters |

### Why Sine/Cosine Encoding?

Time and location are **cyclical** quantities:
- Week 52 is close to week 1 (end of year wraps to beginning)
- Hour 23 is close to hour 0 (midnight)
- Longitude -180° is the same as +180°

Simple linear encoding would make week 52 seem "far" from week 1. Sine/cosine encoding preserves cyclical relationships:

```python
# Week 1 and Week 52 have similar encodings:
week_1:  [sin(2π×1/52),  cos(2π×1/52)]  ≈ [0.12, 0.99]
week_52: [sin(2π×52/52), cos(2π×52/52)] ≈ [0.00, 1.00]
```

### Why Wavelengths Matter

Clay was trained on imagery from multiple sensors with different spectral bands. By explicitly providing the center wavelength of each band (e.g., 490nm for blue, 842nm for NIR), Clay can:
- Understand what each band measures
- Generalize to sensors it wasn't explicitly trained on
- Handle different band combinations

In [None]:
def prepare_timestamps(chip_info: List[dict], acq_metadata: dict, 
                       imagery_shape: Tuple, cfg: Config) -> np.ndarray:
    """
    Prepare timestamp arrays for Clay.
    
    Args:
        chip_info: List of chip metadata dicts
        acq_metadata: Acquisition metadata with datetime
        imagery_shape: Shape of original imagery (bands, height, width)
        cfg: Configuration object
    
    Returns:
        timestamps: Array of shape (n_chips, 4) with [week, hour, lat, lon]
    """
    # Parse datetime
    dt_str = acq_metadata.get("datetime", "2022-05-18T00:00:00")
    if dt_str:
        try:
            dt = datetime.fromisoformat(dt_str.replace("Z", "+00:00"))
        except:
            dt = datetime(2022, 5, 18, 0, 0)
    else:
        dt = datetime(2022, 5, 18, 0, 0)
    
    week = dt.isocalendar()[1]  # Week of year
    hour = dt.hour
    
    # Calculate chip centroids in lat/lon
    _, height, width = imagery_shape
    min_lon, min_lat, max_lon, max_lat = cfg.bbox
    
    timestamps = []
    for info in chip_info:
        # Chip center in pixel coordinates
        cy = info["y0"] + cfg.chip_size / 2
        cx = info["x0"] + cfg.chip_size / 2
        
        # Convert to lat/lon (simple linear interpolation)
        lat = max_lat - (cy / height) * (max_lat - min_lat)
        lon = min_lon + (cx / width) * (max_lon - min_lon)
        
        timestamps.append([week, hour, lat, lon])
    
    return np.array(timestamps, dtype=np.float32)


# Prepare timestamps
timestamps = prepare_timestamps(chip_info, acq_metadata, imagery.shape, cfg)

# Wavelengths (in nanometers)
wavelengths = np.array([cfg.wavelengths_nm], dtype=np.float32)  # (1, n_bands)

print("=== Clay Input Summary ===")
print(f"Chips: {chips_norm.shape}")
print(f"Timestamps: {timestamps.shape}")
print(f"  Sample: week={timestamps[0,0]:.0f}, hour={timestamps[0,1]:.0f}, lat={timestamps[0,2]:.4f}, lon={timestamps[0,3]:.4f}")
print(f"Wavelengths: {wavelengths.shape}")
print(f"  Values (nm): {wavelengths[0].tolist()}")

## Part 5: Generate Clay Embeddings

### What Happens Inside Clay?

Clay uses a **Vision Transformer (ViT)** architecture adapted for satellite imagery. Here's the processing pipeline:

```
Input Chip (256×256×C)
    ↓
┌─────────────────────────────────────┐
│  1. PATCH EMBEDDING                 │
│     Split into 8×8 pixel patches    │
│     → 32×32 = 1024 patches          │
│     Each patch → 1024-dim vector    │
└─────────────────────────────────────┘
    ↓
┌─────────────────────────────────────┐
│  2. POSITION ENCODING               │
│     Add spatial position info       │
│     Add time/location encoding      │
└─────────────────────────────────────┘
    ↓
┌─────────────────────────────────────┐
│  3. TRANSFORMER ENCODER             │
│     Self-attention across patches   │
│     12 transformer layers           │
│     Learns relationships between    │
│     all parts of the image          │
└─────────────────────────────────────┘
    ↓
┌─────────────────────────────────────┐
│  4. CLS TOKEN EXTRACTION            │
│     Special [CLS] token aggregates  │
│     information from all patches    │
│     → 1024-dim embedding vector     │
└─────────────────────────────────────┘
    ↓
Output Embedding (1024-dim)
```

### The CLS Token

The **CLS (classification) token** is a special learnable token prepended to the patch sequence. Through self-attention, it "attends" to all patches and learns to aggregate the most important information into a single vector. This makes it ideal as a global representation of the entire chip.

### Inference vs Training Mode

During training, Clay uses **masking** (hiding 75% of patches) to learn robust representations. For inference, we set `mask_ratio=0` to use all patches, producing the best quality embeddings.

### Output: A 1024-Dimensional Vector

Each chip is compressed from ~262,144 pixel values (256×256×4 bands) into a single 1024-dimensional vector that captures:
- Land cover type (urban, forest, water, agriculture)
- Spatial patterns and textures
- Seasonal characteristics (via time encoding)
- Geographic context (via location encoding)

In [None]:
# Check if checkpoint exists
checkpoint_path = Path(cfg.checkpoint_path)

if not checkpoint_path.exists():
    # Try alternate locations
    alt_paths = [
        Path("clay-v1.5.ckpt"),
        Path.home() / "clay-v1.5.ckpt",
        Path(cfg.data_dir) / "clay-v1.5.ckpt",
    ]
    for alt in alt_paths:
        if alt.exists():
            checkpoint_path = alt
            print(f"Found checkpoint at: {checkpoint_path}")
            break

if not checkpoint_path.exists():
    print("ERROR: Clay checkpoint not found!")
    print("\nPlease download the checkpoint:")
    print("  wget https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt")
    print(f"\nSearched locations:")
    print(f"  - {cfg.checkpoint_path}")
    for alt in alt_paths:
        print(f"  - {alt}")
    CHECKPOINT_OK = False
else:
    print(f"Checkpoint found: {checkpoint_path}")
    print(f"  Size: {checkpoint_path.stat().st_size / 1e9:.2f} GB")
    CHECKPOINT_OK = True

In [None]:
if CLAY_OK and CHECKPOINT_OK and TORCH_OK:
    print("Loading Clay model...")
    
    # Clay expects configs/metadata.yaml in the current directory.
    # This file is not included in the pip package, so we download it from the repo.
    configs_dir = Path("configs")
    metadata_path = configs_dir / "metadata.yaml"
    
    if not metadata_path.exists():
        print("Downloading Clay metadata.yaml from GitHub...")
        configs_dir.mkdir(exist_ok=True)
        metadata_url = "https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml"
        response = requests.get(metadata_url, timeout=30)
        response.raise_for_status()
        metadata_path.write_text(response.text)
        print(f"  Saved to {metadata_path}")
    else:
        print(f"Using existing {metadata_path}")
    
    # Load model with mask_ratio=0 for inference (no masking)
    model = ClayMAEModule.load_from_checkpoint(
        str(checkpoint_path),
        mask_ratio=0.0,  # No masking during inference
        shuffle=False,   # Don't shuffle patches
    )
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded on {device}")
    
    # Convert inputs to tensors
    chips_tensor = torch.tensor(chips_norm, dtype=torch.float32)
    
    # Clay expects time and latlon as sine/cosine encoded pairs:
    # time: [week_sin, week_cos, hour_sin, hour_cos] -> [B, 4]
    # latlon: [lat_sin, lat_cos, lon_sin, lon_cos] -> [B, 4]
    
    def encode_time_latlon(timestamps):
        """Encode timestamps to Clay's expected format with sin/cos encoding."""
        week = timestamps[:, 0]  # Week of year (0-52)
        hour = timestamps[:, 1]  # Hour of day (0-23)
        lat = timestamps[:, 2]   # Latitude (-90 to 90)
        lon = timestamps[:, 3]   # Longitude (-180 to 180)
        
        # Normalize and encode with sin/cos
        # Week: 0-52 -> 0-2π
        week_rad = 2 * np.pi * week / 52
        week_norm = np.stack([np.sin(week_rad), np.cos(week_rad)], axis=1)
        
        # Hour: 0-24 -> 0-2π
        hour_rad = 2 * np.pi * hour / 24
        hour_norm = np.stack([np.sin(hour_rad), np.cos(hour_rad)], axis=1)
        
        # Latitude: -90 to 90 -> -π/2 to π/2
        lat_rad = np.pi * lat / 180
        lat_norm = np.stack([np.sin(lat_rad), np.cos(lat_rad)], axis=1)
        
        # Longitude: -180 to 180 -> -π to π
        lon_rad = np.pi * lon / 180
        lon_norm = np.stack([np.sin(lon_rad), np.cos(lon_rad)], axis=1)
        
        # Combine: time [B, 4], latlon [B, 4]
        time_encoded = np.hstack([week_norm, hour_norm]).astype(np.float32)
        latlon_encoded = np.hstack([lat_norm, lon_norm]).astype(np.float32)
        
        return time_encoded, latlon_encoded
    
    time_encoded, latlon_encoded = encode_time_latlon(timestamps)
    time_tensor = torch.tensor(time_encoded, dtype=torch.float32)
    latlon_tensor = torch.tensor(latlon_encoded, dtype=torch.float32)
    
    # Wavelengths as 1D tensor
    waves_tensor = torch.tensor(cfg.wavelengths_nm, dtype=torch.float32).to(device)
    
    # GSD (ground sampling distance) in meters
    gsd_tensor = torch.tensor(cfg.resolution_m, dtype=torch.float32).to(device)
    
    # Generate embeddings in batches
    print(f"\nGenerating embeddings (batch_size={cfg.batch_size})...")
    
    embeddings_list = []
    n_chips = len(chips_tensor)
    
    with torch.no_grad():
        for i in range(0, n_chips, cfg.batch_size):
            batch_end = min(i + cfg.batch_size, n_chips)
            
            # Get batch
            batch_chips = chips_tensor[i:batch_end].to(device)
            batch_time = time_tensor[i:batch_end].to(device)
            batch_latlon = latlon_tensor[i:batch_end].to(device)
            
            # Create datacube dict for Clay encoder
            datacube = {
                "pixels": batch_chips,      # [B C H W]
                "time": batch_time,         # [B 4] - week_sin, week_cos, hour_sin, hour_cos
                "latlon": batch_latlon,     # [B 4] - lat_sin, lat_cos, lon_sin, lon_cos
                "gsd": gsd_tensor,          # scalar
                "waves": waves_tensor,      # [N] - wavelengths in nm
            }
            
            # Get encoder output
            # Returns: (encoded_patches, unmasked_indices, masked_indices, masked_matrix)
            encoded_patches, _, _, _ = model.model.encoder(datacube)
            
            # Extract CLS token embedding (first token)
            cls_embedding = encoded_patches[:, 0, :]  # [B D]
            
            embeddings_list.append(cls_embedding.cpu().numpy())
            
            print(f"  Processed {batch_end}/{n_chips} chips", end="\r")
            
            # Clear GPU cache periodically
            if device.type == "cuda" and (i + cfg.batch_size) % (cfg.batch_size * 10) == 0:
                torch.cuda.empty_cache()
    
    print()
    embeddings = np.vstack(embeddings_list)
    
    print(f"\n=== Embeddings Generated ===")
    print(f"Shape: {embeddings.shape}")
    print(f"Embedding dimension: {embeddings.shape[1]}")
    print(f"Value range: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
    
    EMBEDDINGS_OK = True
    
else:
    print("\n" + "="*60)
    print("CLAY NOT AVAILABLE - Creating placeholder embeddings")
    print("="*60)
    print("\nTo generate real embeddings:")
    print("1. Install Clay: pip install git+https://github.com/Clay-foundation/model.git")
    print("2. Download checkpoint: wget https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt")
    print("\nCreating PCA-based placeholder embeddings for demonstration...")
    
    # Create placeholder embeddings using PCA
    flat_chips = chips_norm.reshape(chips_norm.shape[0], -1)
    pca = PCA(n_components=min(128, flat_chips.shape[1]), random_state=cfg.random_seed)
    embeddings = pca.fit_transform(flat_chips).astype(np.float32)
    
    print(f"Placeholder embeddings shape: {embeddings.shape}")
    print("Note: These are NOT real Clay embeddings!")
    
    EMBEDDINGS_OK = False

In [None]:
# Visualize embedding distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Histogram of embedding values
axes[0].hist(embeddings.flatten(), bins=50, edgecolor='none', alpha=0.7)
axes[0].set_xlabel("Embedding value")
axes[0].set_ylabel("Frequency")
axes[0].set_title("Distribution of embedding values")

# Per-dimension mean
dim_means = embeddings.mean(axis=0)
axes[1].bar(range(len(dim_means)), dim_means, width=1.0, color='steelblue', edgecolor='none')
axes[1].set_xlabel("Dimension")
axes[1].set_ylabel("Mean value")
axes[1].set_title("Per-dimension mean")
axes[1].axhline(0, color='black', linewidth=0.5)

# PCA projection
pca_2d = PCA(n_components=2, random_state=cfg.random_seed)
emb_2d = pca_2d.fit_transform(embeddings)

axes[2].scatter(emb_2d[:, 0], emb_2d[:, 1], s=30, alpha=0.7, c=range(len(emb_2d)), cmap='viridis')
axes[2].set_xlabel(f"PC1 ({pca_2d.explained_variance_ratio_[0]*100:.1f}% var)")
axes[2].set_ylabel(f"PC2 ({pca_2d.explained_variance_ratio_[1]*100:.1f}% var)")
axes[2].set_title("Embeddings (PCA projection)")

plt.tight_layout()
plt.show()

print(f"Embedding dimension: {embeddings.shape[1]}")
if EMBEDDINGS_OK:
    print("(768 = Clay v1, 1024 = Clay v1.5)")

### Visualizing the Embedding Space

The plots below help us understand the structure of our embeddings:

1. **Distribution of values**: Shows the range and spread of embedding values. A roughly symmetric distribution centered near zero is typical.

2. **Per-dimension mean**: Each of the 1024 dimensions may capture different features. Some dimensions consistently activate (positive mean), others consistently deactivate (negative mean).

3. **PCA projection**: Projects 1024 dimensions to 2D. Points close together have similar embeddings (and likely similar content or position).

## Part 6: Save Embeddings

### Why Save Embeddings?

Generating embeddings is computationally expensive (requires GPU, takes time). Once generated, embeddings can be:
- **Reused** for multiple downstream tasks (classification, retrieval, clustering)
- **Shared** with others who don't have GPU access
- **Combined** with embeddings from other regions or time periods
- **Indexed** for fast similarity search at scale

### The Parquet Format

We save embeddings as **Parquet** files, a columnar storage format that:
- Compresses numerical data efficiently
- Supports complex types (lists, nested structures)
- Is widely supported (pandas, PyArrow, Spark, DuckDB)
- Enables fast partial reads (only load columns you need)

### Our Schema

| Column | Type | Description |
|--------|------|-------------|
| `item_id` | string | Unique identifier for each chip |
| `embeddings` | list[float] | The 1024-dim embedding vector |
| `geometry` | bytes (WKB) | Geographic bounds of the chip |
| `chip_row` | int | Row position in the original image grid |
| `chip_col` | int | Column position in the original image grid |

### Why WKB for Geometry?

**WKB (Well-Known Binary)** is a compact binary format for geometries:
- More space-efficient than text formats (WKT)
- Compatible with GeoParquet specification
- Readable by GIS tools (QGIS, PostGIS, GeoPandas)

This allows you to later visualize or query embeddings spatially.

In [None]:
def create_chip_geometries(chip_info: List[dict], imagery_shape: Tuple, 
                           cfg: Config) -> List[bytes]:
    """
    Create WKB geometries for each chip.
    
    Returns:
        List of WKB-encoded polygon geometries
    """
    if not SHAPELY_OK:
        print("Shapely not available - storing placeholder geometries")
        return [b"" for _ in chip_info]
    
    _, height, width = imagery_shape
    min_lon, min_lat, max_lon, max_lat = cfg.bbox
    
    # Pixel to geo coordinate conversion
    def pixel_to_geo(px, py):
        lon = min_lon + (px / width) * (max_lon - min_lon)
        lat = max_lat - (py / height) * (max_lat - min_lat)
        return lon, lat
    
    geometries = []
    for info in chip_info:
        x0, y0 = info["x0"], info["y0"]
        x1, y1 = x0 + cfg.chip_size, y0 + cfg.chip_size
        
        # Get corners in lon/lat
        lon0, lat0 = pixel_to_geo(x0, y0)
        lon1, lat1 = pixel_to_geo(x1, y1)
        
        # Create polygon
        poly = box(lon0, lat1, lon1, lat0)  # (minx, miny, maxx, maxy)
        geometries.append(wkb.dumps(poly))
    
    return geometries


# Create geometries
geometries = create_chip_geometries(chip_info, imagery.shape, cfg)

# Create DataFrame
df = pd.DataFrame({
    "item_id": [f"{acq_metadata.get('item_id', 'generated')}_{i}" for i in range(len(embeddings))],
    "embeddings": [emb.tolist() for emb in embeddings],
    "geometry": geometries,
    "chip_row": [info["row"] for info in chip_info],
    "chip_col": [info["col"] for info in chip_info],
})

print("DataFrame created:")
print(df.head())
print(f"\nShape: {df.shape}")

In [None]:
# Save to Parquet
output_path = Path(cfg.output_parquet)
output_path.parent.mkdir(parents=True, exist_ok=True)

df.to_parquet(output_path, index=False)

print(f"Embeddings saved to: {output_path}")
print(f"File size: {output_path.stat().st_size / 1e6:.2f} MB")

# Verify by reading back
df_verify = pd.read_parquet(output_path)
print(f"\nVerification - loaded {len(df_verify)} rows")
print(f"Embedding length: {len(df_verify['embeddings'].iloc[0])}")

## Part 7: Explore the Embedding Space

Now that we have embeddings, we can explore and use them. This demonstrates workflows you learned in Session 2:

### What Can We Do With Embeddings?

| Task | How It Works |
|------|--------------|
| **Similarity Search** | Find chips with similar embeddings (nearest neighbors) |
| **Classification** | Train a classifier on labeled embeddings |
| **Clustering** | Group similar chips without labels (k-means, HDBSCAN) |
| **Anomaly Detection** | Find chips that are "far" from all others |
| **Visualization** | Project to 2D/3D with PCA, t-SNE, or UMAP |

### Similarity Search with Cosine Distance

We use **cosine similarity** to find similar embeddings:
- Measures the angle between vectors (ignores magnitude)
- Value of 1.0 = identical direction, 0.0 = orthogonal
- Works well for high-dimensional embeddings

```python
cosine_similarity = dot(A, B) / (norm(A) * norm(B))
cosine_distance = 1 - cosine_similarity
```

In [None]:
# Load the embeddings we just saved
emb_df = pd.read_parquet(cfg.output_parquet)

# Extract embedding matrix
X_all = np.vstack(emb_df["embeddings"].values)

print(f"Loaded {len(emb_df)} embeddings")
print(f"Embedding matrix shape: {X_all.shape}")

In [None]:
# PCA visualization of embedding space
pca = PCA(n_components=2, random_state=cfg.random_seed)
X_pca = pca.fit_transform(X_all)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Colored by chip position (row)
sc1 = axes[0].scatter(X_pca[:, 0], X_pca[:, 1], 
                      c=emb_df["chip_row"], cmap='viridis', 
                      s=50, alpha=0.7)
plt.colorbar(sc1, ax=axes[0], label='Chip row')
axes[0].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)")
axes[0].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)")
axes[0].set_title("Embedding space (colored by row position)")

# Colored by chip position (column)
sc2 = axes[1].scatter(X_pca[:, 0], X_pca[:, 1], 
                      c=emb_df["chip_col"], cmap='plasma', 
                      s=50, alpha=0.7)
plt.colorbar(sc2, ax=axes[1], label='Chip column')
axes[1].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)")
axes[1].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)")
axes[1].set_title("Embedding space (colored by column position)")

plt.tight_layout()
plt.show()

print(f"Total variance explained: {pca.explained_variance_ratio_.sum()*100:.1f}%")

### Understanding the PCA Visualization

**PCA (Principal Component Analysis)** projects our 1024-dimensional embeddings down to 2D for visualization. Each point represents one chip's embedding.

### Why Do We See a Grid Pattern?

If you see a **regular grid pattern** in the PCA plot, this is expected behavior when processing chips from a single satellite image. Here's why:

**Clay embeddings encode BOTH content AND position.** The model receives:
- `time`: Same for all chips (one acquisition time)
- `latlon`: Varies systematically in a grid pattern across the image

Since all chips share the same timestamp, the **position encoding dominates** the embedding structure. The latitude and longitude of each chip vary in a regular grid pattern (row by row, column by column), creating the grid structure you see.

### What This Means

| Observation | Interpretation |
|-------------|----------------|
| Grid pattern | Position encoding is working correctly |
| Smooth color gradient | Embeddings vary continuously with position |
| Low variance explained | Most information is in higher dimensions |

### When Would We NOT See a Grid?

You would see **semantic clustering** (grouping by land cover type) when:
1. Chips come from **different locations** (diverse lat/lon, breaking the grid)
2. Chips come from **different times** (varied timestamps)
3. Position encoding is **zeroed out** (Clay's training sometimes does this)

### The Takeaway

The grid pattern confirms that Clay is correctly encoding spatial information. For downstream tasks like classification, a linear classifier can learn to use BOTH the position-dependent AND content-dependent parts of the embedding.

In [None]:
# Similarity search using nearest neighbors
print("Building similarity search index...")

# Normalize embeddings for cosine similarity
X_norm = X_all / (np.linalg.norm(X_all, axis=1, keepdims=True) + 1e-12)

# Build index
nn = NearestNeighbors(n_neighbors=min(10, len(X_norm)), metric="cosine")
nn.fit(X_norm)

# Query with the first chip
query_idx = 0
distances, indices = nn.kneighbors(X_norm[query_idx:query_idx+1])

print(f"\nQuery chip: {query_idx}")
print(f"Nearest neighbors (by cosine similarity):")
for rank, (dist, idx) in enumerate(zip(distances[0], indices[0])):
    print(f"  {rank+1}. idx={idx:4d}  distance={dist:.4f}  row={emb_df.iloc[idx]['chip_row']}  col={emb_df.iloc[idx]['chip_col']}")

In [None]:
# Visualize query and retrieved chips
k_show = min(6, len(indices[0]))

fig, axes = plt.subplots(2, k_show, figsize=(3 * k_show, 6))

for i, idx in enumerate(indices[0][:k_show]):
    chip = chips_norm[idx]
    
    # Create RGB with percentile normalization
    if chip.shape[0] >= 3:
        def chip_normalize(band):
            vmin, vmax = np.percentile(band, [2, 98])
            return np.clip((band - vmin) / (vmax - vmin + 1e-8), 0, 1)
        rgb = np.stack([chip_normalize(chip[2]), chip_normalize(chip[1]), chip_normalize(chip[0])], axis=-1)
    else:
        rgb = chip[0]
    
    # Top row: RGB
    axes[0, i].imshow(rgb, cmap="gray" if rgb.ndim == 2 else None)
    if i == 0:
        axes[0, i].set_title(f"Query (idx={idx})\ndist=0.0000", fontsize=10)
    else:
        axes[0, i].set_title(f"Neighbor {i}\nidx={idx}, dist={distances[0][i]:.4f}", fontsize=10)
    axes[0, i].axis("off")
    
    # Bottom row: embedding bar chart
    emb = embeddings[idx][:50]  # Show first 50 dimensions
    axes[1, i].bar(range(len(emb)), emb, width=1.0, color='steelblue', edgecolor='none')
    axes[1, i].set_xlabel("Dimension (first 50)")
    if i == 0:
        axes[1, i].set_ylabel("Value")
    axes[1, i].set_ylim([embeddings[:, :50].min(), embeddings[:, :50].max()])

plt.suptitle("Similarity Search Results", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Visualize spatial distribution of nearest neighbors
fig, ax = plt.subplots(figsize=(10, 8))

# All chips (gray)
ax.scatter(emb_df["chip_col"], emb_df["chip_row"], 
           s=50, c='lightgray', alpha=0.5, label='All chips')

# Nearest neighbors (colored by distance)
neighbor_rows = [emb_df.iloc[idx]["chip_row"] for idx in indices[0]]
neighbor_cols = [emb_df.iloc[idx]["chip_col"] for idx in indices[0]]

sc = ax.scatter(neighbor_cols, neighbor_rows, 
                c=distances[0], cmap='viridis_r', 
                s=200, edgecolor='black', linewidth=1.5,
                label='Nearest neighbors')
plt.colorbar(sc, ax=ax, label='Cosine distance')

# Mark query with star
ax.scatter([emb_df.iloc[query_idx]["chip_col"]], 
           [emb_df.iloc[query_idx]["chip_row"]],
           marker='*', s=400, c='red', edgecolor='black',
           label='Query chip', zorder=10)

ax.set_xlabel("Chip column")
ax.set_ylabel("Chip row")
ax.set_title("Spatial distribution of nearest neighbors")
ax.legend(loc='upper right')
ax.invert_yaxis()  # Match image coordinates

plt.tight_layout()
plt.show()

print("\nNote: Embeddings capture semantic similarity, not just spatial proximity.")
print("Similar embeddings may come from distant locations if they have similar content.")

## Part 8: Summary

Congratulations! You have learned to generate your own Clay embeddings from scratch.

In [None]:
# Summary statistics
print("=" * 60)
print("SESSION 2B SUMMARY: Generate Your Own Clay Embeddings")
print("=" * 60)

print("\n--- Data Acquisition ---")
print(f"Source:           {acq_metadata['source']}")
print(f"Image size:       {imagery.shape[1]} x {imagery.shape[2]} pixels")
print(f"Bands:            {len(cfg.bands)} ({', '.join(cfg.bands)})")

print("\n--- Chipping ---")
print(f"Chip size:        {cfg.chip_size} x {cfg.chip_size} pixels")
print(f"Valid chips:      {len(chips_norm)}")

print("\n--- Embeddings ---")
print(f"Embedding dim:    {embeddings.shape[1]}")
print(f"Total embeddings: {len(embeddings)}")
print(f"Generated with:   {'Clay encoder' if EMBEDDINGS_OK else 'PCA (placeholder)'}")

print("\n--- Output ---")
print(f"Saved to:         {cfg.output_parquet}")
print(f"File size:        {Path(cfg.output_parquet).stat().st_size / 1e6:.2f} MB")

print("\n" + "=" * 60)
print("WHAT YOU LEARNED:")
print("=" * 60)
print("1. Acquire satellite imagery via STAC or sample files")
print("2. Chip imagery into 256x256 tiles for Clay")
print("3. Prepare Clay inputs (pixels, timestamps, wavelengths)")
print("4. Generate embeddings using Clay's encoder")
print("5. Save embeddings in Session 2-compatible format")
print("6. Use embeddings for similarity search")

print("\n" + "=" * 60)
print("NEXT STEPS:")
print("=" * 60)
print("- Session 3: Fine-tune Clay with LoRA for your specific task")
print("- Try different AOIs and time ranges")
print("- Add labels and train a classifier (like Session 2)")
print("- Scale up: process larger areas with batching")

## Optional: Compare with Session 2 Precomputed Embeddings

If you have the Session 2 precomputed embeddings, you can compare statistics.

In [None]:
# Try to load Session 2 precomputed embeddings for comparison
session2_dir = Path("./data/practical2_clay")
session2_files = list(session2_dir.glob("*.gpq")) if session2_dir.exists() else []

if len(session2_files) > 0:
    print(f"Found {len(session2_files)} Session 2 embedding files")
    
    # Load first file
    s2_df = pd.read_parquet(session2_files[0])
    s2_emb = np.vstack(s2_df["embeddings"].values)
    
    print(f"\nSession 2 embeddings: {s2_emb.shape}")
    print(f"Your embeddings:      {embeddings.shape}")
    
    print(f"\n--- Comparison ---")
    print(f"Session 2 dim: {s2_emb.shape[1]}  (768 = Clay v1)")
    print(f"Your dim:      {embeddings.shape[1]}  ({'Clay v1.5' if embeddings.shape[1] == 1024 else 'Clay v1' if embeddings.shape[1] == 768 else 'Other'})")
    
    print(f"\n--- Value Statistics ---")
    print(f"Session 2 mean: {s2_emb.mean():.4f}, std: {s2_emb.std():.4f}")
    print(f"Your mean:      {embeddings.mean():.4f}, std: {embeddings.std():.4f}")
else:
    print("Session 2 precomputed embeddings not found.")
    print("Run Session 2 first to download the precomputed embeddings for comparison.")