# Eddy Tracking Notebook — Inertial Cost with Gap Tolerance

**Project:** Eddy Detection in the Gulf of California (2010–2024)  
**Purpose:** Link daily eddy detections into tracks using an inertial cost (predicted position, diameter change, vorticity consistency, and type), with gap tolerance up to *G* days.  
**Inputs:** `eddy_detections_hybrid_GoC_2010_2024.nc` (or your detection output)  
**Outputs:** prints + in-memory tracks (optionally save later as CSV/NetCDF).  
**Notes:** Clean English-only version of the original tracking script; logic preserved.


In [None]:
import xarray as xr
import numpy as np
from math import radians
from pprint import pprint
from dataclasses import dataclass
from typing import List, Dict, Tuple

# --- Paths and global parameters
PATH_EDDIES       = "eddy_detections_hybrid_GoC_2010_2024.nc"  # detection NetCDF
DIST_MAX_KM       = 50.0   # max allowed great-circle distance (km) in cost
COST_THRESHOLD    = 1.2    # max total cost to accept a match
ALLOWED_GAP_DAYS  = 10     # days without detection before closing/bridging
MIN_TRACK_LENGTH  = 14     # keep tracks with >= this many detections


In [None]:
def haversine(lon1, lat1, lon2, lat2) -> float:
    """
    Great-circle distance in km between (lon1,lat1) and (lon2,lat2).
    Returns NaN if any input is NaN; numerically safe.
    """
    if any(np.isnan([lon1, lat1, lon2, lat2])):
        return np.nan
    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    a = np.clip(a, 0.0, 1.0)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371.0 * c


def update_velocity(eddy_prev, eddy_current, gap=None) -> None:
    """
    Estimate (vx, vy) in km/day between two consecutive states.
    Uses finite differences over 'gap' days if provided.
    """
    lat1, lon1 = eddy_prev.centroid
    lat2, lon2 = eddy_current.centroid
    dt = int(gap) if gap is not None else int((eddy_current.day - eddy_prev.day) / np.timedelta64(1, "D"))
    if dt <= 0:
        eddy_current.velocity = (0.0, 0.0)
        return
    latm = 0.5 * (lat1 + lat2)
    dx = (lon2 - lon1) * 111.0 * np.cos(np.radians(latm))
    dy = (lat2 - lat1) * 111.0
    eddy_current.velocity = (dx / dt, dy / dt)


def predict_position(eddy) -> Tuple[float, float]:
    """
    Predict next-day position (lat, lon) given current velocity (km/day).
    Converts km to degrees using local metric at eddy latitude.
    """
    vx, vy = eddy.velocity
    lat0, lon0 = eddy.centroid
    dlon = vx / (111.0 * np.cos(np.radians(lat0)) if np.cos(np.radians(lat0)) != 0 else np.inf)
    dlat = vy / 111.0
    return (lat0 + dlat, lon0 + dlon)


def matching_cost_inertia(eddy_prev, eddy_current,
                          dist_max=DIST_MAX_KM, alpha=1.0, beta=0.3, gamma=1.0, delta=100.0) -> float:
    """
    Total cost between two eddies:
      - Distance (to predicted position): scaled by dist_max and weight alpha.
      - Diameter change (relative): weight beta (+ extra penalty above 60%).
      - Vorticity change (relative): weight gamma (reject if > 1.0).
      - Type mismatch: large penalty delta.
    Lower is better.
    """
    # distance term (to inertial prediction)
    latp, lonp = predict_position(eddy_prev)
    lat2, lon2 = eddy_current.centroid
    dist = haversine(lonp, latp, lon2, lat2)
    cost_dist = (dist / dist_max) * alpha

    # diameter term
    d1, d2 = eddy_prev.diameter, eddy_current.diameter
    diff_d = abs(d1 - d2) / np.mean([d1, d2]) if (d1 > 0 and d2 > 0) else 1.0
    cost_diam = beta * (diff_d + max(0, diff_d - 0.6))  # extra penalty over 60%

    # vorticity term
    v1, v2 = eddy_prev.vorticity, eddy_current.vorticity
    mv = np.mean([abs(v1), abs(v2)]) if abs(v1) + abs(v2) > 0 else 1e-9
    diff_v = abs(v1 - v2) / mv
    if diff_v > 1.0:
        return 1e6  # hard reject
    cost_vort = gamma * diff_v

    # type term
    cost_type = 0.0 if eddy_prev.eddy_type == eddy_current.eddy_type else delta

    return cost_dist + cost_diam + cost_vort + cost_type


In [None]:
class EddyState:
    """
    Single-day eddy state used for tracking.

    Attributes
    ----------
    day : numpy.datetime64
    centroid : (lat, lon)
    diameter : float (km)
    major_axis, minor_axis : float (km)
    eccentricity : float
    vorticity : float
    eddy_type : str ("Cyclonic" or "Anticyclonic")
    velocity : (vx, vy) in km/day
    track_id : int or None
    """
    __slots__ = ("day","centroid","diameter","major_axis","minor_axis","eccentricity",
                 "vorticity","eddy_type","velocity","track_id")

    def __init__(self, day, centroid, diameter, major_axis, minor_axis,
                 eccentricity, vorticity, eddy_type,
                 velocity=(0.0, 0.0), track_id=None):
        self.day          = day
        self.centroid     = centroid
        self.diameter     = diameter
        self.major_axis   = major_axis
        self.minor_axis   = minor_axis
        self.eccentricity = eccentricity
        self.vorticity    = vorticity
        self.eddy_type    = eddy_type
        self.velocity     = velocity
        self.track_id     = track_id

    def __repr__(self):
        lat, lon = self.centroid
        return (f"EddyState(day={self.day.astype('datetime64[D]')}, "
                f"centroid=({lat:.2f},{lon:.2f}), diam={self.diameter:.1f}km, "
                f"major={self.major_axis:.1f}km, minor={self.minor_axis:.1f}km, "
                f"ecc={self.eccentricity:.2f}, vort={self.vorticity:.2e}, "
                f"type={self.eddy_type}, vel=({self.velocity[0]:.1f},{self.velocity[1]:.1f}), "
                f"track={self.track_id})")


In [None]:
def load_dataset(path: str) -> xr.Dataset:
    """Open the NetCDF with detection outputs."""
    return xr.open_dataset(path)


def build_detections_dict(ds: xr.Dataset) -> Dict[np.datetime64, List[dict]]:
    """
    Rebuild per-day detections using the 'eddy' dimension.
    Compatible with either 'eccentricity' (EN) or 'excentricidad' (ES).
    """
    times = ds.time.values
    eddies = ds.eddy.values

    # Handle both variable names for eccentricity
    if "eccentricity" in ds.variables:
        ecc_name = "eccentricity"
    elif "excentricidad" in ds.variables:
        ecc_name = "excentricidad"
    else:
        raise KeyError("Eccentricity variable not found (expected 'eccentricity' or 'excentricidad').")

    dets: Dict[np.datetime64, List[dict]] = {}
    for i, t in enumerate(times):
        day_list = []
        for j in eddies:
            latv = ds.centroid_lat.values[i, j]
            lonv = ds.centroid_lon.values[i, j]
            if np.isfinite(latv) and np.isfinite(lonv):
                day_list.append({
                    "time": t,
                    "centroid": (float(latv), float(lonv)),   # (lat, lon)
                    "diameter": float(ds.diameter_km.values[i, j]),
                    "major_axis": float(ds.major_axis_km.values[i, j]),
                    "minor_axis": float(ds.minor_axis_km.values[i, j]),
                    "eccentricity": float(ds[ecc_name].values[i, j]),
                    "vorticity": float(ds.vorticity.values[i, j]),
                    "type": str(ds.type.values[i, j]),
                })
        dets[np.datetime64(t)] = day_list
    return dets


def convert_to_states(dets: Dict[np.datetime64, List[dict]]) -> Dict[np.datetime64, List[EddyState]]:
    """Convert detection dict into EddyState instances."""
    states: Dict[np.datetime64, List[EddyState]] = {}
    for day, lst in dets.items():
        states[day] = [
            EddyState(
                day=dd["time"],
                centroid=dd["centroid"],
                diameter=dd["diameter"],
                major_axis=dd["major_axis"],
                minor_axis=dd["minor_axis"],
                eccentricity=dd["eccentricity"],
                vorticity=dd["vorticity"],
                eddy_type=dd["type"],
            )
            for dd in lst
        ]
    return states


In [None]:
def track_eddies(states_dict: Dict[np.datetime64, List[EddyState]],
                 dist_max=DIST_MAX_KM,
                 cost_threshold=COST_THRESHOLD,
                 allowed_gap=ALLOWED_GAP_DAYS) -> List[List[EddyState]]:
    """
    Build tracks (list of EddyState lists) by greedy inertial-cost matching with gap tolerance.
    """
    days = sorted(states_dict.keys())
    tracks: List[List[EddyState]] = []
    next_id = 0

    # Initialize with first day
    if days:
        for ed in states_dict[days[0]]:
            ed.track_id = next_id
            tracks.append([ed])
            next_id += 1

    # Iterate subsequent days
    for day in days[1:]:
        for ed_c in states_dict[day]:
            # candidate tracks whose last observation is within allowed gap
            candidates = []
            for tr in tracks:
                last = tr[-1]
                gap = int((day - last.day) / np.timedelta64(1, "D"))
                if gap <= allowed_gap:
                    candidates.append((tr, last, gap))

            # choose best by cost (with mild gap penalty)
            best, best_cost = None, np.inf
            for tr, last, gap in candidates:
                cost = matching_cost_inertia(last, ed_c, dist_max)
                if gap > 1:
                    cost *= (1 + 0.1 * (gap - 1))  # penalize gaps > 1 day
                if cost < best_cost:
                    best_cost, best = cost, (tr, last, gap)

            # accept if under threshold; otherwise start a new track
            if best and best_cost < cost_threshold:
                tr, last, gap = best
                ed_c.track_id = tr[0].track_id
                update_velocity(last, ed_c, gap)
                tr.append(ed_c)
            else:
                ed_c.track_id = next_id
                tracks.append([ed_c])
                next_id += 1

    return tracks


In [None]:
ds = load_dataset(PATH_EDDIES)
print(f"Dataset loaded: {len(ds.time)} times, {len(ds.eddy)} eddy slots.")

dets   = build_detections_dict(ds)
states = convert_to_states(dets)

tracks = track_eddies(states)
print(f"Total tracks (pre-filter): {len(tracks)}")

# Filter by minimum detections per track
filtered_tracks = [trk for trk in tracks if len(trk) >= MIN_TRACK_LENGTH]
print(f"Tracks with ≥ {MIN_TRACK_LENGTH} detections: {len(filtered_tracks)}")

# Quick summary
lengths = [len(tr) for tr in tracks]
print(f"Mean detections per track: {np.mean(lengths):.1f}")

# Print first 3 tracks (compact)
for i, tr in enumerate(tracks[:3]):
    print(f"\nTrack {i} ({len(tr)} detections):")
    pprint(tr)
