In [1]:
# file: notebooks/per_context_embeddings.py
"""
Build 13 per-context embeddings (11 POI classes + metro + bus) for each apartment.
- Default: FAST BASELINE (no training). Always distinct across contexts when data differs.
- Optional: Per-context GNN with adjustable EPOCHS and NEG_K (slower).

Outputs a CSV with columns: id, emb_<context> (JSON vector or null).
Looks for shards in Graph_data/ with these patterns:
  - shard_*.pkl           (general POI classes)
  - METROSHARD_*.pkl      (metro)
  - BUSSHARD_*.pkl        (bus)
"""

# ==============================
# Cell 1 — Imports & switches
# ==============================
import os, json, time, pickle, glob
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv


In [8]:
# ---- Mode: "baseline" (fast, no training) or "gnn" (train per context)
MODE = "baseline"   # change to "gnn" to train encoders per context

In [9]:
# ==============================
# Cell 2 — Paths & hyperparams
# ==============================
ROOT = Path("Graph_data")  # folder containing exported shards
OUT_CSV = "apartment_embeddings_per_context.csv"

# Contexts
CLASSES = [
    'sport_and_leisure','medical','education_prim','veterinary',
    'food_and_drink_stores','arts_and_entertainment','food_and_drink',
    'park_like','security','religion','education_sup'
]
ALL_CONTEXTS = CLASSES + ['metro','bus']

# GNN knobs (used only when MODE=="gnn")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EMB_DIM = 12
HIDDEN  = 32
EPOCHS  = 8         # increase for better quality
BATCH_SZ = 512      # larger is faster
LR = 1e-3
WEIGHT_DECAY = 1e-4
NEG_K = 8           # increase for stronger contrast

print({"mode": MODE, "device": str(DEVICE)})

{'mode': 'baseline', 'device': 'cuda'}


In [10]:
# ======================================
# Cell 3 — Load exported shard helpers
# ======================================

def load_pickle(p: Path):
    with open(p, "rb") as f:
        return pickle.load(f)

# General shards hold: dict[apt_id] -> dict[class] -> Data|None
# Metro/Bus shards hold: dict[apt_id] -> Data|None

def find_shards() -> Tuple[List[Path], List[Path], List[Path]]:
    gens  = sorted((ROOT).glob("shard_*.pkl"))
    metro = sorted((ROOT).glob("METROSHARD_*.pkl"))
    bus   = sorted((ROOT).glob("BUSSHARD_*.pkl"))
    print(f"found: general={len(gens)}, metro={len(metro)}, bus={len(bus)}")
    return gens, metro, bus

GEN_SHARDS, METRO_SHARDS, BUS_SHARDS = find_shards()


def load_general_items_for_class(ctx: str) -> List[Tuple[int, Data]]:
    items: List[Tuple[int, Data]] = []
    for p in GEN_SHARDS:
        part = load_pickle(p)
        for aid, gdict in part.items():
            g = gdict.get(ctx)
            if isinstance(g, Data):
                items.append((int(aid), g))
    return items


def load_simple_items(shards: List[Path]) -> List[Tuple[int, Data]]:
    items: List[Tuple[int, Data]] = []
    for p in shards:
        part = load_pickle(p)
        for aid, g in part.items():
            if isinstance(g, Data):
                items.append((int(aid), g))
    return items

found: general=5, metro=2, bus=2


In [11]:
# ===================================
# Cell 4 — FAST baseline (no train)
# ===================================

def weights_from_graph(g: Data) -> Optional[torch.Tensor]:
    if g is None or g.edge_attr is None or g.edge_attr.numel() == 0:
        return None
    w = g.edge_attr.view(-1).clone()
    if (w > 1.0).any():  # meters → squash
        w = 1.0 / (1.0 + torch.clamp(w, min=0.0))
    return torch.clamp(w, 0.0, 1.0)


def star_stats_12d(w: torch.Tensor) -> np.ndarray:
    E = float(w.numel())
    if E == 0:
        return np.full(12, np.nan, dtype=np.float32)
    s1  = w.sum(); s2 = (w**2).sum(); s3 = (w**3).sum()
    mean = float(s1 / E)
    var  = float(s2 / E - mean**2)
    std  = float(np.sqrt(max(var, 0.0)))
    w_np = w.cpu().numpy()
    w_sorted = np.sort(w_np)[::-1]
    top1 = float(w_sorted[0])
    top3_mean = float(w_sorted[:min(3, w_sorted.size)].mean())
    tail = float((w_np > 0.75).mean())
    mn = float(w_np.min()); mx = float(w_np.max())
    bin2 = float((w_np > 0.5).mean())
    return np.array([E, float(s1), float(s2), float(s3), mean, std, mn, mx, top1, top3_mean, tail, bin2], dtype=np.float32)


def build_baseline_embeddings(out_csv: str) -> pd.DataFrame:
    per_ctx: Dict[str, List[Tuple[int, Data]]] = {}
    for ctx in CLASSES:
        per_ctx[ctx] = load_general_items_for_class(ctx)
        print(f"[collect] {ctx}: {len(per_ctx[ctx])}")
    per_ctx['metro'] = load_simple_items(METRO_SHARDS); print(f"[collect] metro: {len(per_ctx['metro'])}")
    per_ctx['bus']   = load_simple_items(BUS_SHARDS);   print(f"[collect] bus:   {len(per_ctx['bus'])}")

    all_ids: Set[int] = set()
    for ctx, items in per_ctx.items():
        all_ids |= {aid for aid, _ in items}
    all_ids = sorted(all_ids)
    print("apartments total:", len(all_ids))

    ctx_embs: Dict[str, Dict[int, Optional[str]]] = {c: {} for c in ALL_CONTEXTS}
    t0 = time.time()
    for ctx in ALL_CONTEXTS:
        items = dict(per_ctx.get(ctx, []))  # apt_id -> Data
        for aid in all_ids:
            g = items.get(aid)
            if g is None:
                ctx_embs[ctx][aid] = None
            else:
                w = weights_from_graph(g)
                if w is None:
                    ctx_embs[ctx][aid] = None
                else:
                    vec = star_stats_12d(w)
                    ctx_embs[ctx][aid] = json.dumps([float(x) for x in vec.tolist()])
    print(f"baseline built in {time.time()-t0:.2f}s")

    rows = []
    for aid in all_ids:
        row = {"id": aid}
        for ctx in ALL_CONTEXTS:
            row[f"emb_{ctx}"] = ctx_embs[ctx].get(aid)
        rows.append(row)
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("saved:", out_csv, "shape:", df.shape)
    return df

In [6]:
# ======================================
# Cell 5 — Optional: per-context GNN
# ======================================
class TinyGraphConv(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        self.conv1 = GraphConv(in_dim, hidden)
        self.conv2 = GraphConv(hidden, out_dim)
        self.act = nn.ReLU(); self.drop = nn.Dropout(0.1)
        self.W = nn.Linear(out_dim, out_dim, bias=False)      # for edges
        self.h_deg = nn.Linear(out_dim, 1)
        self.h_mean = nn.Linear(out_dim, 1)

    def forward(self, data: Data):
        x, ei = data.x, data.edge_index
        ew = data.edge_attr.view(-1) if data.edge_attr is not None else None
        h = self.conv1(x, ei, edge_weight=ew); h = self.act(h); h = self.drop(h)
        h = self.conv2(h, ei, edge_weight=ew)
        return h

    def score(self, h_src, h_dst):
        return (self.W(h_src) * h_dst).sum(dim=-1)


def add_distance_scalar(g: Data) -> Data:
    n = g.num_nodes
    dist = torch.zeros((n,1), dtype=g.x.dtype)
    if g.edge_attr is not None and g.edge_index is not None:
        _, dst = g.edge_index
        w = g.edge_attr.view(-1)
        dist[dst] = w.unsqueeze(1) if w.dim()==1 else w
    g.x = torch.cat([g.x, dist], dim=1)
    return g


def common_feat_dim(items: List[Tuple[int, Data]]) -> int:
    return max((g.x.size(1) for _, g in items), default=0)

class StarDataset(Dataset):
    def __init__(self, items: List[Tuple[int, Data]], base_f: int):
        self.items = items; self.base_f = base_f; self.input_dim = base_f + 1
    def __len__(self): return len(self.items)
    def __getitem__(self, i: int) -> Data:
        aid, g = self.items[i]
        if g.x.size(1) < self.base_f:
            pad = torch.zeros((g.num_nodes, self.base_f - g.x.size(1)), dtype=g.x.dtype)
            g = Data(x=torch.cat([g.x, pad], dim=1), edge_index=g.edge_index, edge_attr=g.edge_attr)
        g = add_distance_scalar(g)
        if g.x.size(1) > self.input_dim:
            g.x = g.x[:, :self.input_dim]
        elif g.x.size(1) < self.input_dim:
            pad = torch.zeros((g.num_nodes, self.input_dim - g.x.size(1)), dtype=g.x.dtype)
            g.x = torch.cat([g.x, pad], dim=1)
        apt_idx = int(torch.nonzero(g.x[:,0] > 0.5, as_tuple=False)[0].item()) if (g.x[:,0] > 0.5).any() else 0
        g.apt_idx = apt_idx; g.apt_id = int(aid)
        return g


def edge_weights_for_training(g: Data) -> torch.Tensor:
    w = g.edge_attr.view(-1)
    if (w > 1.0).any():
        w = 1.0 / (1.0 + torch.clamp(w, min=0.0))
    return torch.clamp(w, 0.0, 1.0)


def train_context(items: List[Tuple[int, Data]], name: str) -> Tuple[Optional[TinyGraphConv], Optional[StarDataset]]:
    if not items:
        print(f"[{name}] no graphs; skip")
        return None, None
    base_f = common_feat_dim(items)
    ds = StarDataset(items, base_f)
    dl = DataLoader(ds, batch_size=BATCH_SZ, shuffle=True)
    model = TinyGraphConv(ds.input_dim, HIDDEN, EMB_DIM).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    def train_epoch() -> float:
        model.train(); tot=0.0; cnt=0
        for big in dl:
            big = big.to(DEVICE); opt.zero_grad()
            H = model(big); src, dst = big.edge_index
            # apartment global indices
            num_g = big.ptr.numel()-1
            apt_gl = []
            for i in range(num_g):
                base = int(big.ptr[i].item()); aidx = int(big.apt_idx[i].item())
                apt_gl.append(base + aidx)
            apt_gl = torch.tensor(apt_gl, device=big.x.device, dtype=src.dtype)
            mask = (src.view(-1,1) == apt_gl.view(1,-1)).any(dim=1)
            if not mask.any():
                continue
            src_pos = src[mask]; dst_pos = dst[mask]
            tgt_all = edge_weights_for_training(big); tgt_pos = tgt_all[mask]
            # pos
            pred_pos = torch.sigmoid(model.score(H[src_pos], H[dst_pos]))
            loss_pos = F.mse_loss(pred_pos, tgt_pos)
            # negs
            neg_s, neg_d = [], []
            for i in range(num_g):
                base = int(big.ptr[i].item()); end=int(big.ptr[i+1].item())
                a_gi = base + int(big.apt_idx[i].item())
                true = set(dst_pos[(src_pos==a_gi)].tolist())
                cands = [j for j in range(base,end) if j!=a_gi and j not in true]
                if not cands: continue
                pick = np.random.choice(cands, size=min(NEG_K,len(cands)), replace=False)
                neg_s += [a_gi]*len(pick); neg_d += [int(x) for x in pick]
            if neg_s:
                neg_s = torch.tensor(neg_s, device=big.x.device, dtype=src.dtype)
                neg_d = torch.tensor(neg_d, device=big.x.device, dtype=dst.dtype)
                pred_neg = torch.sigmoid(model.score(H[neg_s], H[neg_d]))
                loss_neg = F.mse_loss(pred_neg, torch.zeros_like(pred_neg))
            else:
                loss_neg = torch.tensor(0.0, device=big.x.device)
            # aux: degree + mean weight from apt embedding
            deg_t, mean_t = [], []
            for i in range(num_g):
                base = int(big.ptr[i].item()); a_gi = base + int(big.apt_idx[i].item())
                sel = (src_pos == a_gi); d = int(sel.sum().item()); deg_t.append(d)
                mean_t.append(float(tgt_pos[sel].mean().item()) if d>0 else 0.0)
            deg_t = torch.tensor(deg_t, device=big.x.device, dtype=torch.float).view(-1,1)
            mean_t= torch.tensor(mean_t,device=big.x.device, dtype=torch.float).view(-1,1)
            deg_n = torch.log1p(deg_t)/4.0
            apt_h = H[apt_gl]
            loss_deg = F.mse_loss(model.h_deg(apt_h), deg_n)
            loss_mean= F.mse_loss(torch.sigmoid(model.h_mean(apt_h)), mean_t)
            loss = loss_pos + 0.5*loss_neg + 0.2*loss_deg + 0.2*loss_mean
            loss.backward(); opt.step(); tot += float(loss.item()); cnt += 1
        return tot/max(cnt,1)

    print(f"[{name}] graphs={len(ds)} in_dim={ds.input_dim}")
    for ep in range(1, EPOCHS+1):
        t0=time.time(); L=train_epoch();
        print(f"[{name}] epoch {ep:02d} loss={L:.6f} ({time.time()-t0:.1f}s)")
    return model, ds

@torch.no_grad()
def embed_with_model(model: TinyGraphConv, g: Data, input_dim: int) -> np.ndarray:
    # ensure exact input_dim
    if g.x.size(1) < input_dim:
        pad = torch.zeros((g.num_nodes, input_dim - g.x.size(1)), dtype=g.x.dtype)
        g.x = torch.cat([g.x, pad], dim=1)
    elif g.x.size(1) > input_dim:
        g.x = g.x[:, :input_dim]
    g = g.to(DEVICE)
    H = model(g)
    aidx = int(torch.nonzero(g.x[:,0] > 0.5, as_tuple=False)[0].item()) if (g.x[:,0] > 0.5).any() else 0
    return H[aidx].cpu().numpy()


def build_gnn_embeddings(out_csv: str) -> pd.DataFrame:
    # collect items
    per_ctx: Dict[str, List[Tuple[int, Data]]] = {}
    for ctx in CLASSES:
        per_ctx[ctx] = load_general_items_for_class(ctx)
        print(f"[collect] {ctx}: {len(per_ctx[ctx])}")
    per_ctx['metro'] = load_simple_items(METRO_SHARDS); print(f"[collect] metro: {len(per_ctx['metro'])}")
    per_ctx['bus']   = load_simple_items(BUS_SHARDS);   print(f"[collect] bus:   {len(per_ctx['bus'])}")

    all_ids: Set[int] = set()
    for ctx, items in per_ctx.items():
        all_ids |= {aid for aid, _ in items}
    all_ids = sorted(all_ids)
    print("apartments total:", len(all_ids))

    ctx_embs: Dict[str, Dict[int, Optional[str]]] = {}
    for ctx in ALL_CONTEXTS:
        items = per_ctx.get(ctx, [])
        if not items:
            ctx_embs[ctx] = {aid: None for aid in all_ids}
            continue
        model, ds = train_context(items, ctx)
        base_f = common_feat_dim(items); input_dim = base_f + 1
        by_id = {aid: g for aid,g in items}
        model.eval(); emb_map: Dict[int, Optional[str]] = {}
        for aid in all_ids:
            g = by_id.get(aid)
            if g is None:
                emb_map[aid] = None
            else:
                # replicate train pipeline
                if g.x.size(1) < base_f:
                    pad = torch.zeros((g.num_nodes, base_f - g.x.size(1)), dtype=g.x.dtype)
                    g = Data(x=torch.cat([g.x, pad], dim=1), edge_index=g.edge_index, edge_attr=g.edge_attr)
                g = add_distance_scalar(g)
                vec = embed_with_model(model, g, input_dim)
                emb_map[aid] = json.dumps([float(x) for x in vec.tolist()])
        ctx_embs[ctx] = emb_map

    rows = []
    for aid in all_ids:
        row = {"id": aid}
        for ctx in ALL_CONTEXTS:
            row[f"emb_{ctx}"] = ctx_embs[ctx].get(aid)
        rows.append(row)
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("saved:", out_csv, "shape:", df.shape)
    return df

In [12]:
# ======================================
# Cell 6 — Run
# ======================================
if MODE == "baseline":
    df_out = build_baseline_embeddings(OUT_CSV)
else:
    df_out = build_gnn_embeddings(OUT_CSV)

print(df_out.head())


[collect] sport_and_leisure: 25123
[collect] medical: 25087
[collect] education_prim: 24519
[collect] veterinary: 22982
[collect] food_and_drink_stores: 24694
[collect] arts_and_entertainment: 25007
[collect] food_and_drink: 24205
[collect] park_like: 24275
[collect] security: 24564
[collect] religion: 21291
[collect] education_sup: 24581
[collect] metro: 16752
[collect] bus:   24175
apartments total: 25211
baseline built in 74.44s
saved: apartment_embeddings_per_context.csv shape: (25211, 14)
           id                              emb_sport_and_leisure  \
0  1359204515  [60.0, 16.075674057006836, 6.5722761154174805,...   
1  1366496843  [12.0, 5.705898284912109, 3.5092594623565674, ...   
2  1367599797  [78.0, 26.780254364013672, 13.415040969848633,...   
3  1391886163  [11.0, 5.413773536682129, 3.2894861698150635, ...   
4  1408157926  [33.0, 8.578935623168945, 2.700268268585205, 0...   

                                         emb_medical  \
0  [38.0, 15.997368812561035, 9.4231

In [1]:
# POI baseline embeddings (12-dim per class) → CSV (JSON columns, not expanded)

import json, pickle
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import pandas as pd

# Where your POI shards live:
SHARD_DIR = Path("WebScraper/Graph_data")

# Output CSV (one row per apartment; 11 JSON columns, one per class)
OUT_CSV = "embeddings_poi_from_shards.csv"

# The 11 POI classes (column order will follow this)
CLASSES: List[str] = [
    "sport_and_leisure",
    "medical",
    "education_prim",
    "veterinary",
    "food_and_drink_stores",
    "arts_and_entertainment",
    "food_and_drink",
    "park_like",
    "security",
    "religion",
    "education_sup",
]

# POI thresholds (meters) used in the 12-dim baseline schema
R1, R2, R3 = 600.0, 1200.0, 2400.0
EPS = 1e-6

def baseline_from_meters(d: np.ndarray) -> Optional[List[float]]:
    """Return 12-dim baseline vector from raw distances in meters; None if empty."""
    if d.size == 0:
        return None
    d = np.asarray(d, dtype=float)
    d_sorted = np.sort(d)
    inv = 1.0 / (d_sorted + EPS)
    return [
        float(d_sorted.size),               # 0: count
        float(d_sorted.mean()),             # 1: mean
        float(d_sorted.min()),              # 2: min
        float(d_sorted.max()),              # 3: max
        float(np.median(d_sorted)),         # 4: median
        float(d_sorted.std()),              # 5: std
        float(inv.mean()),                  # 6: mean_inv
        float(inv.max()),                   # 7: max_inv
        float(inv.sum()),                   # 8: sum_inv
        float((d_sorted <= R1).mean()),     # 9: frac <= 600m
        float((d_sorted <= R2).mean()),     # 10: frac <= 1200m
        float((d_sorted <= R3).mean()),     # 11: frac <= 2400m
    ]

def iter_shards():
    # Only use the POI shards (pattern provided); adapt if needed.
    for p in sorted(SHARD_DIR.glob("shard_*.pkl")):
        yield p

def extract_poi_distances_from_graph(g) -> np.ndarray:
    """
    Get raw meters from a PyG Data 'g'.
    Assumes updated shards stored meters in g.edge_attr (shape [E, 1]).
    """
    if g is None or getattr(g, "edge_attr", None) is None:
        return np.array([], dtype=float)
    ea = g.edge_attr
    if ea.dim() == 2 and ea.size(1) >= 1:
        arr = ea.view(-1).detach().cpu().numpy()
        return arr
    # Fallback: treat as empty if shape is unexpected
    return np.array([], dtype=float)

# Collect per-apartment per-class baselines
# Structure: results[apt_id][class] = list[12] or None
results: Dict[int, Dict[str, Optional[List[float]]]] = {}

count_shards = 0
for shard in iter_shards():
    with open(shard, "rb") as f:
        d = pickle.load(f)  # dict[int -> dict[class -> Data|None]]
    count_shards += 1

    for apt_id, class_map in d.items():
        # Ensure we have a dict for this apartment
        if apt_id not in results:
            results[apt_id] = {cls: None for cls in CLASSES}

        # class_map should be dict[class -> Data|None]; compute per class
        for cls in CLASSES:
            g = class_map.get(cls, None)
            if g is None:
                # leave as None
                continue
            meters = extract_poi_distances_from_graph(g)
            vec = baseline_from_meters(meters)
            results[apt_id][cls] = vec

print(f"Processed {count_shards} shards. Apartments gathered: {len(results)}")

# Build output rows (keep JSON-encoded vectors; None stays None)
rows = []
for apt_id in sorted(results.keys()):
    rec = {"id": int(apt_id)}
    for cls in CLASSES:
        vec = results[apt_id][cls]
        rec[f"emb_{cls}"] = json.dumps(vec) if vec is not None else None
    rows.append(rec)

df = pd.DataFrame(rows, columns=["id"] + [f"emb_{c}" for c in CLASSES])
df.to_csv(OUT_CSV, index=False)
print(
    f"✅ wrote {OUT_CSV} with {len(df)} rows. "
    + "Non-null counts: "
    + ", ".join([f"{c}={df[f'emb_{c}'].notna().sum()}" for c in CLASSES])
)


Processed 0 shards. Apartments gathered: 0
✅ wrote embeddings_poi_from_shards.csv with 0 rows. Non-null counts: sport_and_leisure=0, medical=0, education_prim=0, veterinary=0, food_and_drink_stores=0, arts_and_entertainment=0, food_and_drink=0, park_like=0, security=0, religion=0, education_sup=0


In [2]:
# POI baseline embeddings (12-dim per class) → CSV (JSON columns, not expanded)

import json, pickle, os, sys
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import pandas as pd

# Ensure PyG Data is importable for unpickling
from torch_geometric.data import Data  # noqa: F401

# --- Config ---
# Primary (as you said): shards live in WebScraper/Graph_data
PREFERRED_DIR = Path("WebScraper/Graph_data")
OUT_CSV = "embeddings_poi_from_shards.csv"

CLASSES: List[str] = [
    "sport_and_leisure",
    "medical",
    "education_prim",
    "veterinary",
    "food_and_drink_stores",
    "arts_and_entertainment",
    "food_and_drink",
    "park_like",
    "security",
    "religion",
    "education_sup",
]

# POI thresholds (meters)
R1, R2, R3 = 600.0, 1200.0, 2400.0
EPS = 1e-6

def baseline_from_meters(d: np.ndarray) -> Optional[List[float]]:
    if d.size == 0:
        return None
    d = np.asarray(d, dtype=float)
    d_sorted = np.sort(d)
    inv = 1.0 / (d_sorted + EPS)
    return [
        float(d_sorted.size),               # 0: count
        float(d_sorted.mean()),             # 1: mean
        float(d_sorted.min()),              # 2: min
        float(d_sorted.max()),              # 3: max
        float(np.median(d_sorted)),         # 4: median
        float(d_sorted.std()),              # 5: std
        float(inv.mean()),                  # 6: mean_inv
        float(inv.max()),                   # 7: max_inv
        float(inv.sum()),                   # 8: sum_inv
        float((d_sorted <= R1).mean()),     # 9: frac <= 600m
        float((d_sorted <= R2).mean()),     # 10: frac <= 1200m
        float((d_sorted <= R3).mean()),     # 11: frac <= 2400m
    ]

def extract_poi_distances_from_graph(g) -> np.ndarray:
    # expects g.edge_attr to store raw meters (shape [E,1] or [E])
    if g is None or getattr(g, "edge_attr", None) is None:
        return np.array([], dtype=float)
    ea = g.edge_attr
    try:
        return ea.view(-1).detach().cpu().numpy()
    except Exception:
        return np.array([], dtype=float)

def find_shards() -> list[Path]:
    # Try multiple roots to be robust to current working directory
    roots = [
        PREFERRED_DIR,
        Path("Graph_data"),
        Path.cwd() / "WebScraper" / "Graph_data",
        Path.cwd() / "Graph_data",
        Path("..") / "WebScraper" / "Graph_data",
        Path("..") / "Graph_data",
    ]
    candidates: set[Path] = set()
    for r in roots:
        if r.exists():
            candidates.update(r.glob("shard_*.pkl"))
    # Fallback: recursive search from CWD if still nothing
    if not candidates:
        candidates = set(Path(".").rglob("shard_*.pkl"))
    shards = sorted(candidates, key=lambda p: p.stat().st_mtime)
    # Debug print
    print(f"CWD: {Path.cwd()}")
    print("Search roots checked:")
    for r in roots:
        print(" -", r.resolve())
    print(f"Found {len(shards)} shards:")
    for s in shards[:5]:
        print("  •", s)
    if len(shards) > 5:
        print("  • ...")
    return shards

# ---- Run ----
shards = find_shards()

results: Dict[int, Dict[str, Optional[List[float]]]] = {}
count_shards = 0

for shard in shards:
    with open(shard, "rb") as f:
        d = pickle.load(f)  # dict[int -> dict[class -> Data|None]]
    count_shards += 1

    for apt_id, class_map in d.items():
        # ensure apartment dict exists
        if apt_id not in results:
            results[apt_id] = {cls: None for cls in CLASSES}

        # compute per class
        if isinstance(class_map, dict):
            for cls in CLASSES:
                g = class_map.get(cls, None)
                if g is None:
                    continue
                meters = extract_poi_distances_from_graph(g)
                vec = baseline_from_meters(meters)
                results[apt_id][cls] = vec
        else:
            # Unexpected structure; skip safely
            continue

print(f"Processed {count_shards} shards. Apartments gathered: {len(results)}")

rows = []
for apt_id in sorted(results.keys()):
    rec = {"id": int(apt_id)}
    for cls in CLASSES:
        vec = results[apt_id][cls]
        rec[f"emb_{cls}"] = json.dumps(vec) if vec is not None else None
    rows.append(rec)

df = pd.DataFrame(rows, columns=["id"] + [f"emb_{c}" for c in CLASSES])
df.to_csv(OUT_CSV, index=False)
print(
    f"✅ wrote {OUT_CSV} with {len(df)} rows. "
    + "Non-null counts: "
    + ", ".join([f"{c}={df[f'emb_{c}'].notna().sum()}" for c in CLASSES])
)


CWD: c:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper
Search roots checked:
 - WebScraper\Graph_data
 - C:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper\Graph_data
 - C:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper\WebScraper\Graph_data
 - C:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper\Graph_data
 - C:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper\Graph_data
 - C:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\Graph_data
Found 9 shards:
  • c:\Users\Pc-ADS\Documents\Cosas Universidad\UNAB\2024 sem 2\Seminario licenciatura\WebScraper\Graph_data\shard_20250829_183814_2862820058-1535195651.pkl
  • Graph_data\shard_20250829_183814_2862820058-1535195651.pkl
  • ..\WebScraper\Graph_data\shard_20250829_183814_2862820058-1535195651.pkl
 