In [None]:
from __future__ import annotations

# --- Stdlib ---
import contextlib
import gc
import gzip
import json
import logging
import os
import shutil
import time
from pathlib import Path
from typing import Dict, Hashable, Iterable, Literal, Mapping, Optional, Sequence, Tuple

# --- Main ---
import __main__ as nb

# --- Data Science ---
import numpy as np
import pandas as pd
import geopandas as gpd
import h3

# --- Testing ---
import ipytest
import pytest

# --- Geo ---
from shapely.geometry import shape
from shapely.geometry.base import BaseGeometry

# --- Optional Resources ---
try:
    import psutil as _psutil
except Exception:
    _psutil = None

try:
    import tracemalloc as _tracemalloc
except Exception:
    _tracemalloc = None

# --- Config ---
ipytest.autoconfig()

# --- Logging setup ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

# --- Utilities ---
def ensure_dir(path: str | os.PathLike) -> None:
    os.makedirs(path if isinstance(path, str) else str(path), exist_ok=True)

def parent_dir_of(path: str | os.PathLike) -> str:
    return os.path.dirname(str(path)) or "."

def fmt_bytes(n: int | float) -> str:
    try:
        n = float(n)
    except Exception:
        return str(n)
    for unit in ("B", "KB", "MB", "GB", "TB", "PB"):
        if abs(n) < 1024 or unit == "PB":
            return f"{n:.1f}{unit}"
        n /= 1024

def gunzip_file(src: str, dst: str) -> None:
    ensure_dir(parent_dir_of(dst))
    with gzip.open(src, "rb") as fin, open(dst, "wb") as fout:
        shutil.copyfileobj(fin, fout)

def client_error_msg(e) -> str:
    code = getattr(e, "response", {}).get("Error", {}).get("Code")
    msg  = getattr(e, "response", {}).get("Error", {}).get("Message")
    return f"[{code}] {msg}" if code or msg else str(e)

@contextlib.contextmanager
def timed(label: str, enabled: bool = True):
    t0 = time.perf_counter()
    try:
        yield
    finally:
        if enabled:
            logger.info("%s took %.3f s", label, time.perf_counter() - t0)

def resource_snapshot(note: str = "") -> None:
    if not logger.isEnabledFor(logging.DEBUG):
        return
    parts: list[str] = []
    if _psutil:
        try:
            p = _psutil.Process(os.getpid())
            parts.append(f"rss={fmt_bytes(p.memory_info().rss)}")
            parts.append(f"cpu%~{p.cpu_percent(interval=0.0):.1f}")
        except Exception:
            pass
    if _tracemalloc and _tracemalloc.is_tracing():
        try:
            cur, peak = _tracemalloc.get_traced_memory()
            parts.append(f"py_mem={fmt_bytes(cur)}/{fmt_bytes(peak)}(peak)")
        except Exception:
            pass
    if parts:
        logger.debug("RES%s %s", f"[{note}]" if note else "", " ".join(parts))

# --- Project defaults ---
DEST_ROOT = "../data"

def resolve_path(rel: str | None, filename: str) -> str:
    if rel:
        return str(Path(rel) / filename)
    return str(Path(DEST_ROOT) / filename)

def load_project_files(
    file_map: dict[str, str],
    project_root_name: str = "ds_code_challenge",
    inject_globals: bool = True,
):
    """
    Load files from ROOT/data based on a {file_name: variable_name} map.

    Supports:
      - .csv / .csv.gz  -> pandas DataFrame
      - .geojson / .geojson.gz -> GeoDataFrame
      - .ods -> Excel (odf engine)
    """
    with timed("load_project_files (resolve ROOT/DATA_DIR)"):
        ROOT = Path(__file__).resolve().parents[0] if "__file__" in globals() else Path().resolve()
        while ROOT.name != project_root_name and ROOT.parent != ROOT:
            ROOT = ROOT.parent
        DATA_DIR = ROOT / "data"

    results: dict[str, pd.DataFrame | gpd.GeoDataFrame] = {}
    resource_snapshot("load_project_files:start")

    for file_name, var_name in file_map.items():
        file_path = DATA_DIR / file_name
        logger.info("Processing %s...", file_path)

        with timed(f"read:{file_path.name}"):
            suffix = "".join(file_path.suffixes).lower()

            if suffix.endswith((".csv", ".csv.gz")):
                df = pd.read_csv(file_path)
                results[var_name] = df
                logger.info("→ %s loaded as DataFrame shape=%s", var_name, getattr(df, "shape", None))

            elif suffix.endswith((".geojson", ".geojson.gz")):
                gdf = gpd.read_file(file_path)
                results[var_name] = gdf
                logger.info("→ %s loaded as GeoDataFrame len=%d", var_name, len(gdf))

            elif suffix.endswith(".ods"):
                df = pd.read_excel(file_path, engine="odf")
                results[var_name] = df
                logger.info("→ %s loaded from ODS shape=%s", var_name, getattr(df, "shape", None))

            else:
                logger.warning("Skipping unsupported file type: %s", file_path)
                continue

            if inject_globals:
                globals()[var_name] = results[var_name]

            resource_snapshot(f"after_load:{file_path.name}")

    logger.info("All files loaded successfully.")
    resource_snapshot("load_project_files:end")
    return results




In [None]:
file_map = {
    "sr_hex.csv": "df_sr_hex",
    "sr.csv": "df_sr"
}
datasets = load_project_files(file_map)

In [None]:
# %%
import numpy as np

def assign_h3_level8_sr(
    df: pd.DataFrame,
    lat_col: str = "latitude",
    lon_col: str = "longitude",
    *,
    resolution: int = 8,
    out_col: str | None = None,
    invalid_token: str = "0",
    validate_bounds: bool = True,
    threshold: float = 0.05,
    early_abort: bool = True,
    copy_frame: bool = True,
    log_validation: bool = True,
    log_join: bool = True,
) -> pd.DataFrame:
    """
    Attach an H3 (v4) index at the given `resolution` to each (lat, lon) row.

    Integrations with your support functions:
      • Wraps the whole op in `timed()` and emits `resource_snapshot()` at start/end (DEBUG only).
      • Uses your global `logger` formatting, with toggles for validation/join logs.
      • Avoids extra passes by deriving join stats from the validity mask.
      • Allows in-place mutation (copy_frame=False) for speed.

    Parameters
    ----------
    df : pd.DataFrame
        Input frame containing latitude/longitude columns.
    lat_col, lon_col : str
        Column names for coordinates.
    resolution : int, default 8
        H3 resolution. Name is kept generic; function still called assign_h3_level8.
    out_col : str | None
        Destination column name. Defaults to f"h3_level{resolution}_index".
    invalid_token : str, default "0"
        Value to assign when (lat, lon) are invalid/missing/out-of-bounds.
    validate_bounds : bool, default True
        If True, require -90<=lat<=90 and -180<=lon<=180 in addition to finite checks.
    threshold : float, default 0.05
        Max tolerated share of invalid rows before failing.
    early_abort : bool, default True
        If True, raise immediately when invalid share exceeds `threshold`.
    copy_frame : bool, default True
        If True, operate on a copy; otherwise mutate `df` in place.
    log_validation, log_join : bool
        Toggle INFO-level summary logs.

    Returns
    -------
    pd.DataFrame
        Frame with the new `out_col` (string H3 index).
    """

    # Column existence
    if lat_col not in df.columns or lon_col not in df.columns:
        raise ValueError(f"Missing required columns: {lat_col!r}, {lon_col!r}")

    # Output column default derived from resolution
    out_col = out_col or f"h3_level{resolution}_index"

    with timed(f"assign_h3_level{resolution}"):
        if logger.isEnabledFor(logging.DEBUG):
            resource_snapshot("start")

        # Choose frame to write into
        df_out = df.copy() if copy_frame else df

        # --- Validation (single pass to arrays) ---
        lat = pd.to_numeric(df_out[lat_col], errors="coerce").to_numpy()
        lon = pd.to_numeric(df_out[lon_col], errors="coerce").to_numpy()

        finite = np.isfinite(lat) & np.isfinite(lon)
        if validate_bounds:
            in_bounds = (lat >= -90.0) & (lat <= 90.0) & (lon >= -180.0) & (lon <= 180.0)
            valid = finite & in_bounds
        else:
            valid = finite

        n = int(lat.shape[0])
        n_valid = int(valid.sum())
        n_invalid = n - n_valid
        fail_rate_pre = (n_invalid / n) if n else 0.0

        if log_validation:
            logger.info(
                "Validation: rows=%d valid=%d failed_join=%d (%.2f%% invalid)",
                n, n_valid, n_invalid, fail_rate_pre * 100,
            )

        if early_abort and fail_rate_pre > threshold:
            if logger.isEnabledFor(logging.DEBUG):
                resource_snapshot("abort")
            raise RuntimeError(
                f"Failure rate {fail_rate_pre:.2%} exceeds threshold {threshold:.2%}"
            )

        # --- Compute H3 (fill invalids with token; compute only on valid mask) ---
        # Keep object dtype for safety with string indexes; can switch to pandas StringDtype if desired.
        out = np.empty(n, dtype=object)
        if n_invalid:
            out[~valid] = invalid_token

        if n_valid:
            # List comp over masked arrays is typically fastest for h3 python bindings
            out[valid] = [
                h3.latlng_to_cell(float(la), float(lo), int(resolution))
                for la, lo in zip(lat[valid], lon[valid])
            ]

        # Single assignment (no extra scan)
        df_out[out_col] = out

        # Join stats (same as validation stats here, since invalid => token)
        success = n_valid
        failed = n_invalid
        failed_rate = (failed / n) if n else 0.0

        if log_join:
            logger.info(
                "Join stats: success=%d failed=%d (%.2f%%)",
                success, failed, failed_rate * 100,
            )

        if n and failed_rate > threshold:
            # This is a late gate in case you *didn't* early abort
            raise RuntimeError(
                f"Too many failed joins: {failed_rate:.2%} exceeds threshold {threshold:.2%}"
            )

        if logger.isEnabledFor(logging.DEBUG):
            resource_snapshot("end")

        return df_out


Using direct H3 conversion was much faster than doing a direct join. Also doesn't give any mismatches anymore.

This function compares two df's based on keys and columns

In [None]:
# %%
import numpy as np
import pandas as pd
from typing import Sequence, Mapping, Hashable, Literal, Optional

def compare_dfs_by_keys(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    keys: str | Sequence[str],
    cols: Optional[Sequence[str]] = None,
    *,
    # Equality controls
    na_equal: bool = True,
    coerce_to_str: bool = False,
    case_insensitive: bool = False,
    numeric_atol: float | None = None,            # absolute tol
    numeric_rtol: float | None = None,            # relative tol (np.isclose-style)
    per_column_atol: Optional[Mapping[str, float]] = None,
    datetime_tol: pd.Timedelta | None = None,     # e.g., pd.Timedelta("1ms")
    # Join & prep
    drop_duplicate_keys: bool = False,
    how: Literal["outer", "inner", "left", "right"] = "outer",
    # Logging
    log_summary: bool = True,
) -> dict[str, pd.DataFrame]:
    """
    Compare two DataFrames by `keys` across one or many `cols`.

    Improvements:
      • Uses your `timed()` + `resource_snapshot()` for instrumentation (DEBUG only).
      • Tolerant comparisons:
          – Numeric: np.isclose with per-column overrides (`per_column_atol`) and global atol/rtol.
          – Datetime: |Δ| <= `datetime_tol` if provided.
      • Safe string handling with optional case-insensitive compare.
      • Duplicate-key control to avoid cartesian blowups.
      • Clear, compact summary with match rates per column.
    """
    with timed("compare_dfs_by_keys"):
        if logger.isEnabledFor(logging.DEBUG):
            resource_snapshot("start")

        # --- Normalize & guards ---
        if isinstance(keys, str):
            keys = [keys]
        keys = list(keys)

        for k in keys:
            if k not in df1.columns or k not in df2.columns:
                raise ValueError(f"Key '{k}' must exist in both DataFrames.")

        if cols is None:
            shared = set(df1.columns).intersection(df2.columns) - set(keys)
            if not shared:
                raise ValueError("No shared columns to compare (excluding keys). Provide 'cols'.")
            cols = sorted(shared)
        else:
            for c in cols:
                if c not in df1.columns or c not in df2.columns:
                    raise ValueError(f"Column '{c}' must exist in both DataFrames.")
            cols = list(cols)

        # --- Slice & (optional) dedupe on keys ---
        left = df1[keys + cols].copy()
        right = df2[keys + cols].copy()
        if drop_duplicate_keys:
            left = left.drop_duplicates(subset=keys, keep="first")
            right = right.drop_duplicates(subset=keys, keep="first")

        # --- Optional coercions ---
        def _to_str_preserve_na(s: pd.Series) -> pd.Series:
            # keeps NA as NA, stringifies everything else
            return s.astype("object").where(s.isna(), s.astype(str))

        if coerce_to_str:
            for c in cols:
                left[c] = _to_str_preserve_na(left[c])
                right[c] = _to_str_preserve_na(right[c])

        if case_insensitive:
            for c in cols:
                if (pd.api.types.is_object_dtype(left[c]) or pd.api.types.is_string_dtype(left[c])):
                    left[c]  = left[c].where(left[c].isna(),  left[c].astype(str).str.casefold())
                    right[c] = right[c].where(right[c].isna(), right[c].astype(str).str.casefold())

        # --- Merge (configurable 'how') ---
        merged = left.merge(right, on=keys, how=how, suffixes=("_df1", "_df2"), indicator=True)

        only_in_df1 = merged.loc[merged["_merge"] == "left_only",  keys].reset_index(drop=True)
        only_in_df2 = merged.loc[merged["_merge"] == "right_only", keys].reset_index(drop=True)
        both_mask = merged["_merge"] == "both"

        both_cols = keys + [f"{c}_df1" for c in cols] + [f"{c}_df2" for c in cols]
        both = merged.loc[both_mask, both_cols].copy()

        # --- Column-wise equality (tolerant for numeric/datetime) ---
        equal_masks: dict[str, pd.Series] = {}
        for c in cols:
            l = both[f"{c}_df1"]
            r = both[f"{c}_df2"]

            # Handle datetimes with tolerance if requested
            is_dt_l = pd.api.types.is_datetime64_any_dtype(l)
            is_dt_r = pd.api.types.is_datetime64_any_dtype(r)
            if is_dt_l and is_dt_r and datetime_tol is not None:
                # NAs equal?
                if na_equal:
                    na_eq = l.isna() & r.isna()
                else:
                    na_eq = pd.Series(False, index=l.index)
                non_na = ~(l.isna() | r.isna())
                eq_dt = pd.Series(False, index=l.index)
                if non_na.any():
                    eq_dt.loc[non_na] = (l[non_na] - r[non_na]).abs() <= datetime_tol
                equal_masks[c] = na_eq | eq_dt
                continue

            # Numeric with np.isclose
            is_num_l = pd.api.types.is_numeric_dtype(l)
            is_num_r = pd.api.types.is_numeric_dtype(r)
            if is_num_l and is_num_r and (numeric_atol is not None or numeric_rtol is not None or (per_column_atol and c in per_column_atol)):
                if na_equal:
                    na_eq = l.isna() & r.isna()
                else:
                    na_eq = pd.Series(False, index=l.index)
                non_na = ~(l.isna() | r.isna())
                eq_num = pd.Series(False, index=l.index)
                if non_na.any():
                    atol = per_column_atol.get(c) if per_column_atol else None
                    atol = numeric_atol if atol is None else atol
                    rtol = numeric_rtol if numeric_rtol is not None else 0.0
                    eq_num.loc[non_na] = np.isclose(l[non_na].astype(float), r[non_na].astype(float),
                                                    atol=(0.0 if atol is None else float(atol)),
                                                    rtol=float(rtol),
                                                    equal_nan=False)
                equal_masks[c] = na_eq | eq_num
                continue

            # Fallback: exact equality (with NA-equality if requested)
            eq = (l == r)
            if na_equal:
                eq = eq | (l.isna() & r.isna())
            equal_masks[c] = eq

        # --- Keys where ALL compared columns match ---
        if cols:
            all_equal_mask = np.logical_and.reduce([equal_masks[c].to_numpy() for c in cols])
        else:
            all_equal_mask = np.array([], dtype=bool)

        matches_all = both.loc[all_equal_mask, keys].reset_index(drop=True)

        # --- Long-form mismatches ---
        rows = []
        for c in cols:
            neq = ~equal_masks[c]
            if not neq.any():
                continue
            sub = both.loc[neq, keys + [f"{c}_df1", f"{c}_df2"]].copy()
            sub.insert(len(keys), "column", c)
            sub = sub.rename(columns={f"{c}_df1": "left", f"{c}_df2": "right"})

            # Add abs_diff for numeric, |Δ| for datetimes if types align
            if pd.api.types.is_numeric_dtype(sub["left"]) and pd.api.types.is_numeric_dtype(sub["right"]):
                sub["abs_diff"] = (sub["left"] - sub["right"]).abs()
            elif pd.api.types.is_datetime64_any_dtype(sub["left"]) and pd.api.types.is_datetime64_any_dtype(sub["right"]):
                sub["abs_diff"] = (sub["left"] - sub["right"]).abs()
            else:
                sub["abs_diff"] = pd.NA

            rows.append(sub)

        mismatches_long = (
            pd.concat(rows, ignore_index=True)
            if rows
            else pd.DataFrame(columns=[*keys, "column", "left", "right", "abs_diff"])
        )

        # --- Summary per column ---
        total_both = len(both)
        summary = pd.DataFrame(
            [
                {
                    "column": c,
                    "rows_in_both": total_both,
                    "matches": int(equal_masks[c].sum()),
                    "mismatches": int((~equal_masks[c]).sum()),
                    "match_rate": (float(equal_masks[c].mean()) if total_both else np.nan),
                }
                for c in cols
            ]
        )

        if log_summary:
            logger.info(
                "Compare summary: rows_in_both=%d | avg_match_rate=%.2f%%",
                total_both,
                100.0 * (summary["match_rate"].mean() if not summary.empty else np.nan),
            )

        if logger.isEnabledFor(logging.DEBUG):
            resource_snapshot("end")

        return {
            "only_in_df1": only_in_df1,
            "only_in_df2": only_in_df2,
            "both_rows": both.reset_index(drop=True),
            "matches_all": matches_all,
            "mismatches_long": mismatches_long.reset_index(drop=True),
            "summary": summary.reset_index(drop=True),
        }


In [None]:
df_sr_h3_l8 = assign_h3_level8_sr(df_sr, lat_col="latitude", lon_col="longitude", threshold=0.226, copy_frame=False, log_join=False)
df_sr_h3_l8.drop(columns=["Unnamed: 0"], inplace=True)



Threshold was chosen at 22.6% as 22.55% failed to join. Thus I chose a higher threshold so that the code would not error out early. The default if we don't set a threshold is 5% as this would give us a confidence level of 95%. 
Assignment of the h3 levels just used h3.latlng_to_cell instead of doing joining to the geojson file as it was much faster 

In [None]:


comparison_df = compare_dfs_by_keys(df_sr_h3_l8, df_sr_hex, keys="notification_number", cols=["h3_level8_index"])

print(comparison_df["summary"])



Unit tests of functions


In [None]:
%%ipytest -q -rA


# --- Make sure the notebook-defined function sees required globals ---
def _wire_globals_for_notebook_func():
    fn = nb.assign_h3_level8_sr

    # Provide pandas if not in function globals
    if "pd" not in fn.__globals__:
        fn.__globals__["pd"] = pd

    # Lightweight logger stub
    class _Logger:
        def debug(self, *a, **k): pass
        def info(self, *a, **k): pass
        def error(self, *a, **k): pass

    if "logger" not in fn.__globals__:
        fn.__globals__["logger"] = _Logger()

    
    if "error" not in fn.__globals__:
        fn.__globals__["error"] = lambda *a, **k: None

_wire_globals_for_notebook_func()

# Use real h3 if available; otherwise skip these tests cleanly
h3 = pytest.importorskip("h3")

def test_valid_coords_return_h3_cell():
    df = pd.DataFrame({"latitude": [-33.9249], "longitude": [18.4241]})
    out = nb.assign_h3_level8_sr(df, resolution=8, threshold=0.0, early_abort=False)
    assert out.loc[0, "h3_level8_index"] == h3.latlng_to_cell(-33.9249, 18.4241, 8)
    
def test_known_point_exact_index_level8():
    lat, lon = -33.872839, 18.522488
    expected = "88ad360225fffff"
    df = pd.DataFrame({"latitude": [lat], "longitude": [lon]})
    out = nb.assign_h3_level8_sr(df, resolution=8, threshold=0.0, early_abort=False)
    assert out.loc[0, "h3_level8_index"] == expected


@pytest.mark.parametrize("lat,lon", [
    (np.nan, 18.4),      
    (-33.9, np.nan),     
    (None, 18.4),        
    (-33.9, None),       
])
def test_nan_lat_or_lon_produces_zero(lat, lon):
    df = pd.DataFrame({"latitude": [lat], "longitude": [lon]})
    out = nb.assign_h3_level8_sr(df, threshold=1.0, early_abort=False)  
    assert out.loc[0, "h3_level8_index"] == "0"

def test_out_of_range_coords_produce_zero():
    df = pd.DataFrame({
        "latitude":  [-91.0,  0.0,   45.0],
        "longitude": [  0.0, 200.0,  np.nan],
    })
    out = nb.assign_h3_level8_sr(df, threshold=1.0, early_abort=False)
    assert out["h3_level8_index"].tolist() == ["0", "0", "0"]

def test_precheck_early_abort_on_high_invalid_ratio():
    
    df = pd.DataFrame({"latitude": [100.0, -33.9], "longitude": [0.0, 18.4]})
    with pytest.raises(RuntimeError, match="exceeds threshold"):
        nb.assign_h3_level8_sr(df, threshold=0.4, early_abort=True)

def test_resolution_changes_output():
    df = pd.DataFrame({"latitude": [-33.9249], "longitude": [18.4241]})

    out7 = nb.assign_h3_level8_sr(df, resolution=7, threshold=0.0, early_abort=False)
    out8 = nb.assign_h3_level8_sr(df, resolution=8, threshold=0.0, early_abort=False)

    col7 = "h3_level7_index"
    col8 = "h3_level8_index"

    assert col7 in out7.columns
    assert col8 in out8.columns
    assert out7.loc[0, col7] != out8.loc[0, col8]

def test_missing_columns_raise_value_error():
    df = pd.DataFrame({"latitude": [0.0]})  
    with pytest.raises(ValueError):
        nb.assign_h3_level8_sr(df)
