In [1]:
import json, numpy as np, pandas as pd
from pathlib import Path
import pickle

SHARD_DIR = Path("Graph_data")
OUT_CSV   = "embeddings_metro_bus_from_shards.csv"

# Context-specific thresholds (meters)
THRESHOLDS = {
    "metro": (400.0, 600.0, 800.0),
    "bus":   (200.0, 300.0, 400.0),
}

EPS = 1e-6

def baseline_from_meters(d, r1, r2, r3):
    if len(d) == 0:
        return None
    d = np.asarray(d, dtype=float)
    d_sorted = np.sort(d)
    inv = 1.0 / (d_sorted + EPS)
    return [
        float(len(d_sorted)),               # count
        float(d_sorted.mean()),             # mean
        float(d_sorted.min()),              # min
        float(d_sorted.max()),              # max
        float(np.median(d_sorted)),         # median
        float(d_sorted.std()),              # std
        float(inv.mean()),                  # mean_inv
        float(inv.max()),                   # max_inv
        float(inv.sum()),                   # sum_inv
        float((d_sorted <= r1).mean()),     # frac <= R1
        float((d_sorted <= r2).mean()),     # frac <= R2
        float((d_sorted <= r3).mean()),     # frac <= R3
    ]

def collect_context(shard_glob, ctx_key_guess, ctx_name):
    rows = {}
    r1, r2, r3 = THRESHOLDS[ctx_name]
    for p in sorted(SHARD_DIR.glob(shard_glob)):
        with open(p, "rb") as f:
            d = pickle.load(f)  # dict[int -> graphs-dict or Data]
        for aid, v in d.items():
            if isinstance(v, dict):
                # try explicit key, else first non-None
                g = v.get(ctx_key_guess)
                if g is None:
                    g = next((gg for gg in v.values() if gg is not None), None)
            else:
                g = v
            if g is None:
                rows[aid] = None
                continue
            ea = g.edge_attr.view(-1).cpu().numpy()  # meters
            rows[aid] = baseline_from_meters(ea.tolist(), r1, r2, r3)
    return rows

metro = collect_context("METROSHARD_*.pkl", "metro", "metro")
bus   = collect_context("BUSSHARD_*.pkl",   "bus",   "bus")

# Merge to a single dataframe; store vectors as JSON strings for now
all_ids = sorted(set(metro.keys()) | set(bus.keys()))
out = []
for aid in all_ids:
    rec = {"id": int(aid)}
    rec["emb_metro"] = json.dumps(metro.get(aid)) if metro.get(aid) is not None else None
    rec["emb_bus"]   = json.dumps(bus.get(aid))   if bus.get(aid)   is not None else None
    out.append(rec)

df = pd.DataFrame(out)
df.to_csv(OUT_CSV, index=False)
print(f"✅ wrote {OUT_CSV} with {len(df)} rows "
      f"(metro non-null: {df['emb_metro'].notna().sum()}, bus non-null: {df['emb_bus'].notna().sum()})")


✅ wrote embeddings_metro_bus_from_shards.csv with 24234 rows (metro non-null: 16752, bus non-null: 24175)
