In [None]:
# SST (lat, lon, timestamp, sst) -> gradiente de SST
# Salida: <repo_root>/transform/sst/gradient  (Parquet particionado por year/month)

from pathlib import Path
import math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as pads

# =========================
# Localización carpetas
# =========================
def find_sst_parquet_dir(start: Path | None = None) -> Path:
    start = (start or Path.cwd()).resolve()
    # ruta donde guardaste el parquet de SST
    for parent in [start, *start.parents]:
        cand = parent / "transform" / "sst" / "sample"
        if cand.is_dir():
            return cand
    raise FileNotFoundError(f"No se encontró 'transform/sst/parquet' desde {start}")

IN_PARQ   = find_sst_parquet_dir()
REPO_ROOT = IN_PARQ.parents[2]  # .../transform/sst/parquet -> subir 3 niveles
OUT_PARQ  = (REPO_ROOT / "transform" / "sst" / "gradient")
OUT_PARQ.mkdir(parents=True, exist_ok=True)

print(f"📥 Input : {IN_PARQ}")
print(f"📤 Output: {OUT_PARQ} (partitioned by year/month)")

# =========================
# Config
# =========================
# KNN/Plano local: sst ≈ a*x + b*y + c  => grad = (a, b) [unidades: sst por km]
K_NEIGHBORS = 8         # vecinos por punto (>=3)
MAX_RADIUS_KM = None    # opcional: descartar gradientes cuyo vecino más lejano supere este radio (None = desactivado)
FRACTION = 1.0          # submuestreo aleatorio por timestamp (0<frac<=1)
RANDOM_SEED = 42
CHUNK_SOLVE = 50_000    # resuelve gradientes por lotes para no agotar RAM

rng = np.random.default_rng(RANDOM_SEED)

# =========================
# Utilidades parquet
# =========================
def write_parquet_block(df: pd.DataFrame):
    """Escribe bloque al dataset particionado por year/month."""
    df = df.copy()
    df["ts"] = pd.to_datetime(df["timestamp"], utc=True, errors="coerce")
    df["year"] = df["ts"].dt.year.astype("int16")
    df["month"] = df["ts"].dt.month.astype("int8")
    table = pa.Table.from_pandas(
        df[["lat","lon","timestamp","sst","dTdx_km","dTdy_km","sst_grad","year","month"]],
        preserve_index=False
    )
    pq.write_to_dataset(
        table,
        root_path=OUT_PARQ,
        partition_cols=["year","month"],
        compression="snappy"
    )

# =========================
# Geom helpers
# =========================
def lonlat_to_local_km(lon: np.ndarray, lat: np.ndarray):
    """
    Proyección equirectangular local (aprox) -> (x,y) en km
    Centro en (lon0, lat0) = medianas del grupo.
    """
    lon0 = np.nanmedian(lon); lat0 = np.nanmedian(lat)
    R = 6371.0  # km
    to_rad = np.pi / 180.0
    x = R * np.cos(lat0 * to_rad) * (lon - lon0) * to_rad
    y = R * (lat - lat0) * to_rad
    return x.astype("float32"), y.astype("float32")

def fit_plane_gradients(x: np.ndarray, y: np.ndarray, z: np.ndarray,
                        k: int, max_radius_km: float | None,
                        chunk: int):
    """
    Estima gradiente por punto usando KNN y ajuste de plano local (LS):
      z ≈ a*x + b*y + c  => grad = (a, b)
    Devuelve dZ/dx (km^-1), dZ/dy (km^-1) y |grad|.
    """
    from scipy.spatial import cKDTree

    n = x.size
    tree = cKDTree(np.c_[x, y])

    dZdx = np.full(n, np.nan, dtype="float32")
    dZdy = np.full(n, np.nan, dtype="float32")

    # Resolve en lotes para controlar memoria/tiempo
    for i0 in tqdm(range(0, n, chunk), desc="  -> resolviendo gradientes", leave=False):
        i1 = min(i0 + chunk, n)
        pts = np.c_[x[i0:i1], y[i0:i1]]
        # vecinos (incluye el propio punto)
        dist, idx = tree.query(pts, k=min(k, n), workers=-1)

        # asegurar 2D shape cuando k=1
        if dist.ndim == 1:
            dist = dist[:, None]; idx = idx[:, None]

        for j in range(i1 - i0):
            neigh = idx[j]
            # opcional: filtrar por radio máximo (si se definió)
            if max_radius_km is not None:
                ok = dist[j] <= max_radius_km
                # garantizar al menos 3 vecinos
                if ok.sum() < 3:
                    continue
                neigh = neigh[ok]

            if neigh.size < 3:
                continue

            xi = x[neigh].astype("float64")
            yi = y[neigh].astype("float64")
            zi = z[neigh].astype("float64")

            # centrar en el punto objetivo para robustez numérica
            x0, y0 = pts[j]
            Xi = xi - x0
            Yi = yi - y0

            A = np.c_[Xi, Yi, np.ones_like(Xi)]
            try:
                coef, *_ = np.linalg.lstsq(A, zi, rcond=None)  # [a, b, c]
                dZdx[i0 + j] = coef[0]
                dZdy[i0 + j] = coef[1]
            except Exception:
                # deja NaN si algo falla
                pass

    grad = np.sqrt(dZdx**2 + dZdy**2, dtype="float32")
    return dZdx, dZdy, grad

# =========================
# Pipeline
# =========================
ds = pads.dataset(IN_PARQ, format="parquet", partitioning="hive")
print("🔎 Escaneando timestamps únicos...")
timestamps = set()

# recolectar timestamps de forma streaming (columna sola)
scan_ts = ds.scan(columns=["timestamp"])
for batch in tqdm(scan_ts.to_batches(), desc="  -> leyendo 'timestamp'", leave=False):
    col = batch.column(0)  # Arrow Array
    # convertir a pylist es seguro aquí porque solo guardamos valores únicos
    for v in col.to_pylist():
        if v is not None:
            timestamps.add(v)

timestamps = sorted(timestamps)
print(f"🕒 Timestamps distintos: {len(timestamps)}")

# Proceso por timestamp (independiente, permite limpieza incremental)
for ts in tqdm(timestamps, desc="Procesando timestamps", unit="ts"):
    # lee solo filas de ese timestamp
    filt = (pads.field("timestamp") == pa.scalar(ts))
    scanner = ds.scan(columns=["lat","lon","timestamp","sst"], filter=filt)

    # Acumula en un único DataFrame por timestamp (normalmente está en 1-2 fragmentos)
    parts = []
    for batch in scanner.to_batches():
        tbl = pa.Table.from_batches([batch])
        parts.append(tbl.to_pandas())
    if not parts:
        continue
    df = pd.concat(parts, ignore_index=True)

    # limpiar
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["lat","lon","sst"])
    if df.empty:
        continue

    # submuestreo opcional para acelerar
    if FRACTION < 1.0 and len(df) > 0:
        df = df.sample(frac=FRACTION, random_state=RANDOM_SEED).reset_index(drop=True)

    # coordenadas locales en km
    x, y = lonlat_to_local_km(df["lon"].to_numpy(), df["lat"].to_numpy())
    z = df["sst"].to_numpy(dtype="float32")

    # gradiente local KNN
    dZdx, dZdy, grad = fit_plane_gradients(
        x, y, z,
        k=K_NEIGHBORS,
        max_radius_km=MAX_RADIUS_KM,
        chunk=CHUNK_SOLVE
    )

    out = df.copy()
    out["dTdx_km"] = dZdx
    out["dTdy_km"] = dZdy
    out["sst_grad"] = grad  # magnitud en unidades de sst por km

    # escribir (filtra filas sin gradiente válido)
    out = out.dropna(subset=["dTdx_km","dTdy_km","sst_grad"])
    if not out.empty:
        write_parquet_block(out)

print(f"\n✅ Dataset de gradientes listo en: {OUT_PARQ}")
