<a href="https://colab.research.google.com/github/SullyC25/Week4/blob/main/Collaction_J3%2BSRAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 0) Colab Drive + deps
from google.colab import drive; drive.mount('/content/drive', force_remount=False)
%pip -q install numpy pandas scipy xarray netCDF4 h5netcdf pyarrow fastparquet

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 1) Paths & run params
import os, json, datetime, importlib, warnings
import numpy as np, pandas as pd, xarray as xr
from pathlib import Path
from scipy.spatial import cKDTree

warnings.filterwarnings("ignore", category=FutureWarning)

BASE = "/content/drive/MyDrive"
PREP = f"{BASE}/Preprocessing"
OUT  = f"{BASE}/CollocationV3"
os.makedirs(OUT, exist_ok=True)

S3_PARQ = f"{PREP}/s3_sral_bangladesh_20hz.parquet"
J3_NC   = f"{PREP}/processed_sla_2024_bangladesh.nc"  # Matches preprocessing cell

# Study box (same as preprocessing)
BBOX = (88.0, 93.0, 20.0, 26.0)  # lon_min, lon_max, lat_min, lat_max (0–360 lon)
# Gentle QC settings (used only for the "qc" output)
ABS_SLA_CAP = 3.0     # allow wide amplitude (±3 m)
MAD_CAP_M   = 1.5     # ΔSLA robust cap never exceeds ±1.5 m

# Two collocation windows
RUNS = [
    {"tag":"baseline", "dist_km":10.0, "dt_hours":24.0},
    {"tag":"relaxed",  "dist_km":20.0, "dt_hours":36.0},
]

In [None]:
# 2) Utilities
ENGINE = "netcdf4" if importlib.util.find_spec("netCDF4") else "h5netcdf"

def to_0360(lon):
    lon = np.asarray(lon)
    lon = np.where(np.isfinite(lon), np.mod(lon, 360.0), lon)
    return lon

def pick_name(ds, options):
    for k in options:
        if k in ds.variables or k in ds.coords:
            return k
    low = {k.lower(): k for k in list(ds.variables)+list(ds.coords)}
    for want in options:
        for lk, orig in low.items():
            if want.lower() == lk or want.lower() in lk:
                return orig
    return None

def haversine_km(lon1, lat1, lon2, lat2):
    lon1=np.radians(lon1); lat1=np.radians(lat1)
    lon2=np.radians(lon2); lat2=np.radians(lat2)
    dlon=lon2-lon1; dlat=lat2-lat1
    a=np.sin(dlat/2)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2)**2
    return 2*6371.0*np.arcsin(np.minimum(1, np.sqrt(a)))

def robust_delta_clip(df, col_a="sla_s3", col_b="sla_j3", mad_cap=MAD_CAP_M):
    d = (df[col_a] - df[col_b]).astype("float64")
    if len(d) == 0:
        return df
    mad = float(np.median(np.abs(d - np.median(d))))
    thr = min(mad_cap, 3.0*1.4826*mad) if mad > 0 else 1.0
    return df[np.abs(d) <= thr].copy()

In [None]:
# 3) Load SRAL (20 Hz parquet)
assert os.path.exists(S3_PARQ), f"Missing SRAL parquet: {S3_PARQ}"
s3 = pd.read_parquet(S3_PARQ).copy()

# Standardize columns (accepts variations)
def pick_col(df, wants):
    for w in wants:
        if w in df.columns: return w
    low = {c.lower(): c for c in df.columns}
    for w in wants:
        if w.lower() in low: return low[w.lower()]
    raise KeyError(f"None of {wants} in {df.columns[:20].tolist()}…")

t_s3 = pick_col(s3, ["time","utc_time","datetime"])
x_s3 = pick_col(s3, ["lon","longitude"])
y_s3 = pick_col(s3, ["lat","latitude"])
z_s3 = pick_col(s3, ["sla","ssha","sea_level_anomaly"])
s3 = s3.rename(columns={t_s3:"time", x_s3:"lon", y_s3:"lat", z_s3:"sla"})
s3["time"] = pd.to_datetime(s3["time"], errors="coerce")
s3["lon"]  = to_0360(s3["lon"].to_numpy())
s3 = s3.dropna(subset=["time","lon","lat","sla"]).reset_index(drop=True)
# Keep within ROI but don't over-trim (box only)
s3 = s3[s3["lon"].between(BBOX[0],BBOX[1]) & s3["lat"].between(BBOX[2],BBOX[3])].reset_index(drop=True)


In [None]:
# 4) Load Jason‑3 DUACS L4 (daily gridded)
assert os.path.exists(J3_NC), f"Missing Jason-3 NetCDF: {J3_NC}"
try:
    ds = xr.open_dataset(J3_NC, engine=ENGINE, decode_times=True, mask_and_scale=True)
except Exception:
    ds = xr.open_dataset(J3_NC, engine=ENGINE, decode_times=False, mask_and_scale=True)
# Identify coordinate names
lon_name = pick_name(ds, ["longitude","lon"])
lat_name = pick_name(ds, ["latitude","lat"])
tim_name = pick_name(ds, ["time","utc_time"])
assert all([lon_name, lat_name, tim_name]), "Could not find lon/lat/time in Jason-3 dataset."

# Ensure 0–360 longitude
if float(ds[lon_name].max()) <= 180.0:
    # shift from [-180,180] to [0,360)
    lon_vals = ds[lon_name].values
    lon_vals = np.mod(lon_vals, 360.0)
    ds = ds.assign_coords({lon_name: lon_vals})
# Spatial subsetting to ROI for speed
ds = ds.sel({lat_name: slice(BBOX[2], BBOX[3]),
             lon_name: slice(BBOX[0], BBOX[1])})
# Sort to be safe
ds = ds.sortby([tim_name, lat_name, lon_name])

# Collect ALL numeric data variables (not just SLA)
j3_data_vars = [v for v in ds.data_vars if np.issubdtype(ds[v].dtype, np.number)]
assert "sla" in j3_data_vars or any("sla" in v.lower() or "ssha" in v.lower() for v in j3_data_vars), \
       f"Expected SLA-like variable among {j3_data_vars[:10]}"

# Cache coord arrays
j3_times = pd.to_datetime(ds[tim_name].values)
j3_lons  = ds[lon_name].values
j3_lats  = ds[lat_name].values

# Pre-build per-day KDTree over grid cell centers for fast nearest-neighbor
# (vectorized lookup per SRAL point's nearest date)
def build_spatial_index_for_time(t_idx):
    # returns KDTree, and arrays to map index back to (ilat, ilon)
    # meshgrid order: lat x lon
    LON, LAT = np.meshgrid(j3_lons, j3_lats)  # shapes (nlat, nlon)
    # scale lon by cos(phi0) for near-equirectangular KD
    phi0 = np.deg2rad(np.nanmean(LAT))
    scale = np.cos(phi0)
    XY = np.c_[ (LON.ravel()*scale), LAT.ravel() ]
    tree = cKDTree(XY)
    return tree, scale, LAT.shape

# Build once (grid is static over time)
SP_TREE, SP_SCALE, LAT_SHAPE = build_spatial_index_for_time(0)

def nearest_time_index(t):
    # j3 is daily; pick nearest date index
    return int(np.argmin(np.abs(j3_times.values - np.datetime64(t))))

def collocate_once(run, apply_qc=False):
    """Return a DataFrame of collocations for a given run spec.
       apply_qc=False → 'min' path (geom/time only)
       apply_qc=True  → 'qc'  path (gentle amplitude + robust ΔSLA clip)
    """
    dist_km = float(run["dist_km"])
    dt_h    = float(run["dt_hours"])
    dt_tol  = pd.Timedelta(hours=dt_h)

    # Map each SRAL point to its nearest J3 date index
    t_idx = np.array([nearest_time_index(t) for t in s3["time"].to_numpy()])
    t_near = j3_times.values[t_idx]
    dt = (s3["time"].to_numpy().astype('datetime64[ns]') - t_near.astype('datetime64[ns]'))
    dt = np.abs(dt).astype('timedelta64[ns]').astype('timedelta64[ms]').astype('timedelta64[ns]')
    within_time = (np.abs(pd.to_timedelta(dt)) <= dt_tol)

    # Spatial nearest grid cell (same for all times since grid is static)
    q_pts = np.c_[ s3["lon"].to_numpy()*SP_SCALE, s3["lat"].to_numpy() ]
    _, ij_flat = SP_TREE.query(q_pts, k=1)
    nlat, nlon = LAT_SHAPE
    ilat = ij_flat // nlon
    ilon = ij_flat %  nlon

    # Compute great-circle distance to the cell center (more accurate than KD plane)
    lon_cell = j3_lons[ilon]
    lat_cell = j3_lats[ilat]
    dist = haversine_km(s3["lon"].to_numpy(), s3["lat"].to_numpy(), lon_cell, lat_cell)
    within_dist = (dist <= dist_km)

    keep = within_time & within_dist
    if not np.any(keep):
        return pd.DataFrame(columns=[])

    # Gather Jason‑3 fields at (t_idx, ilat, ilon)

    take_rows = np.where(keep)[0]
    ti = t_idx[take_rows]
    yi = ilat[take_rows]
    xi = ilon[take_rows]

    # Create output with SRAL columns
    out = pd.DataFrame({
        "time_s3": s3["time"].values[take_rows],
        "lon":     s3["lon"].values[take_rows].astype("float32"),
        "lat":     s3["lat"].values[take_rows].astype("float32"),
        "sla_s3":  s3["sla"].values[take_rows].astype("float32"),
    })
    # Diagnostics
    out["j3_time"] = j3_times.values[ti]
    out["j3_lon"]  = j3_lons[xi].astype("float32")
    out["j3_lat"]  = j3_lats[yi].astype("float32")
    out["dist_km"] = dist[take_rows].astype("float32")
    out["dt_hours"]= (np.abs(out["time_s3"] - out["j3_time"]).dt.total_seconds()/3600.0).astype("float32")

    # Attach ALL numeric J3 vars
    for v in j3_data_vars:
        # guard against unexpected dim order; rely on name-based indexing
        da = ds[v]
        # Sanity: ensure dims contain time/lat/lon in some order
        dims = list(da.dims)
        i_t = dims.index(tim_name) if tim_name in dims else None
        i_y = dims.index(lat_name) if lat_name in dims else None
        i_x = dims.index(lon_name) if lon_name in dims else None
        if None in (i_t, i_y, i_x):
            continue  # skip non-(t,y,x) variables
        # Use .isel with arrays
        vals = da.isel({tim_name: xr.DataArray(ti, dims=("z",)),
                        lat_name: xr.DataArray(yi, dims=("z",)),
                        lon_name: xr.DataArray(xi, dims=("z",))}).values
        out[f"{v}_j3"] = vals.astype("float32")

    # Optional gentle QC path
    if apply_qc:
        # SLA cap on both sides if present
        if "sla_j3" in out.columns:
            mask_amp = (
                np.isfinite(out["sla_s3"]) & np.isfinite(out["sla_j3"]) &
                out["sla_s3"].between(-ABS_SLA_CAP, ABS_SLA_CAP) &
                out["sla_j3"].between(-ABS_SLA_CAP, ABS_SLA_CAP)
            )
            out = out[mask_amp].copy()
        # robust delta clip (S3 − J3)
        if "sla_j3" in out.columns:
            out = robust_delta_clip(out, col_a="sla_s3", col_b="sla_j3", mad_cap=MAD_CAP_M)

    # Final tidy types
    for c in ["lon","lat","j3_lon","j3_lat","dist_km","dt_hours"]:
        if c in out.columns: out[c] = out[c].astype("float32")
    return out.reset_index(drop=True)

In [None]:
# 5) Run collocation
summary_rows = []
for run in RUNS:
    tag = run["tag"]

    pairs_min = collocate_once(run, apply_qc=False)
    min_path  = f"{OUT}/s3_j3_pairs_{tag}_min.parquet"
    pairs_min.to_parquet(min_path, index=False)

    # derive simple stats if sla_j3 exists
    if "sla_j3" in pairs_min.columns and len(pairs_min):
        d = (pairs_min["sla_s3"] - pairs_min["sla_j3"]).astype("float64")
        bias_min = float(np.nanmean(d))
        rmse_bf_min = float(np.sqrt(np.nanmean((d - np.nanmean(d))**2)))
        R_min = float(np.corrcoef(pairs_min["sla_s3"], pairs_min["sla_j3"])[0,1]) if len(pairs_min) > 2 else np.nan
    else:
        bias_min = rmse_bf_min = R_min = np.nan

    pairs_qc = collocate_once(run, apply_qc=True)
    qc_path  = f"{OUT}/s3_j3_pairs_{tag}_qc.parquet"
    pairs_qc.to_parquet(qc_path, index=False)

    if "sla_j3" in pairs_qc.columns and len(pairs_qc):
        d = (pairs_qc["sla_s3"] - pairs_qc["sla_j3"]).astype("float64")
        bias_qc = float(np.nanmean(d))
        rmse_bf_qc = float(np.sqrt(np.nanmean((d - np.nanmean(d))**2)))
        R_qc = float(np.corrcoef(pairs_qc["sla_s3"], pairs_qc["sla_j3"])[0,1]) if len(pairs_qc) > 2 else np.nan
    else:
        bias_qc = rmse_bf_qc = R_qc = np.nan

    summary_rows.append({
        "tag": tag,
        "N_min": int(len(pairs_min)), "bias_min_m": bias_min, "rmse_bf_min_m": rmse_bf_min, "R_min": R_min,
        "N_qc":  int(len(pairs_qc)),  "bias_qc_m":  bias_qc,  "rmse_bf_qc_m":  rmse_bf_qc,  "R_qc":  R_qc,
        "dist_km": run["dist_km"], "dt_hours": run["dt_hours"]
    })

In [None]:
# 6) Save params + summary
params = {
    "generated_utc": datetime.datetime.utcnow().isoformat()+"Z",
    "inputs": {"s3_parquet": S3_PARQ, "j3_netcdf": J3_NC},
    "bbox": BBOX,
    "qc": {"abs_sla_cap_m": ABS_SLA_CAP, "delta_sla_clip": f"±min({MAD_CAP_M} m, 3×MAD)"},
    "runs": RUNS,
    "notes": [
        "Longitude convention kept at 0–360 to match Jason-3 preprocessing.",
        "Pairs_min: spatial/time windows only. Pairs_qc: adds gentle amplitude + robust ΔSLA clip.",
        "All numeric Jason-3 variables are attached at matched grid cell and time.",
    ],
}
with open(f"{OUT}/params_collocation_v3.json","w") as f: json.dump(params, f, indent=2)

summary = pd.DataFrame(summary_rows, columns=[
    "tag","N_min","bias_min_m","rmse_bf_min_m","R_min","N_qc","bias_qc_m","rmse_bf_qc_m","R_qc","dist_km","dt_hours"
])
summary_path = f"{OUT}/summary_collocation_v3.csv"
summary.to_csv(summary_path, index=False)

print("Saved files:")
print("  -", f"{OUT}/params_collocation_v3.json")
for r in RUNS:
    print("  -", f"{OUT}/s3_j3_pairs_{r['tag']}_min.parquet")
    print("  -", f"{OUT}/s3_j3_pairs_{r['tag']}_qc.parquet")
print("  -", summary_path)
summary

Saved files:
  - /content/drive/MyDrive/CollocationV3/params_collocation_v3.json
  - /content/drive/MyDrive/CollocationV3/s3_j3_pairs_baseline_min.parquet
  - /content/drive/MyDrive/CollocationV3/s3_j3_pairs_baseline_qc.parquet
  - /content/drive/MyDrive/CollocationV3/s3_j3_pairs_relaxed_min.parquet
  - /content/drive/MyDrive/CollocationV3/s3_j3_pairs_relaxed_qc.parquet
  - /content/drive/MyDrive/CollocationV3/summary_collocation_v3.csv


Unnamed: 0,tag,N_min,bias_min_m,rmse_bf_min_m,R_min,N_qc,bias_qc_m,rmse_bf_qc_m,R_qc,dist_km,dt_hours
0,baseline,11833,-0.128698,1.030362,0.099556,10880,-0.0213,0.100822,0.754867,10.0,24.0
1,relaxed,11833,-0.128698,1.030362,0.099556,10880,-0.0213,0.100822,0.754867,20.0,36.0
