In [1]:
from pathlib import Path
import torch
from gelos.gelosdatamodule import GELOSDataModule
import yaml
from gelos import config
from lightning.pytorch import Trainer
from pathlib import Path
from tqdm import tqdm
from gelos.config import PROJ_ROOT, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from terratorch.tasks import EmbeddingGenerationTask


[32m2025-12-05 22:18:26.849[0m | [1mINFO    [0m | [36mgelos.config[0m:[36m<module>[0m:[36m16[0m - [1mPROJ_ROOT path is: /app[0m
  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
from lightning.pytorch.cli import instantiate_class


In [3]:
class LenientEmbeddingGenerationTask(EmbeddingGenerationTask):
    def check_file_ids(self, file_ids, x):
        return

def generate_embeddings(yaml_path: Path) -> None:

    with open(yaml_path, "r") as f:
        yaml_config = yaml.safe_load(f)
    
    print(yaml.dump(yaml_config))

    model_name = yaml_config['model']['init_args']['model']
    output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name
    output_dir.mkdir(exist_ok=True, parents=True)
    data_root = RAW_DATA_DIR / DATA_VERSION

    # add variables to yaml config so it can be passed to classes
    yaml_config['data']['init_args']['data_root'] = data_root
    yaml_config['model']['init_args']['output_dir'] = output_dir

    gelos_datamodule = GELOSDataModule(**yaml_config['data']['init_args'])
    task = LenientEmbeddingGenerationTask(**yaml_config['model']['init_args'])

    device = 'gpu' if torch.cuda.is_available() else 'cpu'
    trainer = Trainer(accelerator=device, devices=1)

    trainer.predict(model=task, datamodule=gelos_datamodule)


In [4]:
yaml_config_directory = PROJ_ROOT / 'gelos' / 'configs'
yaml_paths = list(yaml_config_directory.glob('*.yaml'))
print(yaml_paths)

[PosixPath('/app/gelos/configs/terramind_embedding_generation.yaml'), PosixPath('/app/gelos/configs/prithvi_eo_600m_embedding_generation.yaml'), PosixPath('/app/gelos/configs/prithvi_eo_300m_embedding_generation.yaml')]


In [5]:
yaml_path = yaml_paths[1]

## Run Embedding Generation step by step

In [6]:
def materialize(transform_list):
    instantiated_transforms_list = []
    for class_path in transform_list:
        instantiated_transforms_list.append(instantiate_class(args=(), init=class_path))
    return instantiated_transforms_list 

In [35]:
with open(yaml_path, "r") as f:
        yaml_config = yaml.safe_load(f)

print(yaml.dump(yaml_config))

model_name = yaml_config['model']['init_args']['model']
output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name
output_dir.mkdir(exist_ok=True, parents=True)
data_root = RAW_DATA_DIR / DATA_VERSION

# add variables to yaml config so it can be passed to classes
yaml_config['data']['init_args']['data_root'] = data_root
yaml_config['model']['init_args']['output_dir'] = output_dir

# instantiate transform classes if they exist
if yaml_config["data"]["init_args"]["transform"]:
      yaml_config["data"]["init_args"]["transform"] = [
            instantiate_class(args = (), init=class_path) for class_path in yaml_config["data"]["init_args"]["transform"]
      ]
gelos_datamodule = GELOSDataModule(**yaml_config['data']['init_args'])
task = LenientEmbeddingGenerationTask(**yaml_config['model']['init_args'])

device = 'gpu' if torch.cuda.is_available() else 'cpu'
trainer = Trainer(accelerator=device, devices=1)

INFO:terratorch.models.backbones.prithvi_vit:model_bands not passed. Assuming bands are ordered in the same way as [<HLSBands.BLUE: 'BLUE'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.RED: 'RED'>, <HLSBands.NIR_NARROW: 'NIR_NARROW'>, <HLSBands.SWIR_1: 'SWIR_1'>, <HLSBands.SWIR_2: 'SWIR_2'>].Pretrained patch_embed layer may be misaligned with current bands


data:
  class_path: gelos.gelosdatamodule.GELOSDataModule
  init_args:
    bands:
      S2L2A:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    batch_size: 1
    num_workers: 0
    transform:
    - class_path: terratorch.datasets.transforms.FlattenTemporalIntoChannels
    - class_path: albumentations.PadIfNeeded
      init_args:
        min_height: 98
        min_width: 98
    - class_path: albumentations.pytorch.transforms.ToTensorV2
    - class_path: terratorch.datasets.transforms.UnflattenTemporalFromChannels
      init_args:
        n_channels: 6
        n_timesteps: 4
embedding_extraction_strategies:
  All Patches from April to June:
  - start: 50
    step: 1
    stop: 99
  All Steps of Middle Patch:
  - start: 25
    step: 49
    stop: null
  CLS Token:
  - start: 0
    step: 1
    stop: 1
model:
  class_path: terratorch.tasks.EmbeddingGenerationTask
  init_args:
    embed_file_key: filename
    embedding_pooling: null
    has_cls: true


INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [36]:
gelos_datamodule.setup(stage = "predict")

In [37]:
gelos_datamodule.dataset[0]['image'].shape

torch.Size([6, 96, 98, 98])