In [1]:
!pip install boto3

Looking in indexes: https://pypi.org/simple, https://aws:****@rts-358803043452.d.codeartifact.eu-west-1.amazonaws.com/pypi/rts-lib/simple/


In [2]:
import json
import boto3
from botocore.config import Config

REGION = "af-south-1"
BUCKET = "cct-ds-code-challenge-input-data"
KEY    = "city-hex-polygons-8-10.geojson"

sql = """
SELECT s.properties, s.geometry
FROM S3Object[*].features[*] s
WHERE s.properties.resolution = 8
"""

s3 = boto3.client("s3", region_name=REGION, config=Config(s3={"addressing_style": "virtual"}))

resp = s3.select_object_content(
    Bucket=BUCKET,
    Key=KEY,
    ExpressionType="SQL",
    Expression=sql,
    InputSerialization={"JSON": {"Type": "DOCUMENT"}},
    OutputSerialization={"JSON": {"RecordDelimiter": "\n"}},  # may still split objects across chunks
)

decoder = json.JSONDecoder()
buffer = ""     # accumulated text across streaming chunks
records = []
n_events = 0
n_decoded = 0

def drain_buffer(buf: str):
    """Yield as many JSON objects as possible from the head of buf, return (objs, remainder)."""
    objs = []
    i = 0
    L = len(buf)
    while True:
        # Skip whitespace/newlines between objects
        while i < L and buf[i].isspace():
            i += 1
        if i >= L:
            break
        try:
            obj, end = decoder.raw_decode(buf, i)
            objs.append(obj)
            i = end
        except json.JSONDecodeError:
            # Need more data to complete the next object
            break
    return objs, buf[i:]

for event in resp["Payload"]:
    n_events += 1
    if "Records" in event:
        chunk = event["Records"]["Payload"].decode("utf-8", errors="replace")
        buffer += chunk

        # Try to decode as many complete objects as possible from the buffer
        objs, buffer = drain_buffer(buffer)
        if objs:
            records.extend(objs)
            n_decoded += len(objs)

    elif "Stats" in event or "Progress" in event:
        # Optional: you can inspect event["Stats"] / ["Progress"] for bytes processed
        pass
    elif "End" in event:
        break

# After stream ends, there might still be leftover text in the buffer (e.g., trailing whitespace)
leftover = buffer.strip()

if leftover:
    # Try one last time, then print debug if it still doesn't parse
    try:
        objs, buffer = drain_buffer(leftover)
        records.extend(objs)
        n_decoded += len(objs)
        leftover = buffer.strip()
    except Exception:
        pass

print(f"Events received: {n_events}")
print(f"Decoded JSON objects: {n_decoded}")

if leftover:
    print("\n--- Leftover (did not parse to complete JSON object) ---")
    print(leftover[:1000])  # show first 1000 chars for debugging
    print("--------------------------------------------------------")

if records:
    print("First row example:\n", json.dumps(records[0], indent=2)[:1000])
else:
    print("No records decoded.")


Events received: 33
Decoded JSON objects: 3832
First row example:
 {
  "properties": {
    "index": "88ad361801fffff",
    "centroid_lat": -33.859427322761434,
    "centroid_lon": 18.677843311941835,
    "resolution": 8
  },
  "geometry": {
    "type": "Polygon",
    "coordinates": [
      [
        [
          18.6811898997334,
          -33.86330279081797
        ],
        [
          18.683574296194426,
          -33.85928287732969
        ],
        [
          18.68022760998973,
          -33.85540739558428
        ],
        [
          18.67449676770625,
          -33.855551865779326
        ],
        [
          18.672112346191998,
          -33.85957172360946
        ],
        [
          18.675458791982372,
          -33.8634471669068
        ],
        [
          18.6811898997334,
          -33.86330279081797
        ]
      ]
    ]
  }
}


In [3]:
!pip install orjson

Looking in indexes: https://pypi.org/simple, https://aws:****@rts-358803043452.d.codeartifact.eu-west-1.amazonaws.com/pypi/rts-lib/simple/


In [1]:
import time
import logging
import json
import boto3
from botocore.config import Config
from contextlib import contextmanager

# ------------------------- CONFIG -------------------------
REGION = "af-south-1"
BUCKET = "cct-ds-code-challenge-input-data"
KEY    = "city-hex-polygons-8-10.geojson"

OUT_JSONL = "hex8_features.jsonl"   # one Feature per line (ALL fields preserved)

DEBUG_MODE = False                  # flip True for full timings & resource logs
LOG_EVERY_N_EVENTS = 25             # event logging interval (only in debug mode)
SAMPLE_VALIDATE_EVERY = 1000        # ~1 sampled validation per N lines in debug mode

# Logging
logging.basicConfig(
    level=logging.DEBUG if DEBUG_MODE else logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("hex8_allfields_json_fast")

SQL = """
SELECT s
FROM S3Object[*].features[*] s
WHERE s.properties.resolution = 8
"""

@contextmanager
def timed(label: str):
    t0 = time.perf_counter()
    try:
        yield
    finally:
        if DEBUG_MODE:
            log.debug("%s took %.3f s", label, time.perf_counter() - t0)

def main():
    total_start = time.perf_counter()

    with timed("S3 client init"):
        s3 = boto3.client("s3", region_name=REGION, config=Config(s3={"addressing_style": "virtual"}))

    with timed("S3 Select open"):
        resp = s3.select_object_content(
            Bucket=BUCKET,
            Key=KEY,
            ExpressionType="SQL",
            Expression=SQL,
            InputSerialization={"JSON": {"Type": "DOCUMENT"}},
            OutputSerialization={"JSON": {"RecordDelimiter": "\n"}},
        )

    events = 0
    rows = 0
    bad_samples = 0
    bytes_scanned = bytes_processed = bytes_returned = None

    buf = bytearray()

    with timed("Stream + parse + write"), open(OUT_JSONL, "wb", buffering=8 * 1024 * 1024) as fout:
        for event in resp["Payload"]:
            events += 1

            rec = event.get("Records")
            if rec:
                payload = rec["Payload"]  # bytes
                buf.extend(payload)

                last_nl = buf.rfind(b"\n")
                if last_nl != -1:
                    to_write = buf[: last_nl + 1]  # bytes slice
                    if DEBUG_MODE and SAMPLE_VALIDATE_EVERY > 0 and rows % SAMPLE_VALIDATE_EVERY == 0:
                        prev_nl = to_write.rfind(b"\n", 0, len(to_write) - 1)
                        sample = to_write[prev_nl + 1 : -1] if prev_nl != -1 else to_write[:-1]
                        try:
                            json.loads(sample.decode("utf-8"))
                        except Exception as e:
                            bad_samples += 1
                            log.debug("Sample validation failed: %s", e)

                    fout.write(to_write)
                    rows += to_write.count(b"\n")
                    del buf[: last_nl + 1]

            elif "Stats" in event:
                d = event["Stats"]["Details"]
                bytes_scanned   = d.get("BytesScanned")
                bytes_processed = d.get("BytesProcessed")
                bytes_returned  = d.get("BytesReturned")

            if DEBUG_MODE and events % LOG_EVERY_N_EVENTS == 0:
                log.debug("Event %d | total rows=%d | buffer=%d bytes",
                          events, rows, len(buf))

        if buf:
            if buf[-1:] != b"\n":
                buf += b"\n"
            if DEBUG_MODE and SAMPLE_VALIDATE_EVERY > 0:
                try:
                    last_nl = buf[:-1].rfind(b"\n")
                    sample = buf[last_nl + 1 : -1] if last_nl != -1 else buf[:-1]
                    json.loads(sample.decode("utf-8"))
                except Exception as e:
                    bad_samples += 1
                    log.debug("Final sample validation failed: %s", e)
            fout.write(buf)
            rows += buf.count(b"\n")

    wall_total = time.perf_counter() - total_start

    # Always-on summary line
    log.info("Rows written: %d | Events: %d | Total wall time: %.3f s", rows, events, wall_total)

    if DEBUG_MODE:
        log.debug("======== SUMMARY (json-fast) ========")
        log.debug("Output file: %s", OUT_JSONL)
        log.debug("S3 Select bytes (MB): scanned=%.2f processed=%.2f returned=%.2f",
                  (bytes_scanned or 0) / (1024**2),
                  (bytes_processed or 0) / (1024**2),
                  (bytes_returned or 0) / (1024**2))
        if bad_samples:
            log.debug("Sampled validation failures: %d", bad_samples)

if __name__ == "__main__":
    main()


2025-08-19 22:07:49 | INFO     | Found credentials in shared credentials file: ~/.aws/credentials
2025-08-19 22:07:52 | INFO     | Rows written: 3832 | Events: 35 | Total wall time: 2.568 s


In [1]:
import time
import logging
import json
import boto3
from botocore.config import Config
from contextlib import contextmanager

# ------------------------- CONFIG -------------------------
REGION = "af-south-1"
BUCKET = "cct-ds-code-challenge-input-data"
KEY    = "city-hex-polygons-8-10.geojson"

OUT_JSONL = "hex8_features.jsonl"   # one Feature per line (ALL fields preserved)

DEBUG_MODE = False                  # detailed resource logs only if True
LOG_EVERY_N_EVENTS = 25             # higher = less logger overhead
SAMPLE_VALIDATE_EVERY = 1000        # ~1 sampled validation per N lines in DEBUG_MODE

# Logging
logging.basicConfig(
    level=logging.DEBUG if DEBUG_MODE else logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("hex8_allfields_json_fast")

SQL = """
SELECT s
FROM S3Object[*].features[*] s
WHERE s.properties.resolution = 8
"""

@contextmanager
def timed(label: str):
    t0 = time.perf_counter()
    try:
        yield
    finally:
        log.info("%s took %.3f s", label, time.perf_counter() - t0)

def main():
    total_start = time.perf_counter()

    with timed("S3 client init"):
        s3 = boto3.client("s3", region_name=REGION, config=Config(s3={"addressing_style": "virtual"}))

    with timed("S3 Select open"):
        resp = s3.select_object_content(
            Bucket=BUCKET,
            Key=KEY,
            ExpressionType="SQL",
            Expression=SQL,  # boto3 expects str
            InputSerialization={"JSON": {"Type": "DOCUMENT"}},
            OutputSerialization={"JSON": {"RecordDelimiter": "\n"}},
        )

    events = 0
    rows = 0
    bad_samples = 0
    bytes_scanned = bytes_processed = bytes_returned = None

    # Keep a bytearray buffer and write complete lines in large chunks
    buf = bytearray()

    # Large OS buffer for fewer syscalls
    with timed("Stream + parse + write"), open(OUT_JSONL, "wb", buffering=8 * 1024 * 1024) as fout:
        for event in resp["Payload"]:
            events += 1
            ev_t0 = time.perf_counter()

            rec = event.get("Records")
            if rec:
                payload = rec["Payload"]  # bytes
                buf.extend(payload)

                # Find last newline and flush everything up to it
                last_nl = buf.rfind(b"\n")
                if last_nl != -1:
                    to_write = buf[: last_nl + 1]  # bytes slice (has .count)
                    if DEBUG_MODE and SAMPLE_VALIDATE_EVERY > 0 and rows % SAMPLE_VALIDATE_EVERY == 0:
                        # Sample-validate one full line from this chunk (avoid splitting)
                        prev_nl = to_write.rfind(b"\n", 0, len(to_write) - 1)
                        sample = to_write[prev_nl + 1 : -1] if prev_nl != -1 else to_write[:-1]
                        try:
                            json.loads(sample.decode("utf-8"))
                        except Exception as e:
                            bad_samples += 1
                            log.debug("Sample validation failed: %s", e)

                    fout.write(to_write)                # one big write
                    rows += to_write.count(b"\n")       # cheap row count
                    del buf[: last_nl + 1]              # drop written bytes in-place

            elif "Stats" in event:
                d = event["Stats"]["Details"]
                bytes_scanned   = d.get("BytesScanned")
                bytes_processed = d.get("BytesProcessed")
                bytes_returned  = d.get("BytesReturned")

            if DEBUG_MODE and events % LOG_EVERY_N_EVENTS == 0:
                log.debug("Event %d processed in %.4f s | total rows=%d | buffer=%d bytes",
                          events, time.perf_counter() - ev_t0, rows, len(buf))

        # Write any trailing line (if stream didn't end with '\n')
        if buf:
            if buf[-1:] != b"\n":
                buf += b"\n"
            if DEBUG_MODE and SAMPLE_VALIDATE_EVERY > 0:
                try:
                    last_nl = buf[:-1].rfind(b"\n")
                    sample = buf[last_nl + 1 : -1] if last_nl != -1 else buf[:-1]
                    json.loads(sample.decode("utf-8"))
                except Exception as e:
                    bad_samples += 1
                    log.debug("Final sample validation failed: %s", e)
            fout.write(buf)
            rows += buf.count(b"\n")

    wall_total = time.perf_counter() - total_start

    # Minimal always-on summary
    log.info("======== SUMMARY ========")
    log.info("Output file: %s", OUT_JSONL)
    log.info("Rows written: %d | Events: %d", rows, events)
    log.info("Total wall time: %.3f s", wall_total)

    if DEBUG_MODE:
        log.debug("S3 Select bytes (MB): scanned=%.2f processed=%.2f returned=%.2f",
                  (bytes_scanned or 0) / (1024**2),
                  (bytes_processed or 0) / (1024**2),
                  (bytes_returned or 0) / (1024**2))
        if bad_samples:
            log.debug("Sampled validation failures: %d", bad_samples)

if __name__ == "__main__":
    main()


2025-08-19 22:07:29 | INFO     | Found credentials in shared credentials file: ~/.aws/credentials
2025-08-19 22:07:29 | INFO     | S3 client init took 0.115 s
2025-08-19 22:07:29 | INFO     | S3 Select open took 0.119 s
2025-08-19 22:07:32 | INFO     | Stream + parse + write took 2.211 s
2025-08-19 22:07:32 | INFO     | Output file: hex8_features.jsonl
2025-08-19 22:07:32 | INFO     | Rows written: 3832 | Events: 35
2025-08-19 22:07:32 | INFO     | Total wall time: 2.448 s


In [2]:
import time
import logging
import json
import boto3
from botocore.config import Config
from contextlib import contextmanager

# ------------------------- CONFIG -------------------------
REGION = "af-south-1"
BUCKET = "cct-ds-code-challenge-input-data"
KEY    = "city-hex-polygons-8-10.geojson"

OUT_JSONL = "hex8_features.jsonl"   # one Feature per line (ALL fields preserved)

DEBUG_MODE = False                  # flip True for detailed logs, row counts, sampled validation
LOG_EVERY_N_EVENTS = 25
SAMPLE_VALIDATE_EVERY = 1000        # ~1 sampled validation per N lines in debug mode

# Logging
logging.basicConfig(
    level=logging.DEBUG if DEBUG_MODE else logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("hex8_allfields_json_tail")

SQL = """
SELECT s
FROM S3Object[*].features[*] s
WHERE s.properties.resolution = 8
"""

@contextmanager
def timed(label: str):
    t0 = time.perf_counter()
    try:
        yield
    finally:
        if DEBUG_MODE:
            log.debug("%s took %.3f s", label, time.perf_counter() - t0)

def main():
    total_start = time.perf_counter()

    with timed("S3 client init"):
        s3 = boto3.client("s3", region_name=REGION, config=Config(s3={"addressing_style": "virtual"}))

    with timed("S3 Select open"):
        resp = s3.select_object_content(
            Bucket=BUCKET,
            Key=KEY,
            ExpressionType="SQL",
            Expression=SQL,
            InputSerialization={"JSON": {"Type": "DOCUMENT"}},
            OutputSerialization={"JSON": {"RecordDelimiter": "\n"}},
        )

    events = 0
    rows = 0        # only incremented in DEBUG_MODE
    bad_samples = 0

    # Keep only the partial line between chunks
    tail = b""

    with timed("Stream + parse + write"), open(OUT_JSONL, "wb", buffering=8 * 1024 * 1024) as fout:
        for event in resp["Payload"]:
            events += 1

            rec = event.get("Records")
            if rec:
                payload = rec["Payload"]  # bytes

                # Concatenate previous tail with the new payload (1 new bytes object)
                data = tail + payload

                # Find last complete newline
                last_nl = data.rfind(b"\n")
                if last_nl != -1:
                    # Zero-copy write of all complete lines
                    chunk_view = memoryview(data)[: last_nl + 1]
                    if DEBUG_MODE and SAMPLE_VALIDATE_EVERY > 0 and rows % SAMPLE_VALIDATE_EVERY == 0:
                        # Sample-validate the last full line in this block
                        prev_nl = data.rfind(b"\n", 0, last_nl)
                        sample = data[prev_nl + 1 : last_nl] if prev_nl != -1 else data[:last_nl]
                        try:
                            json.loads(sample.decode("utf-8"))
                        except Exception as e:
                            bad_samples += 1
                            log.debug("Sample validation failed: %s", e)

                    fout.write(chunk_view)

                    if DEBUG_MODE:
                        rows += data.count(b"\n", 0, last_nl + 1)  # count only in debug

                    # Keep the remainder (after the last newline) as the new tail
                    tail = data[last_nl + 1 :]
                else:
                    # No newline in this payload; carry everything as tail
                    tail = data

            elif "Stats" in event and DEBUG_MODE:
                d = event["Stats"]["Details"]
                log.debug("Stats: scanned=%.2fMB processed=%.2fMB returned=%.2fMB",
                          d.get("BytesScanned", 0)/1048576,
                          d.get("BytesProcessed", 0)/1048576,
                          d.get("BytesReturned", 0)/1048576)

            if DEBUG_MODE and events % LOG_EVERY_N_EVENTS == 0:
                log.debug("Event %d | rows=%d | tail=%d bytes", events, rows, len(tail))

        # Flush any trailing partial line (ensure trailing newline)
        if tail:
            if tail[-1:] != b"\n":
                fout.write(tail + b"\n")
                if DEBUG_MODE:
                    rows += 1
            else:
                fout.write(tail)
                if DEBUG_MODE:
                    rows += tail.count(b"\n")

    wall_total = time.perf_counter() - total_start

    # Always-on minimal summary
    if DEBUG_MODE:
        log.info("Rows written: %d | Events: %d | Total wall time: %.3f s",
                rows, events, wall_total)
    else:
        log.info("Events: %d | Total wall time: %.3f s",
                events, wall_total)

if __name__ == "__main__":
    main()


2025-08-19 22:09:16 | INFO     | Events: 35 | Total wall time: 2.513 s


In [4]:
# CELL 2 — Robust loader: JSONL -> GeoPandas (handles wrapped/columnar shapes)
import json
import pandas as pd
import geopandas as gpd
from shapely.geometry import shape

def gpd_from_jsonl(path: str, debug: bool = False) -> gpd.GeoDataFrame:
    """
    Accepts JSON Lines where each line is either:
      • a GeoJSON Feature: {"type":"Feature","properties":{...},"geometry":{...}}
      • wrapped: {"s": {Feature...}}
      • columnar: {"_1": properties, "_2": geometry}
      • or another single-key wrapper around a Feature

    Returns a GeoDataFrame with geometry and flattened properties.
    """
    feats = []
    wrapped = colstyle = skipped = 0

    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            s = line.strip()
            if not s:
                continue
            obj = json.loads(s)

            rec = None
            if isinstance(obj, dict) and "geometry" in obj and "properties" in obj:
                rec = obj
            elif isinstance(obj, dict) and "s" in obj and isinstance(obj["s"], dict):
                rec = obj["s"]; wrapped += 1
            elif isinstance(obj, dict) and "_1" in obj and "_2" in obj:
                # Map column-style into a proper Feature
                rec = {"type": "Feature", "properties": obj["_1"], "geometry": obj["_2"]}; colstyle += 1
            else:
                # Try any single nested dict that looks like a Feature
                if isinstance(obj, dict):
                    for v in obj.values():
                        if isinstance(v, dict) and "geometry" in v and "properties" in v:
                            rec = v; wrapped += 1
                            break

            if not isinstance(rec, dict) or "geometry" not in rec:
                skipped += 1
                if debug and skipped <= 3:
                    print(f"Skipping line {ln}: no 'geometry' key")
                continue

            feats.append(rec)

    if not feats:
        raise ValueError("No usable Feature records with 'geometry' found. Inspect your JSONL lines.")

    # Build columns
    props = [feat.get("properties", {}) for feat in feats]
    geoms = [shape(feat["geometry"]) if feat.get("geometry") is not None else None for feat in feats]

    df_props = pd.json_normalize(props, sep=".")
    gdf = gpd.GeoDataFrame(df_props, geometry=geoms, crs="EPSG:4326")

    if debug:
        print(f"Loaded {len(feats)} features | wrapped={wrapped} | columnar={colstyle} | skipped={skipped}")

    return gdf

# --- Use it ---
IN_JSONL = "hex8_features.jsonl"
gdf = gpd_from_jsonl(IN_JSONL, debug=False)  # flip to True for a brief diagnostics print
print(gdf.shape)
gdf.head(3)


(3832, 5)


Unnamed: 0,index,centroid_lat,centroid_lon,resolution,geometry
0,88ad361801fffff,-33.859427,18.677843,8,"POLYGON ((18.68119 -33.8633, 18.68357 -33.8592..."
1,88ad361803fffff,-33.855696,18.668766,8,"POLYGON ((18.67211 -33.85957, 18.6745 -33.8555..."
2,88ad361805fffff,-33.855263,18.685959,8,"POLYGON ((18.68931 -33.85914, 18.69169 -33.855..."


In [6]:
import pandas as pd
import geopandas as gpd
from pathlib import Path

ROOT = Path(__file__).resolve().parents[0] if "__file__" in globals() else Path().resolve()
while ROOT.name != "ds_code_challenge" and ROOT.parent != ROOT:
    ROOT = ROOT.parent

DATA_DIR = ROOT / "data"

file_map = {
    "sr.csv": "df_sr",
    "sr_hex.csv": "df_sr_hex",
    "sr_hex_truncated.csv": "df_sr_hex_truncated",
    "city-hex-polygons-8.geojson": "gdf_city_hex_8"
}

# Load the files
for file_name, var_name in file_map.items():
    file_path = DATA_DIR / file_name


    if file_path.suffix == ".csv":
        df = pd.read_csv(file_path)
        globals()[var_name] = df

    elif file_path.suffix == ".geojson":
        gdf = gpd.read_file(file_path)
        globals()[var_name] = gdf




In [20]:
import pandas as pd
import geopandas as gpd
from shapely.geometry.base import BaseGeometry
from typing import Dict

def _ensure_key_col(gdf: gpd.GeoDataFrame, key: str = "index") -> gpd.GeoDataFrame:
    return gdf if key in gdf.columns else gdf.reset_index().rename(columns={"index": key})

def compare_hex_gdfs_simple(
    left: gpd.GeoDataFrame,
    right: gpd.GeoDataFrame,
    key: str = "index",
    geom_tolerance: float = 0.0,  # 0 = strict; >0 means distance <= tol (CRS units) counts as equal
    na_equal: bool = True,
) -> Dict[str, pd.DataFrame]:
    # Make sure key column exists
    left  = _ensure_key_col(left, key)
    right = _ensure_key_col(right, key)

    required = {key, "centroid_lat", "centroid_lon", "geometry"}
    for name, df in (("left", left), ("right", right)):
        missing = sorted(required - set(df.columns))
        if missing:
            raise ValueError(f"{name} GeoDataFrame missing columns: {missing}")

    # Select/clone
    geom_l = left.geometry.name
    geom_r = right.geometry.name
    L = left[[key, "centroid_lat", "centroid_lon", geom_l]].copy()
    R = right[[key, "centroid_lat", "centroid_lon", geom_r]].copy()

    # Align CRS (reproject right -> left if both set and differ)
    if getattr(left, "crs", None) and getattr(right, "crs", None) and left.crs != right.crs:
        R = gpd.GeoDataFrame(R, geometry=geom_r, crs=right.crs).to_crs(left.crs)

    # Merge on key
    m = L.merge(R, on=key, how="outer", suffixes=("_l", "_r"), indicator=True)

    only_in_left  = m.loc[m["_merge"] == "left_only",  [key]].reset_index(drop=True)
    only_in_right = m.loc[m["_merge"] == "right_only", [key]].reset_index(drop=True)
    both = m.loc[m["_merge"] == "both"].copy()

    # Element-wise attr equality
    def _eq(a: pd.Series, b: pd.Series) -> pd.Series:
        out = (a == b)
        return out | (a.isna() & b.isna()) if na_equal else out

    lat_eq = _eq(both["centroid_lat_l"], both["centroid_lat_r"])
    lon_eq = _eq(both["centroid_lon_l"], both["centroid_lon_r"])

    # Element-wise geometry equality
    g1 = gpd.GeoSeries(both[f"{geom_l}_l"], crs=left.crs)
    g2 = gpd.GeoSeries(both[f"{geom_r}_r"], crs=left.crs)  # already reprojected if needed

    if geom_tolerance <= 0:
        geom_eq = g1.geom_equals(g2)
        if na_equal:
            geom_eq = geom_eq | (g1.isna() & g2.isna())
    else:
        def _within(a: BaseGeometry, b: BaseGeometry) -> bool:
            if a is None or b is None:
                return na_equal and (a is None and b is None)
            try:
                return a.distance(b) <= geom_tolerance
            except Exception:
                return False
        geom_eq = pd.Series([_within(a, b) for a, b in zip(g1.values, g2.values)], index=both.index)

    # Overall equality
    all_eq = lat_eq & lon_eq & geom_eq
    matches = both.loc[all_eq, [key]].reset_index(drop=True)
    diffs   = both.loc[~all_eq].copy()

    # Long-form diffs
    rows = []
    if (~lat_eq).any():
        t = diffs.loc[diffs.index.intersection(both.index[~lat_eq]), [key, "centroid_lat_l", "centroid_lat_r"]].copy()
        t.insert(1, "column", "centroid_lat")
        t.rename(columns={"centroid_lat_l": "left", "centroid_lat_r": "right"}, inplace=True)
        rows.append(t)
    if (~lon_eq).any():
        t = diffs.loc[diffs.index.intersection(both.index[~lon_eq]), [key, "centroid_lon_l", "centroid_lon_r"]].copy()
        t.insert(1, "column", "centroid_lon")
        t.rename(columns={"centroid_lon_l": "left", "centroid_lon_r": "right"}, inplace=True)
        rows.append(t)
    if (~geom_eq).any():
        t = diffs.loc[diffs.index.intersection(both.index[~geom_eq]), [key, f"{geom_l}_l", f"{geom_r}_r"]].copy()
        t.insert(1, "column", "geometry")
        t["left"]  = gpd.GeoSeries(t.pop(f"{geom_l}_l"), crs=left.crs).to_wkt()
        t["right"] = gpd.GeoSeries(t.pop(f"{geom_r}_r"), crs=left.crs).to_wkt()
        rows.append(t[[key, "column", "left", "right"]])

    mismatches_long = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=[key, "column", "left", "right"])

    # Wide-form for convenience
    if not diffs.empty:
        keep = [key, "centroid_lat_l", "centroid_lat_r", "centroid_lon_l", "centroid_lon_r", f"{geom_l}_l", f"{geom_r}_r"]
        mismatches_wide = diffs[[c for c in keep if c in diffs.columns]].reset_index(drop=True)
    else:
        mismatches_wide = pd.DataFrame(columns=[key])

    return {
        "only_in_left": only_in_left,
        "only_in_right": only_in_right,
        "matches": matches,
        "mismatches_long": mismatches_long,
        "mismatches_wide": mismatches_wide,
    }


In [None]:


res = compare_hex_gdfs_simple(gdf_city_hex_8, gdf, key="index", geom_tolerance=0.0) 

print("Only in left:", len(res["only_in_left"]))
print("Only in right:", len(res["only_in_right"]))
print("Matches:", len(res["matches"]))



Only in left: 0
Only in right: 0
Matches: 3832
