In [None]:
# EKE (C3S DUACS L4) → lat, lon, timestamp, eke
# Incremental write to Parquet partitioned by year/month
#
# - Handles disguised .nc that are actually ZIP/GZIP (no renaming; temp extraction).
# - Prefers velocity anomalies (ugosa, vgosa). Fallback to absolute (ugos, vgos).
# - Uses xarray+netCDF4/h5netcdf/scipy; writes Parquet with pyarrow.
#
# Requirements:
#   pip install xarray netCDF4 h5netcdf h5py pyarrow tqdm pandas numpy

from pathlib import Path
import os, zipfile, gzip, tempfile, shutil, importlib
from datetime import timezone
import numpy as np
import pandas as pd
import xarray as xr
from tqdm.auto import tqdm
import pyarrow as pa
import pyarrow.parquet as pq

# --------------------------
# Locate input/output
# --------------------------
def find_eke_sample_dir(start: Path | None = None) -> Path:
    start = (start or Path.cwd()).resolve()
    for parent in [start, *start.parents]:
        cand = parent / "downloads" / "eke" / "sample"
        if cand.is_dir():
            return cand
    raise FileNotFoundError(f"Could not find 'downloads/eke/sample' starting from {start}")

SAMPLE_DIR = find_eke_sample_dir()
REPO_ROOT  = SAMPLE_DIR.parents[2]  # .../downloads/eke/sample -> up 3 levels
OUT_PARQ   = (REPO_ROOT / "transform" / "eke" / "sample")
OUT_PARQ.mkdir(parents=True, exist_ok=True)

print(f"📂 Input:  {SAMPLE_DIR}")
print(f"🧺 Output Parquet (partitioned): {OUT_PARQ}")

# --------------------------
# Config
# --------------------------
PREFERRED_UV = [("ugosa","vgosa"), ("ugos","vgos")]  # anomalies first
LAT_CANDS = ("latitude","lat")
LON_CANDS = ("longitude","lon")
TIME_NAME = "time"

FRACTION = 1.0                # 0<frac<=1 to randomly keep a fraction of points
RANDOM_SEED = 42
MAX_POINTS_PER_FILE = None    # cap rows per file (after FRACTION), e.g., 1_000_000
CHUNK_ROWS = 2_000_000        # write in chunks to control memory
COMPRESSION = "snappy"        # snappy|zstd|gzip

rng = np.random.default_rng(RANDOM_SEED)

# --------------------------
# Helpers
# --------------------------
def sniff_magic(path: Path) -> str:
    with path.open("rb") as f:
        head = f.read(16)
    if head.startswith(b"\x89HDF\r\n\x1a\n"):
        return "HDF5"      # NetCDF4/HDF5
    if head.startswith(b"CDF"):
        return "NETCDF3"   # NetCDF3
    if head.startswith(b"PK"):
        return "ZIP"
    if head.startswith(b"\x1f\x8b"):
        return "GZIP"
    return "UNKNOWN"

def extract_temp_nc(src_path: Path, kind: str) -> tuple[Path, tempfile.TemporaryDirectory]:
    """Extract inner .nc from ZIP/GZIP into a TemporaryDirectory; return (nc_path, tmpdir)."""
    tmpdir = tempfile.TemporaryDirectory()
    out_nc = Path(tmpdir.name) / (src_path.stem + "__real.nc")
    if kind == "ZIP":
        with zipfile.ZipFile(src_path) as zf:
            members = zf.namelist()
            nc_members = [m for m in members if m.lower().endswith(".nc")]
            pick = nc_members[0] if nc_members else members[0]
            with zf.open(pick) as src, out_nc.open("wb") as dst:
                shutil.copyfileobj(src, dst)
    elif kind == "GZIP":
        with gzip.open(src_path, "rb") as src, out_nc.open("wb") as dst:
            shutil.copyfileobj(src, dst)
    else:
        tmpdir.cleanup()
        raise ValueError(f"Unsupported archive kind: {kind}")
    return out_nc, tmpdir

def try_open_xr(nc_path: Path):
    tried, last_err = [], None
    for eng in ("netcdf4", "h5netcdf", "scipy"):
        if eng == "netcdf4" and importlib.util.find_spec("netCDF4") is None:
            continue
        if eng == "h5netcdf" and importlib.util.find_spec("h5netcdf") is None:
            continue
        try:
            ds = xr.open_dataset(nc_path, engine=eng, decode_cf=True, mask_and_scale=True)
            return ds, eng
        except Exception as e:
            tried.append(eng); last_err = e
    raise RuntimeError(f"xarray failed with engines {tried}. Last error: {last_err}")

def choose_coords(ds: xr.Dataset):
    lat_name = next((n for n in LAT_CANDS if n in ds.coords), None)
    lon_name = next((n for n in LON_CANDS if n in ds.coords), None)
    if lat_name is None or lon_name is None:
        raise KeyError("latitude/longitude coords not found")
    lat = ds[lat_name].values
    lon = ds[lon_name].values
    return lat, lon, lat_name, lon_name

def choose_uv(ds: xr.Dataset) -> tuple[str,str,str]:
    for u, v in PREFERRED_UV:
        if u in ds.data_vars and v in ds.data_vars:
            return u, v, ("anomaly" if (u.endswith("osa") and v.endswith("osa")) else "absolute")
    raise KeyError("No suitable velocity variables found (tried ugosa/vgosa and ugos/vgos)")

def write_parquet_block(df: pd.DataFrame):
    df["ts"] = pd.to_datetime(df["timestamp"], utc=True, errors="coerce")
    df["year"] = df["ts"].dt.year.astype("int16")
    df["month"] = df["ts"].dt.month.astype("int8")
    table = pa.Table.from_pandas(df[["lat","lon","timestamp","eke","year","month"]], preserve_index=False)
    pq.write_to_dataset(table, root_path=OUT_PARQ, partition_cols=["year","month"], compression=COMPRESSION)

# --------------------------
# Processing
# --------------------------
files = sorted(SAMPLE_DIR.glob("*.nc"))
assert files, f"No .nc files found in {SAMPLE_DIR}"

skipped, processed = [], 0

for src in tqdm(files, desc="Files", unit="file"):
    tmpdir = None
    work = src
    kind = sniff_magic(src)
    if kind in ("ZIP","GZIP"):
        try:
            work, tmpdir = extract_temp_nc(src, kind)
        except Exception as e:
            skipped.append((src.name, f"extract {kind} failed: {e}"))
            continue

    try:
        ds, eng = try_open_xr(work)
    except Exception as e:
        skipped.append((src.name, f"xarray open failed: {e}"))
        if tmpdir: 
            try: tmpdir.cleanup()
            except: pass
        continue

    try:
        # Coordinates
        lat1d, lon1d, lat_name, lon_name = choose_coords(ds)
        if TIME_NAME not in ds.coords:
            raise KeyError("time coord not found")
        times = pd.to_datetime(ds[TIME_NAME].values).tz_localize("UTC", nonexistent='shift_forward', ambiguous='NaT')

        # Velocity fields
        u_name, v_name, vel_kind = choose_uv(ds)
        u = ds[u_name]  # (time, lat, lon)
        v = ds[v_name]

        # Shapes
        # Expect (T, Y, X) or possibly named dims; enforce with .transpose if needed
        def ensure_tyxd(a: xr.DataArray):
            dims = list(a.dims)
            want = [TIME_NAME, lat_name, lon_name]
            if dims != want:
                return a.transpose(*want)
            return a
        u = ensure_tyxd(u)
        v = ensure_tyxd(v)

        # Meshgrid for lat/lon (1D → 2D)
        if lat1d.ndim == 1 and lon1d.ndim == 1:
            lon2d, lat2d = np.meshgrid(lon1d, lat1d)   # (Y,X)
        else:
            # already 2D fields (rare)
            lat2d, lon2d = lat1d, lon1d

        # For each time slice (usually one per file)
        for ti, tval in enumerate(times):
            u2d = np.asarray(u.isel({TIME_NAME: ti}).values)
            v2d = np.asarray(v.isel({TIME_NAME: ti}).values)

            # EKE = 0.5*(u'^2 + v'^2); using anomalies if available
            eke2d = 0.5*(u2d**2 + v2d**2)

            # Flatten valid points
            latf = lat2d.ravel()
            lonf = lon2d.ravel()
            ekef = eke2d.ravel()
            finite = np.isfinite(latf) & np.isfinite(lonf) & np.isfinite(ekef)
            latf, lonf, ekef = latf[finite], lonf[finite], ekef[finite]
            if latf.size == 0:
                continue

            # Optional downsample
            n = latf.size
            if FRACTION < 1.0:
                k = max(1, int(np.ceil(n * FRACTION)))
                idx = rng.choice(n, size=k, replace=False)
                latf, lonf, ekef = latf[idx], lonf[idx], ekef[idx]
                n = k

            if (MAX_POINTS_PER_FILE is not None) and (n > MAX_POINTS_PER_FILE):
                idx = rng.choice(n, size=MAX_POINTS_PER_FILE, replace=False)
                latf, lonf, ekef = latf[idx], lonf[idx], ekef[idx]
                n = MAX_POINTS_PER_FILE

            # Chunked write
            ts_iso = pd.to_datetime(tval).tz_convert("UTC").isoformat()
            if (CHUNK_ROWS is not None) and (n > CHUNK_ROWS):
                for i0 in range(0, n, CHUNK_ROWS):
                    i1 = min(i0 + CHUNK_ROWS, n)
                    df = pd.DataFrame({
                        "lat": latf[i0:i1],
                        "lon": lonf[i0:i1],
                        "timestamp": ts_iso,
                        "eke": ekef[i0:i1],
                    })
                    write_parquet_block(df)
            else:
                df = pd.DataFrame({
                    "lat": latf,
                    "lon": lonf,
                    "timestamp": ts_iso,
                    "eke": ekef,
                })
                write_parquet_block(df)

        processed += 1

    except Exception as e:
        skipped.append((src.name, f"process failed: {e}"))

    finally:
        try: ds.close()
        except: pass
        if tmpdir:
            try: tmpdir.cleanup()
            except: pass

print(f"\n✅ Parquet dataset written to: {OUT_PARQ}")
print(f"   Files processed: {processed}/{len(files)}")
if skipped:
    print("\n⚠️ Skipped:")
    for name, reason in skipped:
        print(f"  - {name}: {reason}")


📂 Input:  C:\Users\Crist\Desktop\NASA\tag-and-satellite-data-model\downloads\eke\sample
🧺 Output Parquet (partitioned): C:\Users\Crist\Desktop\NASA\tag-and-satellite-data-model\transform\eke\parquet


Files:   0%|          | 0/30 [00:00<?, ?file/s]


✅ Parquet dataset written to: C:\Users\Crist\Desktop\NASA\tag-and-satellite-data-model\transform\eke\parquet
   Files processed: 30/30
