In [1]:
import geopandas as gpd
import os
import random
import pandas as pd
from pathlib import Path
import pyarrow as pa
from matplotlib.patches import Patch
import pyarrow.dataset as ds
import numpy as np
import yaml
from gelos import config
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from gelos.config import PROJ_ROOT, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from gelos.config import REPORTS_DIR, FIGURES_DIR

[32m2025-12-06 19:55:17.649[0m | [1mINFO    [0m | [36mgelos.config[0m:[36m<module>[0m:[36m16[0m - [1mPROJ_ROOT path is: /app[0m


In [2]:
import gelos.tsne_transform

In [5]:
from gelos.tsne_transform import extract_embeddings_from_directory, tsne_from_embeddings, select_embedding_indices, sample_files,plot_from_tsne

In [12]:
yaml_file = "prithvi_eo_300m_embedding_generation.yaml"
extraction_strategy = "CLS Token"

In [15]:
yaml_config_directory = PROJ_ROOT / 'gelos' / 'configs'

data_root = RAW_DATA_DIR / DATA_VERSION
chip_gdf = gpd.read_file(data_root / 'gelos_chip_tracker.geojson')

In [16]:
yaml_filepath = yaml_config_directory / yaml_file
with open(yaml_filepath, "r") as f:
    yaml_config = yaml.safe_load(f)
print(yaml.dump(yaml_config))
model_name = yaml_config['model']['init_args']['model']
model_title = yaml_config['model']['title']
embedding_extraction_strategies = yaml_config['embedding_extraction_strategies']
output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name

embeddings_directories = [item for item in output_dir.iterdir() if item.is_dir()]

slice_args = embedding_extraction_strategies[extraction_strategy]    


data:
  class_path: gelos.gelosdatamodule.GELOSDataModule
  init_args:
    bands:
      S2L2A:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    batch_size: 1
    num_workers: 0
embedding_extraction_strategies:
  All Patches from April to June:
  - start: 37
    step: 1
    stop: 73
  All Steps of Middle Patch:
  - start: 19
    step: 36
    stop: null
  CLS Token:
  - start: 0
    step: 1
    stop: 1
model:
  class_path: terratorch.tasks.EmbeddingGenerationTask
  init_args:
    embed_file_key: filename
    embedding_pooling: null
    has_cls: true
    model: prithvi_eo_v2_300
    model_args:
      backbone: prithvi_eo_v2_300
      backbone_bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      backbone_pretrained: true
    output_format: parquet
  title: Prithvi EO V2 300M
seed_everything: 0
trainer:
  accelerator: auto
  callbacks: []
  devices: auto
  max_epochs: 0
  num_nodes: 1
  strategy: auto

### Load one batch of embeddings and test slicing

In [None]:
directory = embeddings_directories[0]
n_sample = 100
# create parquet dataset
files = sample_files(directory, n_sample, seed=42)
dataset = ds.dataset(files, format="parquet")
scanner = dataset.scanner(columns=["embedding", "file_id"])
emb_chunks, id_chunks = [], []
batches = scanner.to_batches()
for batch in batches:
    break


In [41]:
# select cls token using indexing
cls_token_indexed = pa.compute.list_flatten(batch["embedding"][0][0]).to_numpy()

# select cls token using slicing - if embedding strategy is "CLS Token"
sliced = select_embedding_indices(batch.column("embedding"), slice_args)
flattened = pa.compute.list_flatten(sliced, recursive=True)
emb_np = flattened.to_numpy(zero_copy_only=False).reshape(len(batch), -1)

print(f"CLS Token from slicing equivalent to indexing 1st token: {(emb_np == cls_token_indexed).all()}")

CLS Token from slicing equivalent to indexing 1st token: True


In [81]:
n_sample=100
directory = embeddings_directory
files = sample_files(directory, n_sample, seed=42)
chip_indices = [int(file.stem.split('_')[0]) for file in files]
dataset = ds.dataset(files, format='parquet')

In [None]:
# create direcotires for output if saving
reports_dir = REPORTS_DIR / DATA_VERSION
reports_dir.mkdir(exist_ok=True, parents=True)
figures_dir = FIGURES_DIR / DATA_VERSION
figures_dir.mkdir(exist_ok=True, parents=True)