In [75]:
import geopandas as gpd
import os
import random
import pandas as pd
from pathlib import Path
import pyarrow as pa
from matplotlib.patches import Patch
import pyarrow.dataset as ds
import numpy as np
import yaml
from gelos import config
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from gelos.config import PROJ_ROOT, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from gelos.config import REPORTS_DIR, FIGURES_DIR

In [76]:

def sample_files(directory: str | Path, sample_size: int, *, seed: int | None = None) -> list[Path]:
    rng = random.Random(seed)
    directory = Path(directory)

    files = [Path(entry.path) for entry in os.scandir(directory) if entry.is_file()]
    if sample_size >= len(files):
        return files

    return rng.sample(files, sample_size)

In [77]:
def select_embedding_indices(
        embeddings_column: pa.lib.ListArray, 
        slice_args: list[dict[str, int]]
        ) -> pa.lib.ListArray:
    array = embeddings_column
    for args in slice_args:
        array = pa.compute.list_slice(
            array,
            start=args['start'],
            stop=args["stop"],
            step=args["step"]
            )
        array = pa.compute.list_flatten(array)
    return array

def extract_embeddings_from_directory(
        directory: Path | str, 
        n_sample: int = 100000, 
        chip_indices: list[int] = None, 
        slice_args: list[dict[str, int]] = [{"start": 0, "stop": None, "step": 1}]
        ) -> tuple[np.array]:
    # extract embeddings in numpy format from geoparquet
    if chip_indices:
        files = [directory / f"{str(id).zfill(6)}_embedding.parquet" for id in chip_indices]
    else:
        files = sample_files(directory, n_sample, seed=42)
        chip_indices = [int(file.stem.split('_')[0]) for file in files]
    dataset = ds.dataset(files, format='parquet')
    table = dataset.to_table(columns=["embedding"])
    n_rows = table.num_rows
    embeddings_column = table.column("embedding").combine_chunks()
    selected_embeddings = select_embedding_indices(embeddings_column, slice_args)
    embeddings_flattened = pa.compute.list_flatten(selected_embeddings, recursive=True)
    embeddings = embeddings_flattened.to_numpy(zero_copy_only=False).reshape(n_rows,-1)
    return embeddings, chip_indices

def tsne_from_embeddings(embeddings: np.array) -> np.array:
    tsne = TSNE(n_components=2, random_state=42, perplexity=50, max_iter=1000)
    embeddings_tsne = tsne.fit_transform(embeddings)
    return embeddings_tsne

def plot_from_tsne(
        embeddings_tsne: np.array,
        chip_gdf: gpd.GeoDataFrame,
        model_name: str,
        extraction_strategy: str,
        embedding_layer: str,
        legend_patches: list[Patch],
        chip_indices: list[int],
        axis_lim: int = 90,
        output_dir: str | Path = None
        ) -> None:
    """
    plot a tSNE transform of embeddings colored according to land cover
    """
    colors = chip_gdf.loc[chip_indices]['color']
    
    plt.figure(figsize=(10, 8))
    plt.scatter(embeddings_tsne[:, 1], -embeddings_tsne[:, 0], c=colors, s=2)
    plt.suptitle(f"t-SNE Visualization of GELOS Embeddings for {model_name}", fontsize=14)
    plt.title(extraction_strategy)
    plt.xlabel("t-SNE Dimension 1", fontsize=12)
    plt.ylabel("t-SNE Dimension 2", fontsize=12)
    plt.xlim([-axis_lim, axis_lim])
    plt.ylim([-axis_lim, axis_lim])
    plt.legend(handles=legend_patches, loc="upper left", fontsize=10, framealpha=0.9)

    if output_dir:
        model_name = model_name.replace(" ", "").lower()
        extraction_strategy = extraction_strategy.replace(" ", "").lower()
        embedding_layer = embedding_layer.replace(" ", "").lower()
        plt.savefig(output_dir / f"{model_name}_{extraction_strategy}_{embedding_layer}_tsneplot.png", dpi=600, bbox_inches="tight")
    else:
        plt.show()

def save_tsne_as_csv(
        embeddings_tsne: np.array,
        chip_indices: list[int],
        model_name: str,
        extraction_strategy: str,
        embedding_layer: str,
        output_dir: str | Path = None
) -> None:
    model_name = model_name.replace(" ", "").lower()
    extraction_strategy = extraction_strategy.replace(" ", "").lower()
    embedding_layer = embedding_layer.replace(" ", "").lower()
    embeddings_df = pd.DataFrame({
        "id" : chip_indices,
        f"{model_name}_{extraction_strategy}_x" : embeddings_tsne[:, 0],
        f"{model_name}_{extraction_strategy}_y" : embeddings_tsne[:, 0],
    })
    embeddings_df.to_csv(output_dir / f"{model_name}_{extraction_strategy}_{embedding_layer}_tsnetable.csv")

In [78]:

legend_patches = [
    Patch(color=color, label=name)
    for name, color in [
        ("Water", "#419bdf"),
        ("Trees", "#397d49"),
        ("Crops", "#e49635"),
        ("Built Area", "#c4281b"),
        ("Bare Ground", "#a59b8f"),
        ("Rangeland", "#e3e2c3"),
    ]
]



In [79]:
yaml_config_directory = config.PROJ_ROOT / 'gelos' / 'configs'
for yaml_filepath in yaml_config_directory.glob("*.yaml"):
    with open(yaml_filepath, "r") as f:
        yaml_config = yaml.safe_load(f)
    print(yaml.dump(yaml_config))
    model_name = yaml_config['model']['init_args']['model']
    model_title = yaml_config['model']['title']
    embedding_extraction_strategies = yaml_config['embedding_extraction_strategies']

    output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name
    data_root = RAW_DATA_DIR / DATA_VERSION
    chip_gdf = gpd.read_file(data_root / 'gelos_chip_tracker.geojson')
    reports_dir = REPORTS_DIR / DATA_VERSION
    reports_dir.mkdir(exist_ok=True, parents=True)
    figures_dir = FIGURES_DIR / DATA_VERSION
    figures_dir.mkdir(exist_ok=True, parents=True)

    # add variables to yaml config so it can be passed to classes
    yaml_config['data']['init_args']['data_root'] = data_root
    yaml_config['model']['init_args']['output_dir'] = output_dir
    embeddings_directories = [item for item in output_dir.iterdir() if item.is_dir()]

    for embeddings_directory in embeddings_directories:
        break 
        embedding_layer = embeddings_directory.stem

        for extraction_strategy, slice_args in embedding_extraction_strategies.items():

            embeddings, chip_indices = extract_embeddings_from_directory(embeddings_directory, slice_args=slice_args)

            embeddings_tsne = tsne_from_embeddings(embeddings)

            save_tsne_as_csv(
                embeddings_tsne,
                chip_indices,
                model_name,
                extraction_strategy,
                embedding_layer,
                output_dir
                )

            plot_from_tsne(
                embeddings_tsne,
                chip_gdf,
                model_name,
                extraction_strategy,
                embedding_layer,
                legend_patches,
                output_dir = figures_dir
                )

data:
  class_path: gelos.gelosdatamodule.GELOSDataModule
  init_args:
    bands:
      DEM:
      - DEM
      S1RTC:
      - VV
      - VH
      S2L2A:
      - COASTAL_AEROSOL
      - BLUE
      - GREEN
      - RED
      - RED_EDGE_1
      - RED_EDGE_2
      - RED_EDGE_3
      - NIR_BROAD
      - NIR_NARROW
      - WATER_VAPOR
      - SWIR_1
      - SWIR_2
    batch_size: 1
    num_workers: 0
    repeat_bands:
      DEM: 4
    target_size: 96
embedding_extraction_strategies:
  All Embeddings:
  - start: 0
    step: 1
    stop: null
  All Patches from April to June:
  - start: 1
    step: 1
    stop: 2
  All Steps of Middle Patch:
  - start: 0
    step: 1
    stop: null
  - start: 18
    step: 1
    stop: 19
model:
  class_path: terratorch.tasks.EmbeddingGenerationTask
  init_args:
    embed_file_key: filename
    embedding_pooling: null
    has_cls: false
    layers:
    - -1
    model: terramind_v1_base
    model_args:
      merge_method: mean
      modalities:
      - S2L2A
      - 

In [81]:
n_sample=100
directory = embeddings_directory
files = sample_files(directory, n_sample, seed=42)
chip_indices = [int(file.stem.split('_')[0]) for file in files]
dataset = ds.dataset(files, format='parquet')

In [83]:
dataset.schema

embedding: list<element: list<element: double>>
  child 0, element: list<element: double>
      child 0, element: double
-- schema metadata --
pandas: '{"index_columns": [], "column_indexes": [], "columns": [{"name":' + 223
geo: '{"primary_column": null, "columns": {}, "version": "1.0.0", "creato' + 49