In [1]:
from pathlib import Path
import torch
from gelos.gelosdatamodule import GELOSDataModule
import yaml
from gelos import config
from lightning.pytorch import Trainer
from pathlib import Path
from tqdm import tqdm
from gelos.config import PROJ_ROOT, EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from terratorch.tasks import EmbeddingGenerationTask
from gelos.features import LenientEmbeddingGenerationTask

[32m2025-12-06 21:13:05.726[0m | [1mINFO    [0m | [36mgelos.config[0m:[36m<module>[0m:[36m16[0m - [1mPROJ_ROOT path is: /app[0m
  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [2]:
from lightning.pytorch.cli import instantiate_class

In [3]:
yaml_file = "prithvi_eo_300m_embedding_generation.yaml"

In [4]:
yaml_config_directory = PROJ_ROOT / 'gelos' / 'configs'
yaml_path = yaml_config_directory / yaml_file
print(yaml_path)

/app/gelos/configs/prithvi_eo_300m_embedding_generation.yaml


## Run Embedding Generation step by step

In [5]:
with open(yaml_path, "r") as f:
        yaml_config = yaml.safe_load(f)

print(yaml.dump(yaml_config))

model_name = yaml_config['model']['init_args']['model']
output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name
output_dir.mkdir(exist_ok=True, parents=True)
data_root = RAW_DATA_DIR / DATA_VERSION

# add variables to yaml config so it can be passed to classes
yaml_config['data']['init_args']['data_root'] = data_root
yaml_config['model']['init_args']['output_dir'] = output_dir

# instantiate transform classes if they exist
if "transform" in yaml_config["data"]["init_args"].keys():
      yaml_config["data"]["init_args"]["transform"] = [
            instantiate_class(args = (), init=class_path) for class_path in yaml_config["data"]["init_args"]["transform"]
      ]
gelos_datamodule = GELOSDataModule(**yaml_config['data']['init_args'])
task = LenientEmbeddingGenerationTask(**yaml_config['model']['init_args'])

device = 'gpu' if torch.cuda.is_available() else 'cpu'
trainer = Trainer(accelerator=device, devices=1)

INFO:terratorch.models.backbones.prithvi_vit:model_bands not passed. Assuming bands are ordered in the same way as [<HLSBands.BLUE: 'BLUE'>, <HLSBands.GREEN: 'GREEN'>, <HLSBands.RED: 'RED'>, <HLSBands.NIR_NARROW: 'NIR_NARROW'>, <HLSBands.SWIR_1: 'SWIR_1'>, <HLSBands.SWIR_2: 'SWIR_2'>].Pretrained patch_embed layer may be misaligned with current bands


data:
  class_path: gelos.gelosdatamodule.GELOSDataModule
  init_args:
    bands:
      S2L2A:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    batch_size: 1
    num_workers: 0
embedding_extraction_strategies:
  All Patches from April to June:
  - start: 37
    step: 1
    stop: 73
  All Steps of Middle Patch:
  - start: 19
    step: 36
    stop: null
  CLS Token:
  - start: 0
    step: 1
    stop: 1
model:
  class_path: terratorch.tasks.EmbeddingGenerationTask
  init_args:
    embed_file_key: filename
    embedding_pooling: null
    has_cls: true
    model: prithvi_eo_v2_300
    model_args:
      backbone: prithvi_eo_v2_300
      backbone_bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      backbone_pretrained: true
    output_format: parquet
  title: Prithvi EO V2 300M
seed_everything: 0
trainer:
  accelerator: auto
  callbacks: []
  devices: auto
  max_epochs: 0
  num_nodes: 1
  strategy: auto

INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


### Inspect model and ensure weights have been loaded correctly

In [7]:
for k, v in task.model.named_parameters():
    print(k, v.shape)

cls_token torch.Size([1, 1, 1024])
patch_embed.proj.weight torch.Size([1024, 6, 1, 16, 16])
patch_embed.proj.bias torch.Size([1024])
blocks.0.norm1.weight torch.Size([1024])
blocks.0.norm1.bias torch.Size([1024])
blocks.0.attn.qkv.weight torch.Size([3072, 1024])
blocks.0.attn.qkv.bias torch.Size([3072])
blocks.0.attn.proj.weight torch.Size([1024, 1024])
blocks.0.attn.proj.bias torch.Size([1024])
blocks.0.norm2.weight torch.Size([1024])
blocks.0.norm2.bias torch.Size([1024])
blocks.0.mlp.fc1.weight torch.Size([4096, 1024])
blocks.0.mlp.fc1.bias torch.Size([4096])
blocks.0.mlp.fc2.weight torch.Size([1024, 4096])
blocks.0.mlp.fc2.bias torch.Size([1024])
blocks.1.norm1.weight torch.Size([1024])
blocks.1.norm1.bias torch.Size([1024])
blocks.1.attn.qkv.weight torch.Size([3072, 1024])
blocks.1.attn.qkv.bias torch.Size([3072])
blocks.1.attn.proj.weight torch.Size([1024, 1024])
blocks.1.attn.proj.bias torch.Size([1024])
blocks.1.norm2.weight torch.Size([1024])
blocks.1.norm2.bias torch.Size([10

In [8]:
from models.prithvi_eo_v2 import PrithviViT

In [9]:
model_version = "300M"
prithvi_model = PrithviViT(num_frames = 4, in_chans = 6, model_size = model_version)

In [12]:
weights_path = EXTERNAL_DATA_DIR / "model_weights"/ f"Prithvi_EO_V2_{model_version}.pt"

In [14]:
weights = torch.load(weights_path)

FileNotFoundError: [Errno 2] No such file or directory: '/app/data/external/model_weights/Prithvi_EO_V2_300M.pt'

In [None]:
encoder_state_dict = {}
for k, v in state_dict.items():
    if 'pos_embed' in k:
        continue
    if k.startswith('encoder'):
        new_key = k.replace("encoder.", "", 1)
        encoder_state_dict[new_key] = v

prithvi_model.load_state_dict(encoder_state_dict, strict=False)

### Inspect outputs of dataset

In [6]:
gelos_datamodule.setup(stage="predict")

In [13]:
for k, v in gelos_datamodule.dataset[0].items():
    if k == "image":
        for sensor, data in v.items():
            print(sensor, data.shape)
    else:
        print(k, v)

S2L2A (4, 96, 96, 12)
S1RTC (4, 96, 96, 2)
DEM (4, 32, 32, 1)
S2L2A torch.Size([12, 4, 96, 96])
S1RTC torch.Size([2, 4, 96, 96])
DEM torch.Size([1, 4, 96, 96])
S2L2A torch.Size([12, 4, 96, 96])
S1RTC torch.Size([2, 4, 96, 96])
DEM torch.Size([1, 4, 96, 96])
filename 000000
file_id 0
