In [1]:
import ipytest, pytest, pandas as pd
ipytest.autoconfig()

In [2]:
import logging
logging.basicConfig(
    level=logging.INFO,  
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger(__name__)

In [3]:
import pandas as pd
import geopandas as gpd
from pathlib import Path

ROOT = Path(__file__).resolve().parents[0] if "__file__" in globals() else Path().resolve()
while ROOT.name != "ds_code_challenge" and ROOT.parent != ROOT:
    ROOT = ROOT.parent

DATA_DIR = ROOT / "data"

file_map = {
    "sr.csv": "df_sr",
    "sr_hex.csv": "df_sr_hex",
    "sr_hex_truncated.csv": "df_sr_hex_truncated",
    "city-hex-polygons-8.geojson": "gdf_city_hex_8"
}

# Load the files
for file_name, var_name in file_map.items():
    file_path = DATA_DIR / file_name
    logger.info(f"\nProcessing {file_path}...")

    if file_path.suffix == ".csv":
        df = pd.read_csv(file_path)
        globals()[var_name] = df

    elif file_path.suffix == ".geojson":
        gdf = gpd.read_file(file_path)
        globals()[var_name] = gdf

logger.info("All files loaded successfully.")


2025-08-19 21:30:10 | INFO     | 
Processing /Users/ben/Documents/test/ds_code_challenge/data/sr.csv...
2025-08-19 21:30:12 | INFO     | 
Processing /Users/ben/Documents/test/ds_code_challenge/data/sr_hex.csv...
2025-08-19 21:30:14 | INFO     | 
Processing /Users/ben/Documents/test/ds_code_challenge/data/sr_hex_truncated.csv...
2025-08-19 21:30:14 | INFO     | 
Processing /Users/ben/Documents/test/ds_code_challenge/data/city-hex-polygons-8.geojson...
2025-08-19 21:30:14 | INFO     | All files loaded successfully.


In [9]:
import os, gc, time, logging
import numpy as np
import pandas as pd

# ----------------- Robust logger setup (no duplicates) -----------------
def get_logger(name="h3_assign", level=logging.INFO) -> logging.Logger:
    lg = logging.getLogger(name)
    lg.setLevel(level)
    lg.propagate = False  # stop bubbling to root (prevents duplicate lines)

    # Remove existing handlers (safe for re-running in notebooks)
    for h in list(lg.handlers):
        lg.removeHandler(h)

    h = logging.StreamHandler()
    h.setLevel(level)
    h.setFormatter(logging.Formatter(
        "%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    ))
    lg.addHandler(h)
    return lg

logger = get_logger(level=logging.INFO)   # set to WARNING/ERROR to quiet logs

# ----------------- Optional deps -----------------
try:
    import psutil
except ImportError:
    psutil = None

try:
    import tracemalloc
    _TRACEMALLOC_OK = True
except ImportError:
    tracemalloc = None
    _TRACEMALLOC_OK = False

try:
    import h3
    _H3_OK = True
except ImportError:
    h3 = None
    _H3_OK = False


# ----------------- Helpers -----------------
def _fmt_bytes(n: int) -> str:
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if n < 1024:
            return f"{n:.1f}{unit}"
        n /= 1024
    return f"{n:.1f}PB"

def _resource_snapshot(note: str = ""):
    """Cheap unless DEBUG; shows RSS/VMS/threads/CPU%/heap/GC."""
    if not logger.isEnabledFor(logging.DEBUG):
        return
    parts = []
    if psutil:
        p = psutil.Process(os.getpid())
        mem = p.memory_info()
        parts.append(f"RSS={_fmt_bytes(mem.rss)} VMS={_fmt_bytes(mem.vms)}")
        parts.append(f"Threads={p.num_threads()} CPU%={p.cpu_percent(interval=None):.1f}")
    if _TRACEMALLOC_OK and tracemalloc.is_tracing():
        cur, peak = tracemalloc.get_traced_memory()
        parts.append(f"Heap={_fmt_bytes(cur)}/{_fmt_bytes(peak)}")
    parts.append(f"GC={gc.get_count()}")
    logger.debug("[%s] %s", note, "; ".join(parts))


# ----------------- Main -----------------
def assign_h3_level8(
    df: pd.DataFrame,
    lat_col: str = "latitude",
    long_col: str = "longitude",
    threshold: float = 0.05,
    resolution: int = 8,
    early_abort: bool = True,
    *,
    log_validation: bool = True,
    log_join: bool = True,
    copy_frame: bool = True,
) -> pd.DataFrame:
    """
    Assign H3 (v4) hex index at `resolution` to points.

    Performance/Logging notes:
      • Resource snapshots only in DEBUG level.
      • `log_validation` / `log_join` toggle INFO prints.
      • Avoids extra O(n) scans by deriving join stats from `valid`.
      • Set `copy_frame=False` to mutate the input frame (fastest).
    """
    if not _H3_OK:
        raise RuntimeError("h3 is not installed")

    if lat_col not in df.columns or long_col not in df.columns:
        raise ValueError(f"Missing required columns: {lat_col}, {long_col}")

    # Enable tracemalloc only if DEBUG
    tracing = False
    if logger.isEnabledFor(logging.DEBUG) and _TRACEMALLOC_OK:
        tracemalloc.start()
        tracing = True

    t0 = time.perf_counter()

    # Choose frame
    df_out = df.copy() if copy_frame else df

    # --- Validation ---
    lat = pd.to_numeric(df_out[lat_col], errors="coerce").to_numpy()
    lon = pd.to_numeric(df_out[long_col], errors="coerce").to_numpy()
    valid = (
        np.isfinite(lat) & np.isfinite(lon) &
        (lat >= -90.0) & (lat <= 90.0) &
        (lon >= -180.0) & (lon <= 180.0)
    )
    n = int(lat.shape[0])
    n_valid = int(valid.sum())
    n_invalid = n - n_valid
    fail_rate_pre = (n_invalid / n) if n else 0.0

    if log_validation:
        logger.info("Validation: rows=%d invalid=%d (%.2f%%)", n, n_invalid, fail_rate_pre * 100)

    if early_abort and fail_rate_pre > threshold:
        if tracing:
            tracemalloc.stop()
        raise RuntimeError(
            f"Failure rate {fail_rate_pre:.2%} exceeds threshold {threshold:.2%}"
        )

    # --- Compute H3 (object dtype; switch to uint64 in future if you want max speed) ---
    out = np.full(n, "0", dtype=object)
    if n_valid:
        out[valid] = [
            h3.latlng_to_cell(la, lo, resolution)
            for la, lo in zip(lat[valid], lon[valid])
        ]

    # Assign without extra copy
    df_out["h3_level8_index"] = out

    # --- Join stats from 'valid' (no rescan of out) ---
    success = n_valid
    failed = n_invalid
    failed_rate = (failed / n) if n else 0.0

    if log_join:
        logger.info("Join stats: success=%d failed=%d (%.2f%%)", success, failed, failed_rate * 100)

    if n and failed_rate > threshold:
        if tracing:
            tracemalloc.stop()
        raise RuntimeError("Too many failed joins")

    # Done
    logger.info("Total elapsed: %.3fs", time.perf_counter() - t0)
    _resource_snapshot("end")

    if tracing:
        tracemalloc.stop()

    return df_out


In [5]:
%%ipytest -q -rA
import __main__ as nb
import pandas as pd
import numpy as np
import pytest

# --- Make sure the notebook-defined function sees required globals ---
def _wire_globals_for_notebook_func():
    fn = nb.assign_h3_level8

    # Provide pandas if not in function globals
    if "pd" not in fn.__globals__:
        fn.__globals__["pd"] = pd

    # Lightweight logger stub
    class _Logger:
        def debug(self, *a, **k): pass
        def info(self, *a, **k): pass
        def error(self, *a, **k): pass

    if "logger" not in fn.__globals__:
        fn.__globals__["logger"] = _Logger()

    
    if "error" not in fn.__globals__:
        fn.__globals__["error"] = lambda *a, **k: None

_wire_globals_for_notebook_func()

# Use real h3 if available; otherwise skip these tests cleanly
h3 = pytest.importorskip("h3")

def test_valid_coords_return_h3_cell():
    df = pd.DataFrame({"latitude": [-33.9249], "longitude": [18.4241]})
    out = nb.assign_h3_level8(df, resolution=8, threshold=0.0, early_abort=False)
    assert out.loc[0, "h3_level8_index"] == h3.latlng_to_cell(-33.9249, 18.4241, 8)
    
def test_known_point_exact_index_level8():
    lat, lon = -33.872839, 18.522488
    expected = "88ad360225fffff"
    df = pd.DataFrame({"latitude": [lat], "longitude": [lon]})
    out = nb.assign_h3_level8(df, resolution=8, threshold=0.0, early_abort=False)
    assert out.loc[0, "h3_level8_index"] == expected


@pytest.mark.parametrize("lat,lon", [
    (np.nan, 18.4),      
    (-33.9, np.nan),     
    (None, 18.4),        
    (-33.9, None),       
])
def test_nan_lat_or_lon_produces_zero(lat, lon):
    df = pd.DataFrame({"latitude": [lat], "longitude": [lon]})
    out = nb.assign_h3_level8(df, threshold=1.0, early_abort=False)  
    assert out.loc[0, "h3_level8_index"] == "0"

def test_out_of_range_coords_produce_zero():
    df = pd.DataFrame({
        "latitude":  [-91.0,  0.0,   45.0],
        "longitude": [  0.0, 200.0,  np.nan],
    })
    out = nb.assign_h3_level8(df, threshold=1.0, early_abort=False)
    assert out["h3_level8_index"].tolist() == ["0", "0", "0"]

def test_precheck_early_abort_on_high_invalid_ratio():
    
    df = pd.DataFrame({"latitude": [100.0, -33.9], "longitude": [0.0, 18.4]})
    with pytest.raises(RuntimeError, match="exceeds threshold"):
        nb.assign_h3_level8(df, threshold=0.4, early_abort=True)

def test_resolution_changes_output():
    df = pd.DataFrame({"latitude": [-33.9249], "longitude": [18.4241]})
    out7 = nb.assign_h3_level8(df, resolution=7, threshold=0.0, early_abort=False)
    out8 = nb.assign_h3_level8(df, resolution=8, threshold=0.0, early_abort=False)
    assert out7.loc[0, "h3_level8_index"] != out8.loc[0, "h3_level8_index"]

def test_missing_columns_raise_value_error():
    df = pd.DataFrame({"latitude": [0.0]})  
    with pytest.raises(ValueError):
        nb.assign_h3_level8(df)


2025-08-19 21:30:14 | INFO     | Validation: rows=1 invalid=0 (0.00%)
2025-08-19 21:30:14 | INFO     | Join stats: success=1 failed=0 (0.00%)
2025-08-19 21:30:14 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=1 failed=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=0 failed=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=0 failed=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=0 failed=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=0 failed=1 (100.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=3 invalid=3 (100.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=0 failed=3 (100.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=2 invalid=1 (50.00%)


[32m.[0m

2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=1 failed=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s
2025-08-19 21:30:15 | INFO     | Validation: rows=1 invalid=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Join stats: success=1 failed=0 (0.00%)
2025-08-19 21:30:15 | INFO     | Total elapsed: 0.001s


[32m.[0m[32m.[0m[32m                                                                                   [100%][0m
[32m[1m_________________________________ test_valid_coords_return_h3_cell _________________________________[0m
---------------------------------------- Captured log call -----------------------------------------
[32mINFO    [0m __main__:2886585605.py:80 Validation: rows=1 invalid=0 (0.00%)
[32mINFO    [0m __main__:2886585605.py:94 Join stats: success=1 failed=0 (0.00%)
[32mINFO    [0m __main__:2886585605.py:99 Total elapsed: 0.001s
[32m[1m_______________________________ test_known_point_exact_index_level8 ________________________________[0m
---------------------------------------- Captured log call -----------------------------------------
[32mINFO    [0m __main__:2886585605.py:80 Validation: rows=1 invalid=0 (0.00%)
[32mINFO    [0m __main__:2886585605.py:94 Join stats: success=1 failed=0 (0.00%)
[32mINFO    [0m __main__:2886585605.py:99 Total elapsed

In [6]:
import pandas as pd

def compare_h3_by_notification(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    key: str = "notification_number",
    col: str = "h3_level8_index",
    coerce_to_str: bool = True,   # normalize dtypes before compare
    na_equal: bool = True,        # treat NaN==NaN as equal
):
    """
    Compare df1 vs df2 on a single column (h3_level8_index) by notification key.

    Returns dict with:
      - only_in_df1: keys present only in df1
      - only_in_df2: keys present only in df2
      - matches: keys in both where h3 matches (per na_equal rule)
      - mismatches: keys in both where h3 differs, with left/right values
    """
    if key not in df1.columns or key not in df2.columns:
        raise ValueError(f"Key '{key}' must exist in both DataFrames.")
    if col not in df1.columns or col not in df2.columns:
        raise ValueError(f"Column '{col}' must exist in both DataFrames.")

    left = df1[[key, col]].copy()
    right = df2[[key, col]].copy()

    if coerce_to_str:
        for d in (left, right):
            d[col] = d[col].astype("object").where(d[col].isna(), d[col].astype(str))

    merged = left.merge(
        right,
        on=key,
        how="outer",
        suffixes=("_df1", "_df2"),
        indicator=True,
    )

    only_in_df1 = merged.loc[merged["_merge"] == "left_only", [key]]
    only_in_df2 = merged.loc[merged["_merge"] == "right_only", [key]]

    both = merged.loc[merged["_merge"] == "both", [key, f"{col}_df1", f"{col}_df2"]].copy()

    lv = both[f"{col}_df1"]
    rv = both[f"{col}_df2"]
    equal_mask = (lv == rv)
    if na_equal:
        equal_mask = equal_mask | (lv.isna() & rv.isna())

    matches = both.loc[equal_mask, [key]]
    mismatches = both.loc[~equal_mask, [key, f"{col}_df1", f"{col}_df2"]].rename(
        columns={f"{col}_df1": f"{col}_left", f"{col}_df2": f"{col}_right"}
    )

    return {
        "only_in_df1": only_in_df1.reset_index(drop=True),
        "only_in_df2": only_in_df2.reset_index(drop=True),
        "matches": matches.reset_index(drop=True),
        "mismatches": mismatches.reset_index(drop=True),
    }


In [None]:
df_sr2 = assign_h3_level8(df_sr, lat_col="latitude", long_col="longitude", threshold=0.226, copy_frame=False, log_join=False)
df_sr2.drop(columns=["Unnamed: 0"], inplace=True)


2025-08-19 21:33:48 | INFO | h3_assign | Validation: rows=941634 invalid=212364 (22.55%)
2025-08-19 21:33:48 | INFO | h3_assign | Total elapsed: 0.423s
2025-08-19 21:33:48 | INFO | h3_assign | Validation: rows=941634 invalid=212364 (22.55%)
2025-08-19 21:33:49 | INFO | h3_assign | Total elapsed: 0.581s


Threshold was chosen at 22.6% as 22.55% failed to join. Thus I chose a higher threshold so that the code would not fail early. The default if we don't set a threshold is 5% as this would give us a confidence level of 95%.

In [8]:


result = compare_h3_by_notification(df_sr2, df_sr_hex, key="notification_number", col="h3_level8_index")

print("Only in df1:", len(result["only_in_df1"]))
print("Only in df2:", len(result["only_in_df2"]))
print("Matches    :", len(result["matches"]))
print("Mismatches :", len(result["mismatches"]))


Only in df1: 0
Only in df2: 0
Matches    : 941634
Mismatches : 0
