# OlmoEarth Embeddings using data from Google Earth Engine (GEE)

**Author**: Ivan Zvonkov (ivan.zvonkov@gmail.com)

**Last modified**: Oct 27, 2025

**Description**: One-stop shop for generating OlmoEarth embeddings using data from Google Earth Engine. The notebook is intended to be run on Google Colab with a GPU.

1. **Setup**: Specifies inference run configuration and shows amount of embeddings already generated.
2. **GEE data exports**: exports Earth observation data to a cloud bucket in tiles.
3. **OlmoEarth Setup**: Loads OlmoEarth model and creates function for converting Google Earth Engine data into OlmoEarth format.

4. **OlmoEarth Inference**: Runs OlmoEarth model on Earth observation data and upload embeddings to bucket.

The final section **Debugging / Visualization** can be ignored but may be useful when modifying input data sources.

## 1. Setup



In [None]:
# Inference run configuration
#-------------------------------------------------------------------------------
NAME = "Togo_v1"
START_DATE = '2019-03-01'
END_DATE = '2020-03-01'
RUN = f"{NAME}_{START_DATE}_{END_DATE}"

INFERENCE_WINDOW_SIZE = 100
EMBEDDINGS_SIZE = 192

GCLOUD_PROJECT = "ai2-ivan"
IN_BUCKET = "ai2-ivan-helios-input-data" # Bucket for GEE input data
OUT_BUCKET = "ai2-ivan-helios-output-data" # Bucket for embedding outputs


# General setup
#-------------------------------------------------------------------------------
from google.colab import auth
from google.cloud import storage
from tqdm.notebook import tqdm
from pathlib import Path

auth.authenticate_user()
client = storage.Client(project=GCLOUD_PROJECT)
in_bucket = client.bucket(IN_BUCKET)
out_bucket = client.bucket(OUT_BUCKET)

BANDS = {
    "sentinel1":  ["VV", "VH"],
    "sentinel2":  ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"],
    "landsat":    ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
}

def remaining_tiles(mod=None, index=None):
  in_tifs = {b.name for b in in_bucket.list_blobs(prefix=RUN)}
  out_tifs = {b.name for b in out_bucket.list_blobs(prefix=RUN)}
  print(f"Embeddings generated:  {len(out_tifs)}/{len(in_tifs)}")
  remaining = list(in_tifs - out_tifs)
  if mod and index:
    remaining = [t for t in remaining if ((int(Path(t).stem) % mod) == index)]
  return remaining
remaining_tiles();

## 2. GEE data exports (only run once)

In [None]:
import ee
import google

SCOPES = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/earthengine"]
CREDENTIALS, _ = google.auth.default(default_scopes=SCOPES)
ee.Initialize(CREDENTIALS, project=GCLOUD_PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

roi = ee.FeatureCollection("FAO/GAUL/2015/level2").filter(ee.Filter.eq('ADM0_NAME', 'Togo')).geometry()
GEE_TILE_SIZE = 10*1000 # 10km2

start = ee.Date(START_DATE)
end = ee.Date(END_DATE)

# Sentinel-1 Data
#-------------------------------------------------------------------------------
S1_all = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(roi).filterDate(start.advance(-31, 'days'), end.advance(31, 'days'))
S1 = S1_all.filter(ee.Filter.eq("orbitProperties_pass", S1_all.first().get("orbitProperties_pass"))).filter(ee.Filter.eq("instrumentMode", "IW"))
S1_VV = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
S1_VH = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))

def getCloseImages(middleDate, imageCollection):
  def setDate(img):
    dateDist = ee.Number(img.get("system:time_start")).subtract(middleDate.millis()).abs()
    return img.set("dateDist", dateDist)
  fromMiddleDate = imageCollection.map(setDate).sort("dateDist", True)
  fifteenDaysInMs = ee.Number(1296000000)
  maxDiff = ee.Number(fromMiddleDate.first().get("dateDist")).max(fifteenDaysInMs)
  return fromMiddleDate.filterMetadata("dateDist", "not_greater_than", maxDiff)

def get_S1_img(date1, date2):
  daysBetween = date2.difference(date1, 'days')
  middleDate = date1.advance(daysBetween.divide(2), 'days')
  kept_vv = getCloseImages(middleDate, S1_VV).select("VV")
  kept_vh = getCloseImages(middleDate, S1_VH).select("VH")
  S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()])
  return S1_composite.select(BANDS["sentinel1"]).clip(roi).float() # S1 ranges from -50 to 1


# Sentinel-2 data
#-------------------------------------------------------------------------------
# In Togo CLOUD_SCORE_PLUS plus gives better mosaics than sorting by CLOUD_COVERAGE_ASSESSMENT
S2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(roi).filterDate(start, end)
csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED').filterBounds(roi).filterDate(start, end)
QA_BAND = 'cs_cdf'; # Better than cs here
S2_cf = S2.linkCollection(csPlus, [QA_BAND])

def get_S2_img(date1, date2):
  return S2_cf.filterDate(date1, date2).qualityMosaic(QA_BAND).select(BANDS["sentinel2"]).clip(roi).float()


# Landsat 8 data
#-------------------------------------------------------------------------------
LANDSAT_SR = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').merge(ee.ImageCollection('LANDSAT/LC08/C02/T2_L2')).filterBounds(roi).filterDate(start, end)
LANDSAT_TOA = ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA').merge(ee.ImageCollection('LANDSAT/LC08/C02/T2_TOA')).filterBounds(roi).filterDate(start, end)
landsat = LANDSAT_SR.linkCollection(LANDSAT_TOA, ["B8", "B9", "B11"])

def get_landsat_img(date1, date2):
  landsat_img = landsat.filterDate(date1, date2).sort("CLOUD_COVER").mosaic().clip(roi).set("system:index", "")
  SR_BANDS = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "ST_B10"]
  landsat_SR = landsat_img.select(SR_BANDS).rename(["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B10"])
  landsat_TOA = landsat_img.select(["B8", "B9", "B11"]).add(0.2).divide(0.0000275).toInt16()
  return landsat_SR.addBands(landsat_TOA).float()

proj = landsat.first().select("SR_B1").projection()
crs = proj.crs().getInfo()
print(crs)

In [None]:
# Create GEE tasks
#-------------------------------------------------------------------------------
S1_img_list = []
S2_img_list = []
landsat_img_list = []
numMonths = end.difference(start, 'month').toInt().getInfo()
for i in range(numMonths):
  d1 = start.advance(i, 'month')
  d2 = d1.advance(1, 'month')
  S1_img_list.append(get_S1_img(d1, d2))
  S2_img_list.append(get_S2_img(d1, d2))
  landsat_img_list.append(get_landsat_img(d1, d2))

def imageFromList(pre, imgList):
  seq = ee.ImageCollection.fromImages(imgList).toBands()
  newNames = seq.bandNames().map(lambda b: ee.String(pre+"_").cat(b))
  return seq.rename(newNames)

theBigInputImage = ee.Image.cat([
  ee.Image.pixelLonLat().clip(roi).select("latitude", "longitude").float(),
  imageFromList("sentinel2",  S2_img_list),
  imageFromList("sentinel1",  S1_img_list),
  imageFromList("landsat", landsat_img_list)
])

grid = roi.buffer(-100).coveringGrid(crs, GEE_TILE_SIZE) # Buffered to avoid almost empty tiles
grid_list = grid.toList(grid.size())

tasks = []
already_exist = 0
print("Preparing EarthEngine tasks...")
for i in tqdm(range(grid.size().getInfo())):
  if in_bucket.blob(f"{RUN}/{i}.tif").exists():
    already_exist += 1
    continue

  tile = ee.Feature(grid_list.get(i)).geometry()
  task = ee.batch.Export.image.toCloudStorage(
    image=theBigInputImage.clip(tile),
    description=f"{RUN}_{i}", bucket=IN_BUCKET, fileFormat='GeoTIFF',
    fileNamePrefix=f"{RUN}/{i}", scale=10, crs=crs, region=tile,
  );
  tasks.append(task)

if already_exist > 0:
  print(f"{already_exist} tiles already exist in Cloud Storage.")
if len(tasks) > 0:
  print(f"Run next cell to start exports.")

In [None]:
# Start gee tasks for tiles not yet exported
for task in tqdm(tasks):
  task.start()
print(f"Started {len(tasks)} export tasks, see: https://code.earthengine.google.com/tasks")

## 3. OlmoEarth Setup

In [None]:
# 3 minutes to setup
!pip install rasterio -q

from getpass import getpass
from datetime import datetime as dt

import numpy as np
import pandas as pd
import time
import json
import torch

!git clone https://{getpass("Github Token: ")}@github.com/allenai/helios.git
!pip install -q -r helios/requirements.txt

%cd /content/helios

from torch.utils.data import default_collate
from olmo_core.config import Config
from olmo_core.distributed.checkpoint import load_model_and_optim_state
from olmoearth_pretrain.data.dataset import OlmoEarthSample
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, Modality
from olmoearth_pretrain.data.normalize import Normalizer, Strategy

!pip install rasterio -q
import rasterio as rio
from rasterio.windows import Window

# Download model weights
!gcloud storage cp --no-user-output-enabled -r gs://helios-embeddings-bucket/latent_mim_tiny_shallow_decoder_lr2e-4_255000 .

In [None]:
# Load model
with open("latent_mim_tiny_shallow_decoder_lr2e-4_255000/config.json", "r") as f:
  config_dict = json.load(f)
model = Config.from_dict(config_dict["model"]).build()
load_model_and_optim_state("latent_mim_tiny_shallow_decoder_lr2e-4_255000/model_and_optim", model)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.encoder.to(device)

# Derive timestamps
to_date_obj = lambda d: dt.strptime(d, "%Y-%m-%d").date()
timestamps_pd = pd.date_range(to_date_obj(START_DATE), to_date_obj(END_DATE), freq="MS")[:-1]
timestamps = [[t.year, t.month - 1, t.day] for t in timestamps_pd]

# Helper function for data prep
def prepare_masked_olmo_earth_sample(tile, bands, device=None):
  num_pixels = tile.shape[1] * tile.shape[2]
  input_data = tile.reshape(len(bands), num_pixels)

  # Fill input dict using geotiff data
  input_dict_raw = {
    "timestamps": np.array([timestamps] * num_pixels),
    "latlon":     input_data[[bands.index("latitude"), bands.index("longitude")]].transpose(1, 0),
    "landsat":    np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["landsat"]))),
    "sentinel1":  np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["sentinel1"]))),
    "sentinel2":  np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["sentinel2"]))),
  }
  for i, key in enumerate(bands):
    if key == "latitude" or key == "longitude":
      continue
    modality, timestep_str, band = key.split("_")
    band_index = BANDS[modality].index(band)
    input_dict_raw[modality][:, 0, 0, int(timestep_str), band_index] = input_data[i]

  # Normalize input dict
  computed = Normalizer(Strategy.COMPUTED)
  predefined = Normalizer(Strategy.PREDEFINED)
  input_dict_normed = {
    "timestamps": input_dict_raw["timestamps"],
    "latlon":    predefined.normalize(Modality.LATLON, input_dict_raw["latlon"]).astype(np.float32),
    "landsat":   computed.normalize(Modality.LANDSAT, input_dict_raw["landsat"]).astype(np.float32),
    "sentinel1": computed.normalize(Modality.SENTINEL1, input_dict_raw["sentinel1"]).astype(np.float32),
    "sentinel2": computed.normalize(Modality.SENTINEL2_L2A, input_dict_raw["sentinel2"]).astype(np.float32),
  }

  # Prepared MaskedOlmoEarthSample
  masked_sample_dicts_list = []
  for i in range(num_pixels):
    sample = OlmoEarthSample(
        sentinel2_l2a=input_dict_normed["sentinel2"][i],
        sentinel1=input_dict_normed["sentinel1"][i],
        landsat=input_dict_normed["landsat"][i],
        timestamps=input_dict_normed["timestamps"][i],
        latlon=input_dict_normed["latlon"][i],
    )
    masked_sample = MaskedOlmoEarthSample.from_olmoearthsample(sample)
    masked_sample_dicts_list.append(masked_sample.as_dict(return_none=False))
  collated_sample = default_collate(masked_sample_dicts_list)
  collated_sample_to_device = {k: v.to(device) for k,v in collated_sample.items()}
  return MaskedOlmoEarthSample(**collated_sample_to_device)

## 4. OlmoEarth Inference

In [None]:
INFERENCE_WINDOW_SIZE = 100
remaining = remaining_tiles()
while len(remaining) > 0:
  for tile in tqdm(remaining):
    print(tile)
    print(f"\n\tDownloading input data ...\t", end="")
    start = time.perf_counter()
    in_bucket.blob(tile).download_to_filename("in.tif")
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")

    with rio.open("in.tif") as src:
      profile = src.profile
      bands = src.descriptions
      height, width = src.height, src.width

      profile.update(count=EMBEDDINGS_SIZE, dtype="float32", compress="deflate", bigtiff="YES")

      total_windows = (height // INFERENCE_WINDOW_SIZE) * (width // INFERENCE_WINDOW_SIZE)
      with tqdm(total=total_windows, desc="\tRunning inference") as pbar:
        with rio.open("out.tif", "w", **profile) as dst:
          for y in range(0, height, INFERENCE_WINDOW_SIZE):
            for x in range(0, width, INFERENCE_WINDOW_SIZE):
              win = Window(x, y, min(INFERENCE_WINDOW_SIZE, width - x), min(INFERENCE_WINDOW_SIZE, height - y))
              data = src.read(window=win)
              masked_sample = prepare_masked_olmo_earth_sample(data, bands, device)

              with torch.no_grad():
                  preds = model(masked_sample, patch_size=1, fast_pass=True)

              embeddings = preds["project_aggregated"].cpu().numpy().transpose(1, 0)
              embeddings_reshaped = embeddings.reshape(embeddings.shape[0], data.shape[1], data.shape[2])

              dst.write(embeddings_reshaped.astype("float32"), window=win)
              pbar.update(1)

    print(f"\tUploading embeddings ...\t", end="")
    start = time.perf_counter()
    out_bucket.blob(tile).upload_from_filename("out.tif")
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")

    !rm "in.tif" "out.tif"

    remaining = remaining_tiles()

## Debugging / Visualization

In [None]:
import matplotlib.pyplot as plt

In [None]:
def visualize(imgs):
  fig, axes = plt.subplots(2, 6, figsize=(12, 4))
  for ax, img in zip(axes.flat, imgs):
      ax.imshow(img)
      ax.set_axis_off()
  plt.tight_layout()

def get_imgs(modality_data, band_indices):
  flattened = modality_data[:, 0, 0, :, band_indices].transpose(2, 1, 0)
  reshaped = flattened.reshape(len(timestamps), tile.shape[1], tile.shape[2], len(band_indices))
  return np.nan_to_num(reshaped, nan=0)

In [None]:
# S2 raw with manual normalization
s2_imgs_raw = get_imgs(input_dict_raw["sentinel2"], [BANDS["sentinel2"].index(b) for b in ["B4", "B3", "B2"]])
s2_imgs_rough_norm = np.clip((s2_imgs_raw / 10000) / 0.15, 0, 1) # Rough norm
visualize(s2_imgs_rough_norm)

In [None]:
# S2 helios normalization
s2_imgs_helios_normed = get_imgs(input_dict_normed["sentinel2"], [BANDS["sentinel2"].index(b) for b in ["B4", "B3", "B2"]])
visualize(s2_imgs_helios_normed)

In [None]:
# S1 raw with manual normalization
s1_imgs_raw = get_imgs(input_dict_raw["sentinel1"], [BANDS["sentinel1"].index(b) for b in ["VV", "VH", "VV"]])
s1_imgs_rough_norm = (s1_imgs_raw + 25) / 25
s1_imgs_rough_norm[:, :, :, 0] = np.clip(s1_imgs_rough_norm[:, :, :, 0], 0, 1)
s1_imgs_rough_norm[:, :, :, 1] = np.clip(s1_imgs_rough_norm[:, :, :, 1], -0.2, 0.8) + 0.2
s1_imgs_rough_norm[:, :, :, 2] = np.clip(s1_imgs_rough_norm[:, :, :, 2], 0.4, 1.2) + -0.2
visualize(s1_imgs_rough_norm)

In [None]:
# S1 with Helios normalization (pretty good)
s1_imgs_normed = get_imgs(input_dict_normed["sentinel1"], [BANDS["sentinel1"].index(b) for b in ["VV", "VH", "VV"]])
visualize(s1_imgs_normed)

In [None]:
# landsat raw with manual normalization
landsat_imgs_raw = get_imgs(input_dict_raw["landsat"], [BANDS["landsat"].index(b) for b in ["B4", "B3", "B2"]])
landsat_imgs_rough_norm = np.clip((landsat_imgs_raw / 25000), 0, 1) # Rough norm
visualize(landsat_imgs_rough_norm)

In [None]:
# Landsat with Helios normalization (pretty good)
landsat_imgs_normed = get_imgs(input_dict_normed["landsat"], [BANDS["landsat"].index(b) for b in ["B4", "B3", "B2"]])
visualize(landsat_imgs_normed)

In [None]:
# sentinel2_l2a: ArrayTensor | None = None  # [B, H, W, T, len(S2_bands)]
# latlon: ArrayTensor | None = None  # [B, 2]
# timestamps: ArrayTensor | None = None  # [B, T, D=3], where D=[day, month, year]
# sentinel1: ArrayTensor | None = None  # [B, H, W, T, len(S1_bands)]
# worldcover: ArrayTensor | None = None  # [B, H, W, 1, len(WC_bands)]
# openstreetmap_raster: ArrayTensor | None = None  # [B, H, W, 1, len(OSM_bands)]
# srtm: ArrayTensor | None = None  # [B, H, W, 1, len(SRTM_bands)]
# landsat: ArrayTensor | None = None  # [B, H, W, T, len(LANDSAT_bands)]
# # naip with different tile resolution is currently not used in favor of naip_10.
# naip: ArrayTensor | None = None  # [B, H, W, T, len(NAIP_bands)]
# # naip_10 is currently 4x the height/width of sentinel2_l2a.
# naip_10: ArrayTensor | None = None  # [B, H, W, T, len(NAIP_bands)]
# gse: ArrayTensor | None = None  # [B, H, W, 1, len(GSE_bands)]
# cdl: ArrayTensor | None = None  # [B, H, W, 1, len(CDL_bands)]
# worldpop: ArrayTensor | None = None  # [B, H, W, 1, len(WORLDPOP_bands)]
# worldcereal: ArrayTensor | None = None  # [B, H, W, 1, len(CDL_bands)]
# wri_canopy_height_map: ArrayTensor | None = None  # [B, H, W, 1, 1]
# # era5_10 is not spatially varying, so it has no height/width dimensions.
# era5_10: ArrayTensor | None = None  # [B, T, len(ERA5_bands)]