In [2]:
# %% ------------------ Imports ------------------
import os
import glob
import yaml
import pickle
import itertools
import numpy as np
import pandas as pd
import polars as pl
import geopandas as gpd
from shapely.geometry import Point
from sklearn.neighbors import BallTree
from tqdm.auto import tqdm

# %% ------------------ Functions ------------------
def pick_point(row):
    """Возвращает geometry_point или центроид геометрии."""
    pt = row.get("geometry_point")
    if isinstance(pt, Point):
        return pt
    return row.geometry.centroid


def build_tile_dataframe(df, crs="EPSG:3857"):
    """Преобразует датафрейм в GeoDataFrame с geometry_point."""
    gdf = gpd.GeoDataFrame(df.copy(), geometry=df["geometry"], crs=crs)
    gdf["geometry_point"] = df.apply(pick_point, axis=1)
    return gdf


def load_embedding_metadata(parquet_dir):
    """Читает координаты всех эмбеддингов и возвращает списки и смещения."""
    parquets = glob.glob(os.path.join(parquet_dir, "*.parquet"))
    coords, ids, offsets = [], [], []
    offset = 0

    for path in tqdm(parquets, desc="Reading coords from parquet"):
        part = pl.read_parquet(path, columns=['unique_id','centre_lat','centre_lon']).to_pandas()
        coords.append(np.vstack([part["centre_lat"], part["centre_lon"]]).T)
        ids.append(part["unique_id"].tolist())
        offsets.append((path, offset, offset + len(part)))
        offset += len(part)

    return np.vstack(coords), list(itertools.chain.from_iterable(ids)), offsets


def match_tiles_to_embeddings(gdf, emb_coords, emb_ids):
    """Находит ближайший эмбеддинг для каждого тайла и добавляет match_id и dist_to_emb."""
    tile_xy = np.vstack([gdf.geometry_point.y.values, gdf.geometry_point.x.values]).T
    tile_rad = np.radians(tile_xy)
    tree = BallTree(np.radians(emb_coords), metric='haversine')
    dist_rad, idx = tree.query(tile_rad, k=1)
    dist_m = dist_rad[:, 0] * 6_371_000  # радиус Земли в метрах

    gdf["match_id"] = [emb_ids[i] for i in idx[:, 0]]
    gdf["dist_to_emb"] = dist_m
    return gdf


def load_required_embeddings(needed_ids, file_offsets):
    """Загружает только нужные эмбеддинги по match_id."""
    emb_vectors = {}
    emb_cols = None

    for path, start, end in tqdm(file_offsets, desc="Loading embeddings blocks"):
        block_ids = emb_ids_flat[start:end]
        want = needed_ids.intersection(block_ids)
        if not want:
            continue

        part_pl = (
            pl.read_parquet(path)
              .filter(pl.col("unique_id").is_in(list(want)))
              .select(["unique_id", "embedding"])
        )
        part = part_pl.to_pandas()
        mat = np.vstack(part["embedding"].values)
        cols = [f"emb_{i}" for i in range(mat.shape[1])]
        if emb_cols is None:
            emb_cols = cols

        df_emb = pd.DataFrame(mat, columns=cols, index=part.index)
        df_part = pd.concat([part[["unique_id"]], df_emb], axis=1)

        for _, row in df_part.iterrows():
            uid = row["unique_id"]
            emb_vectors[uid] = {c: row[c] for c in cols}

    return emb_vectors, emb_cols


def attach_embeddings(gdf, emb_vectors, emb_cols):
    """Приклеивает эмбеддинги к GeoDataFrame."""
    for c in emb_cols:
        gdf[c] = gdf["match_id"].map(lambda uid: emb_vectors.get(uid, {}).get(c, np.nan))
    return gdf


def add_embeddings_to_tiles(df, parquet_dir):
    """Главная функция. Возвращает df с emb_0 … emb_22."""
    gdf = build_tile_dataframe(df)
    emb_coords, emb_ids_flat, file_offsets = load_embedding_metadata(parquet_dir)
    gdf = match_tiles_to_embeddings(gdf, emb_coords, emb_ids_flat)
    needed_ids = set(gdf["match_id"])
    emb_vectors, emb_cols = load_required_embeddings(needed_ids, file_offsets)
    gdf = attach_embeddings(gdf, emb_vectors, emb_cols)
    return gdf



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# %% ------------------ Config & Run ------------------
# Загрузка путей из config.yaml
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

processed_path = config["processed_data_dir"]
raw_path = config["raw_data_dir"]
parquet_dir = os.path.join(raw_path, "datasets", "major-tom-core-s2l1c-ssl4eo-amazonia-embeddings", "versions", "1")

# Загрузка исходного файла
input_path = os.path.join(processed_path, "all_tiles_features_with_soil.pkl")
with open(input_path, "rb") as f:
    df = pickle.load(f)

# Получение DataFrame с эмбеддингами
df_with_emb = add_embeddings_to_tiles(df, parquet_dir)

# Сохранение результата
output_path = os.path.join(processed_path, "all_tiles_features_with_emb.pkl")
with open(output_path, "wb") as f:
    pickle.dump(df_with_emb, f)

print(f"[✓] Saved with embeddings → {output_path}")