In [1]:
from pathlib import Path
import torch
from gelos.gelosdatamodule import GELOSDataModule
import yaml
from gelos import config
from lightning.pytorch import Trainer
from pathlib import Path
from tqdm import tqdm
from gelos.config import PROJ_ROOT, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from terratorch.tasks import EmbeddingGenerationTask
from gelos.features import LenientEmbeddingGenerationTask

[32m2025-12-06 00:16:39.861[0m | [1mINFO    [0m | [36mgelos.config[0m:[36m<module>[0m:[36m16[0m - [1mPROJ_ROOT path is: /app[0m
  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
from lightning.pytorch.cli import instantiate_class

In [3]:
yaml_config_directory = PROJ_ROOT / 'gelos' / 'configs'
yaml_paths = list(yaml_config_directory.glob('*.yaml'))
print(yaml_paths)

[PosixPath('/app/gelos/configs/terramind_embedding_generation.yaml'), PosixPath('/app/gelos/configs/prithvi_eo_600m_embedding_generation.yaml'), PosixPath('/app/gelos/configs/prithvi_eo_300m_embedding_generation.yaml')]


In [4]:
yaml_path = yaml_paths[0]

## Run Embedding Generation step by step

In [5]:
with open(yaml_path, "r") as f:
        yaml_config = yaml.safe_load(f)

print(yaml.dump(yaml_config))

model_name = yaml_config['model']['init_args']['model']
output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name
output_dir.mkdir(exist_ok=True, parents=True)
data_root = RAW_DATA_DIR / DATA_VERSION

# add variables to yaml config so it can be passed to classes
yaml_config['data']['init_args']['data_root'] = data_root
yaml_config['model']['init_args']['output_dir'] = output_dir

# instantiate transform classes if they exist
if "transform" in yaml_config["data"]["init_args"].keys():
      yaml_config["data"]["init_args"]["transform"] = [
            instantiate_class(args = (), init=class_path) for class_path in yaml_config["data"]["init_args"]["transform"]
      ]
gelos_datamodule = GELOSDataModule(**yaml_config['data']['init_args'])
task = LenientEmbeddingGenerationTask(**yaml_config['model']['init_args'])

device = 'gpu' if torch.cuda.is_available() else 'cpu'
trainer = Trainer(accelerator=device, devices=1)

data:
  class_path: gelos.gelosdatamodule.GELOSDataModule
  init_args:
    bands:
      DEM:
      - DEM
      S1RTC:
      - VV
      - VH
      S2L2A:
      - COASTAL_AEROSOL
      - BLUE
      - GREEN
      - RED
      - RED_EDGE_1
      - RED_EDGE_2
      - RED_EDGE_3
      - NIR_BROAD
      - NIR_NARROW
      - WATER_VAPOR
      - SWIR_1
      - SWIR_2
    batch_size: 1
    num_workers: 0
    repeat_bands:
      DEM: 4
    transform:
    - class_path: terratorch.datasets.transforms.FlattenTemporalIntoChannels
    - class_path: albumentations.augmentations.geometric.resize.Resize
      init_args:
        height: 96
        interpolation: 0
        width: 96
    - class_path: albumentations.pytorch.transforms.ToTensorV2
    - class_path: terratorch.datasets.transforms.UnflattenTemporalFromChannels
      init_args:
        n_timesteps: 4
embedding_extraction_strategies:
  All Patches from April to June:
  - start: 1
    step: 1
    stop: 2
  All Steps of Middle Patch:
  - start: 0
  

INFO:httpx:HTTP Request: HEAD https://huggingface.co/ibm-esa-geospatial/TerraMind-1.0-base/resolve/main/TerraMind_v1_base.pt "HTTP/1.1 302 Found"
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


### Inspect outputs of dataset

In [6]:
gelos_datamodule.setup(stage="predict")

In [13]:
for k, v in gelos_datamodule.dataset[0].items():
    if k == "image":
        for sensor, data in v.items():
            print(sensor, data.shape)
    else:
        print(k, v)

S2L2A (4, 96, 96, 12)
S1RTC (4, 96, 96, 2)
DEM (4, 32, 32, 1)
S2L2A torch.Size([12, 4, 96, 96])
S1RTC torch.Size([2, 4, 96, 96])
DEM torch.Size([1, 4, 96, 96])
S2L2A torch.Size([12, 4, 96, 96])
S1RTC torch.Size([2, 4, 96, 96])
DEM torch.Size([1, 4, 96, 96])
filename 000000
file_id 0
