In [4]:
# file: notebooks/train_per_context_embeddings.ipynb (single cell version for brevity)
# ========================= Imports & config =========================
import os, json, time, pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)

# Embedding shape & train knobs
EMB_DIM   = 12
HIDDEN    = 32
# QUICK DIAL-DOWN (put near your hyperparams)
EPOCHS = 2        # was 6–12
BATCH_SZ = 512    # larger batches reduce steps
LR = 2e-3         # a tad higher to still move in 1 epoch
NEG_K = 1         # was 5–10, cheaper negatives
HIDDEN = 16       # smaller model
EMB_DIM = 8       # smaller output

# Shards you listed (under Graph_data/)
ROOT = Path("Graph_data")
GENERAL_SHARDS = [
    'shard_20250823_193931_1555433697-1539080245.pkl',
    'shard_20250823_195306_2853468746-2862582048.pkl',
    'shard_20250823_212915_1575719893-1579917449.pkl',
    'shard_20250823_224953_1586671181-1586639873.pkl',
    'shard_20250824_001026_1584809495-1548097259.pkl',
]
METRO_SHARDS = [
    'METROSHARD_20250825_211657_1555433697-2854216564.pkl',
    'METROSHARD_20250826_091005_1553843137-1548097259.pkl',
]
BUS_SHARDS = [
    'BUSSHARD_20250826_122800_1555433697-1584388845.pkl',
    'BUSSHARD_20250826_153430_2862820058-1548097259.pkl',
]

CLASSES = [
    'sport_and_leisure','medical','education_prim','veterinary',
    'food_and_drink_stores','arts_and_entertainment','food_and_drink',
    'park_like','security','religion','education_sup'
]
ALL_CONTEXTS = CLASSES + ['metro','bus']

OUT_CSV = "apartment_embeddings_per_context_no_coords.csv"

# ========================= Load helpers =========================
def load_pickle(p: Path):
    with open(p, "rb") as f: return pickle.load(f)

def load_general_items_for_class(ctx: str) -> List[Tuple[int, Data]]:
    """Return [(apt_id, Data)] for a single general class context."""
    items: List[Tuple[int, Data]] = []
    for name in GENERAL_SHARDS:
        p = ROOT / name
        if not p.exists(): 
            print(f"[warn] missing general shard {p}"); 
            continue
        part = load_pickle(p)  # dict[apt_id] -> dict[class]->Data or None
        for aid, gdict in part.items():
            g = gdict.get(ctx)
            if isinstance(g, Data):
                items.append((int(aid), g))
    return items

def load_metro_items() -> List[Tuple[int, Data]]:
    items: List[Tuple[int, Data]] = []
    for name in METRO_SHARDS:
        p = ROOT / name
        if not p.exists(): 
            print(f"[warn] missing metro shard {p}")
            continue
        part = load_pickle(p)  # dict[apt_id] -> Data
        for aid, g in part.items():
            if isinstance(g, Data):
                items.append((int(aid), g))
    return items

def load_bus_items() -> List[Tuple[int, Data]]:
    items: List[Tuple[int, Data]] = []
    for name in BUS_SHARDS:
        p = ROOT / name
        if not p.exists(): 
            print(f"[warn] missing bus shard {p}")
            continue
        part = load_pickle(p)  # dict[apt_id] -> Data
        for aid, g in part.items():
            if isinstance(g, Data):
                items.append((int(aid), g))
    return items

# ========================= Dataset (no coords; +distance scalar) =========================
from torch.utils.data import Dataset

def add_distance_scalar(g: Data) -> Data:
    """Append 1-d feature: copy edge weight to POI node; apt node stays 0."""
    n = g.num_nodes
    dist_feat = torch.zeros((n, 1), dtype=g.x.dtype)
    if g.edge_attr is not None and g.edge_index is not None:
        _, dst = g.edge_index
        w = g.edge_attr.view(-1)
        dist_feat[dst] = w.unsqueeze(1)
    g.x = torch.cat([g.x, dist_feat], dim=1)
    return g

def common_feat_dim(items: List[Tuple[int, Data]]) -> int:
    if not items: return 0
    return max(g.x.size(1) for _, g in items)

class StarDataset(Dataset):
    def __init__(self, items: List[Tuple[int, Data]], base_f: int):
        self.items = items
        self.base_f = base_f  # COMMON_F per context (before +1 distance)
        self.input_dim = base_f + 1

    def __len__(self): return len(self.items)

    def __getitem__(self, i: int) -> Data:
        aid, g = self.items[i]
        # pad to base_f
        if g.x.size(1) < self.base_f:
            pad_w = self.base_f - g.x.size(1)
            pad = torch.zeros((g.num_nodes, pad_w), dtype=g.x.dtype)
            g = Data(x=torch.cat([g.x, pad], dim=1),
                     edge_index=g.edge_index, edge_attr=g.edge_attr)
        # add distance scalar
        g = add_distance_scalar(g)  # +1
        # enforce exactly input_dim
        if g.x.size(1) < self.input_dim:
            pad = torch.zeros((g.num_nodes, self.input_dim - g.x.size(1)), dtype=g.x.dtype)
            g.x = torch.cat([g.x, pad], dim=1)
        elif g.x.size(1) > self.input_dim:
            g.x = g.x[:, :self.input_dim]
        # meta: apt index via feature[0] == 1.0
        apt_mask = (g.x[:, 0] > 0.5)
        g.apt_idx = int(torch.nonzero(apt_mask, as_tuple=False)[0].item()) if apt_mask.any() else 0
        g.apt_id  = int(aid)
        return g

# ========================= Model (GraphConv + aux heads) =========================
class TinyGraphConv(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        self.conv1 = GraphConv(in_dim, hidden)
        self.conv2 = GraphConv(hidden, out_dim)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.W = nn.Linear(out_dim, out_dim, bias=False)  # bilinear for edges
        self.head_deg   = nn.Linear(out_dim, 1)
        self.head_meanw = nn.Linear(out_dim, 1)

    def forward(self, data: Data):
        x, ei = data.x, data.edge_index
        ew = data.edge_attr.view(-1) if data.edge_attr is not None else None
        h = self.conv1(x, ei, edge_weight=ew)
        h = self.act(h)
        h = self.dropout(h)
        h = self.conv2(h, ei, edge_weight=ew)
        return h

    def score(self, h_src: torch.Tensor, h_dst: torch.Tensor) -> torch.Tensor:
        return (self.W(h_src) * h_dst).sum(dim=-1)

Device: cuda


In [2]:
# ========================= Train one context =========================
def edge_weights_for_training(g: Data) -> torch.Tensor:
    """Return edge weights in [0,1]: general already in [0,1]; metro/bus might be meters → 1/(1+dist)."""
    w = g.edge_attr.view(-1)
    if (w > 1.0).any():
        w = 1.0 / (1.0 + torch.clamp(w, min=0.0))
    return torch.clamp(w, 0.0, 1.0)

def train_context(items: List[Tuple[int, Data]], context_name: str) -> Tuple[TinyGraphConv, StarDataset]:
    if not items:
        print(f"[{context_name}] no graphs; skipping.")
        return None, None
    base_f = common_feat_dim(items)
    ds = StarDataset(items, base_f)
    dl = DataLoader(ds, batch_size=BATCH_SZ, shuffle=True)
    model = TinyGraphConv(ds.input_dim, HIDDEN, EMB_DIM).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    def train_epoch() -> float:
        model.train()
        tot, cnt = 0.0, 0
        for big in dl:
            big = big.to(DEVICE)
            opt.zero_grad()

            H = model(big)
            src, dst = big.edge_index

            # apartment global indices
            num_g = big.ptr.numel() - 1
            apt_global = []
            for i in range(num_g):
                base = int(big.ptr[i].item())
                aidx = int(big.apt_idx[i].item())
                apt_global.append(base + aidx)
            apt_global = torch.tensor(apt_global, device=big.x.device, dtype=src.dtype)

            # edges whose src is that apartment
            mask = (src.view(-1,1) == apt_global.view(1,-1)).any(dim=1)
            if not mask.any(): 
                continue
            src_pos = src[mask]; dst_pos = dst[mask]
            targets_all = edge_weights_for_training(big)
            targets_pos = targets_all[mask]

            # pos loss
            preds_pos = torch.sigmoid(model.score(H[src_pos], H[dst_pos]))
            loss_pos = F.mse_loss(preds_pos, targets_pos)

            # negatives per graph
            neg_src_list, neg_dst_list = [], []
            for i in range(num_g):
                base = int(big.ptr[i].item()); end = int(big.ptr[i+1].item())
                a_gi = base + int(big.apt_idx[i].item())
                true_dsts = set(dst_pos[(src_pos == a_gi)].tolist())
                cands = [j for j in range(base, end) if j != a_gi and j not in true_dsts]
                if not cands: continue
                pick = np.random.choice(cands, size=min(NEG_K, len(cands)), replace=False)
                neg_src_list += [a_gi]*len(pick)
                neg_dst_list += [int(x) for x in pick]
            if neg_src_list:
                neg_src = torch.tensor(neg_src_list, device=big.x.device, dtype=src.dtype)
                neg_dst = torch.tensor(neg_dst_list, device=big.x.device, dtype=dst.dtype)
                preds_neg = torch.sigmoid(model.score(H[neg_src], H[neg_dst]))
                loss_neg = F.mse_loss(preds_neg, torch.zeros_like(preds_neg))
            else:
                loss_neg = torch.tensor(0.0, device=big.x.device)

            # aux graph-level targets from apt embedding
            deg_t, meanw_t = [], []
            for i in range(num_g):
                base = int(big.ptr[i].item())
                a_gi = base + int(big.apt_idx[i].item())
                sel = (src_pos == a_gi)
                deg_i = int(sel.sum().item()); deg_t.append(deg_i)
                meanw_t.append(float(targets_pos[sel].mean().item()) if deg_i>0 else 0.0)
            deg_t   = torch.tensor(deg_t, device=big.x.device, dtype=torch.float).view(-1,1)
            meanw_t = torch.tensor(meanw_t, device=big.x.device, dtype=torch.float).view(-1,1)
            deg_tn  = torch.log1p(deg_t)/4.0
            apt_h   = H[apt_global]
            loss_deg   = F.mse_loss(model.head_deg(apt_h), deg_tn)
            loss_meanw = F.mse_loss(torch.sigmoid(model.head_meanw(apt_h)), meanw_t)

            loss = loss_pos + 0.5*loss_neg + 0.2*loss_deg + 0.2*loss_meanw
            loss.backward(); opt.step()
            tot += float(loss.item()); cnt += 1
        return tot/max(cnt,1)

    print(f"[{context_name}] training on {len(ds)} graphs; input_dim={ds.input_dim}")
    for ep in range(1, EPOCHS+1):
        t0 = time.time()
        los = train_epoch()
        print(f"[{context_name}] epoch {ep:02d}  loss={los:.6f}  ({time.time()-t0:.1f}s)")
    return model, ds

In [3]:

@torch.no_grad()
def embed_graph(model: TinyGraphConv, g: Data, input_dim: int) -> np.ndarray:
    # ensure input_dim
    if g.x.size(1) < input_dim:
        pad = torch.zeros((g.num_nodes, input_dim - g.x.size(1)), dtype=g.x.dtype)
        g.x = torch.cat([g.x, pad], dim=1)
    elif g.x.size(1) > input_dim:
        g.x = g.x[:, :input_dim]
    g = g.to(DEVICE)
    H = model(g)
    # apartment index
    apt_mask = (g.x[:, 0] > 0.5)
    aidx = int(torch.nonzero(apt_mask, as_tuple=False)[0].item()) if apt_mask.any() else 0
    return H[aidx].cpu().numpy()

# ========================= Run per context & build CSV =========================
def train_and_embed_all() -> pd.DataFrame:
    # collect items per context
    per_ctx_items: Dict[str, List[Tuple[int, Data]]] = {}
    for cls in CLASSES:
        per_ctx_items[cls] = load_general_items_for_class(cls)
        print(f"[collect] {cls}: {len(per_ctx_items[cls])} graphs")
    per_ctx_items['metro'] = load_metro_items(); print(f"[collect] metro: {len(per_ctx_items['metro'])} graphs")
    per_ctx_items['bus']   = load_bus_items();   print(f"[collect] bus:   {len(per_ctx_items['bus'])} graphs")

    # union of all apartment IDs present anywhere
    all_ids: Set[int] = set()
    for ctx, items in per_ctx_items.items():
        all_ids |= set(aid for aid,_ in items)
    all_ids = sorted(all_ids)
    print("Total distinct apartments across contexts:", len(all_ids))

    # For each context: train separate model; then embed all apartments for that context
    ctx_embeddings: Dict[str, Dict[int, Optional[List[float]]]] = {}
    for ctx in ALL_CONTEXTS:
        items = per_ctx_items.get(ctx, [])
        if not items:
            print(f"[{ctx}] no items; skipping")
            ctx_embeddings[ctx] = {aid: None for aid in all_ids}
            continue

        # Determine base_f for this context (for inference too)
        base_f = common_feat_dim(items)
        input_dim = base_f + 1  # + distance scalar
        model, ds = train_context(items, ctx)
        if model is None:  # no graphs at all
            ctx_embeddings[ctx] = {aid: None for aid in all_ids}
            continue

        # map apt->graph for fast lookup
        by_id: Dict[int, Data] = {aid: g for aid, g in items}

        # embed
        model.eval()
        emb_map: Dict[int, Optional[List[float]]] = {}
        for aid in all_ids:
            g = by_id.get(aid)
            if g is None:
                emb_map[aid] = None
            else:
                # build the same input pipeline used in training: pad->+distance
                if g.x.size(1) < base_f:
                    pad = torch.zeros((g.num_nodes, base_f - g.x.size(1)), dtype=g.x.dtype)
                    g = Data(x=torch.cat([g.x, pad], dim=1), edge_index=g.edge_index, edge_attr=g.edge_attr)
                g = add_distance_scalar(g)
                vec = embed_graph(model, g, input_dim)
                emb_map[aid] = [float(x) for x in vec.tolist()]
        ctx_embeddings[ctx] = emb_map

    # Build CSV
    rows = []
    for aid in all_ids:
        row = {'id': aid}
        for ctx in ALL_CONTEXTS:
            v = ctx_embeddings[ctx].get(aid)
            row[f"emb_{ctx}"] = None if v is None else json.dumps(v)
        rows.append(row)
    df = pd.DataFrame(rows)
    df.to_csv(OUT_CSV, index=False)
    print("Saved:", OUT_CSV, "| shape:", df.shape)
    return df

df_out = train_and_embed_all()
df_out.head()


[collect] sport_and_leisure: 25123 graphs
[collect] medical: 25087 graphs
[collect] education_prim: 24519 graphs
[collect] veterinary: 22982 graphs
[collect] food_and_drink_stores: 24694 graphs
[collect] arts_and_entertainment: 25007 graphs
[collect] food_and_drink: 24205 graphs
[collect] park_like: 24275 graphs
[collect] security: 24564 graphs
[collect] religion: 21291 graphs
[collect] education_sup: 24581 graphs
[collect] metro: 16752 graphs
[collect] bus:   24175 graphs
Total distinct apartments across contexts: 25211
[sport_and_leisure] training on 25123 graphs; input_dim=5
[sport_and_leisure] epoch 01  loss=0.128077  (29.5s)
[sport_and_leisure] epoch 02  loss=0.019926  (27.5s)
[sport_and_leisure] epoch 03  loss=0.018856  (27.3s)
[sport_and_leisure] epoch 04  loss=0.018185  (30.2s)
[sport_and_leisure] epoch 05  loss=0.017456  (27.6s)
[sport_and_leisure] epoch 06  loss=0.016883  (27.6s)
[medical] training on 25087 graphs; input_dim=5
[medical] epoch 01  loss=0.067843  (24.2s)
[medic

Unnamed: 0,id,emb_sport_and_leisure,emb_medical,emb_education_prim,emb_veterinary,emb_food_and_drink_stores,emb_arts_and_entertainment,emb_food_and_drink,emb_park_like,emb_security,emb_religion,emb_education_sup,emb_metro,emb_bus
0,1359204515,"[0.5447400808334351, -0.8694881200790405, 0.97...","[0.9175235033035278, 0.24008315801620483, 0.54...","[-0.2462320327758789, -0.24796371161937714, -0...","[0.5069181323051453, 0.23975862562656403, 0.73...","[0.39807450771331787, 0.5578294396400452, -0.2...","[-0.34782326221466064, -0.3198614716529846, -0...","[-0.809662938117981, 0.24182315170764923, -0.7...","[-0.31603437662124634, -0.06535714864730835, 0...","[-0.03100641816854477, -0.22747787833213806, -...","[-0.524695098400116, 0.09304139018058777, -0.2...","[-0.7214534282684326, -0.43536949157714844, -0...","[-0.6309113502502441, 0.600777268409729, 0.978...","[1.7241809368133545, -0.9379951357841492, -0.6..."
1,1366496843,"[0.5447400808334351, -0.8694881200790405, 0.97...","[0.9175235033035278, 0.24008315801620483, 0.54...","[-0.2462320327758789, -0.24796371161937714, -0...","[0.5069181323051453, 0.23975862562656403, 0.73...","[0.39807450771331787, 0.5578294396400452, -0.2...","[-0.34782326221466064, -0.3198614716529846, -0...","[-0.809662938117981, 0.24182315170764923, -0.7...","[-0.31603437662124634, -0.06535713374614716, 0...","[-0.03100641816854477, -0.22747787833213806, -...","[-0.524695098400116, 0.09304139763116837, -0.2...","[-0.7214534282684326, -0.43536949157714844, -0...","[-0.6309113502502441, 0.600777268409729, 0.978...","[1.7241809368133545, -0.9379951357841492, -0.6..."
2,1367599797,"[0.5447400808334351, -0.8694881200790405, 0.97...","[0.9175235033035278, 0.24008315801620483, 0.54...","[-0.2462320327758789, -0.24796371161937714, -0...","[0.5069181323051453, 0.23975862562656403, 0.73...","[0.39807450771331787, 0.5578294396400452, -0.2...","[-0.34782326221466064, -0.3198614716529846, -0...","[-0.809662938117981, 0.24182315170764923, -0.7...","[-0.31603437662124634, -0.06535714864730835, 0...","[-0.03100641816854477, -0.22747787833213806, -...","[-0.524695098400116, 0.09304139763116837, -0.2...","[-0.7214534282684326, -0.43536949157714844, -0...","[-0.6309113502502441, 0.600777268409729, 0.978...","[1.7241809368133545, -0.9379951357841492, -0.6..."
3,1391886163,"[0.5447400808334351, -0.8694881200790405, 0.97...","[0.9175235033035278, 0.24008315801620483, 0.54...","[-0.2462320327758789, -0.24796371161937714, -0...","[0.5069181323051453, 0.23975862562656403, 0.73...","[0.39807450771331787, 0.5578294396400452, -0.2...","[-0.34782326221466064, -0.3198614716529846, -0...","[-0.809662938117981, 0.24182315170764923, -0.7...","[-0.31603437662124634, -0.06535713374614716, 0...","[-0.03100641816854477, -0.22747787833213806, -...","[-0.524695098400116, 0.09304139763116837, -0.2...","[-0.7214534282684326, -0.43536949157714844, -0...","[-0.6309113502502441, 0.600777268409729, 0.978...","[1.7241809368133545, -0.9379951357841492, -0.6..."
4,1408157926,"[0.5447400808334351, -0.8694881200790405, 0.97...","[0.9175235033035278, 0.24008315801620483, 0.54...","[-0.2462320327758789, -0.24796371161937714, -0...","[0.5069181323051453, 0.23975862562656403, 0.73...","[0.39807450771331787, 0.5578294396400452, -0.2...","[-0.34782326221466064, -0.3198614716529846, -0...","[-0.809662938117981, 0.24182315170764923, -0.7...","[-0.31603437662124634, -0.06535714864730835, 0...","[-0.03100641816854477, -0.22747787833213806, -...",,"[-0.7214534282684326, -0.43536949157714844, -0...","[-0.6309113502502441, 0.600777268409729, 0.978...","[1.7241809368133545, -0.9379951357841492, -0.6..."


In [5]:
def weights_from_graph(g: Data) -> Optional[torch.Tensor]:
    if g is None or g.edge_attr is None or g.edge_attr.numel() == 0:
        return None
    w = g.edge_attr.view(-1).clone()
    # If looks like meters, squash to [0,1]
    if (w > 1.0).any():
        w = 1.0 / (1.0 + torch.clamp(w, min=0.0))
    return torch.clamp(w, 0.0, 1.0)

def star_stats_12d(w: torch.Tensor) -> np.ndarray:
    # w in [0,1], shape [E]
    E = float(w.numel())
    if E == 0:
        return np.full(12, np.nan, dtype=np.float32)

    s1  = w.sum()
    s2  = (w**2).sum()
    s3  = (w**3).sum()
    mean = float(s1 / E)
    var  = float(s2 / E - mean**2)
    std  = float(np.sqrt(max(var, 0.0)))
    w_np = w.cpu().numpy()
    w_sorted = np.sort(w_np)[::-1]  # desc by proximity
    top1 = float(w_sorted[0])
    top3_mean = float(w_sorted[:min(3, w_sorted.size)].mean())
    tail = float((w_np > 0.75).mean())  # share of very close POIs
    mn = float(w_np.min()); mx = float(w_np.max())
    # two coarse bins
    bin1 = float(((w_np > 0.0) & (w_np <= 0.5)).mean())
    bin2 = float((w_np > 0.5).mean())

    vec = np.array([
        E, s1, s2, s3, mean, std, mn, mx, top1, top3_mean, tail, bin2
    ], dtype=np.float32)
    return vec

def build_baseline_embeddings_csv(
    general_shards, metro_shards, bus_shards, root: Path, out_csv: str
) -> pd.DataFrame:
    # Gather items just like before
    from collections import defaultdict
    def load_general_items_for_class(ctx: str):
        items = []
        for name in general_shards:
            p = root / name
            if not p.exists(): continue
            part = load_pickle(p)
            for aid, gdict in part.items():
                g = gdict.get(ctx)
                if isinstance(g, Data):
                    items.append((int(aid), g))
        return items

    def load_simple(shard_list):
        items = []
        for name in shard_list:
            p = root / name
            if not p.exists(): continue
            part = load_pickle(p)  # dict[apt_id] -> Data or None
            for aid, g in part.items():
                if isinstance(g, Data):
                    items.append((int(aid), g))
        return items

    CLASSES = [
        'sport_and_leisure','medical','education_prim','veterinary',
        'food_and_drink_stores','arts_and_entertainment','food_and_drink',
        'park_like','security','religion','education_sup'
    ]
    ALL_CTX = CLASSES + ['metro','bus']

    per_ctx = {}
    for ctx in CLASSES:
        per_ctx[ctx] = load_general_items_for_class(ctx)
        print(f"[collect] {ctx}: {len(per_ctx[ctx])} graphs")
    per_ctx['metro'] = load_simple(metro_shards)
    per_ctx['bus']   = load_simple(bus_shards)
    print(f"[collect] metro: {len(per_ctx['metro'])} | bus: {len(per_ctx['bus'])}")

    all_ids = set()
    for ctx, items in per_ctx.items():
        all_ids |= {aid for aid, _ in items}
    all_ids = sorted(all_ids)
    print("Total distinct apartments:", len(all_ids))

    # Build fast embeddings
    ctx_embs = {ctx: {} for ctx in ALL_CTX}
    for ctx in ALL_CTX:
        items = dict(per_ctx.get(ctx, []))  # apt_id -> Data
        for aid in all_ids:
            g = items.get(aid)
            if g is None:
                ctx_embs[ctx][aid] = None
            else:
                w = weights_from_graph(g)
                if w is None:
                    ctx_embs[ctx][aid] = None
                else:
                    vec = star_stats_12d(w)
                    # JSON string for CSV; keep None if no graph
                    ctx_embs[ctx][aid] = json.dumps([float(x) for x in vec.tolist()])

    # Assemble CSV
    rows = []
    for aid in all_ids:
        row = {'id': aid}
        for ctx in ALL_CTX:
            row[f"emb_{ctx}"] = ctx_embs[ctx][aid]  # None or JSON
        rows.append(row)
    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("Saved:", out_csv, "| shape:", df.shape)
    return df

# ---- run it
df_fast = build_baseline_embeddings_csv(
    GENERAL_SHARDS, METRO_SHARDS, BUS_SHARDS, ROOT,
    out_csv="apartment_embeddings_per_context_FAST.csv"
)
df_fast.head()

[collect] sport_and_leisure: 25123 graphs
[collect] medical: 25087 graphs
[collect] education_prim: 24519 graphs
[collect] veterinary: 22982 graphs
[collect] food_and_drink_stores: 24694 graphs
[collect] arts_and_entertainment: 25007 graphs
[collect] food_and_drink: 24205 graphs
[collect] park_like: 24275 graphs
[collect] security: 24564 graphs
[collect] religion: 21291 graphs
[collect] education_sup: 24581 graphs
[collect] metro: 16752 | bus: 24175
Total distinct apartments: 25211
Saved: apartment_embeddings_per_context_FAST.csv | shape: (25211, 14)


Unnamed: 0,id,emb_sport_and_leisure,emb_medical,emb_education_prim,emb_veterinary,emb_food_and_drink_stores,emb_arts_and_entertainment,emb_food_and_drink,emb_park_like,emb_security,emb_religion,emb_education_sup,emb_metro,emb_bus
0,1359204515,"[60.0, 16.075674057006836, 6.5722761154174805,...","[38.0, 15.997368812561035, 9.423164367675781, ...","[13.0, 3.977703809738159, 1.8542356491088867, ...","[6.0, 2.277141571044922, 1.3051209449768066, 0...","[27.0, 9.532333374023438, 4.8009796142578125, ...","[14.0, 5.213222026824951, 2.8603620529174805, ...","[41.0, 11.050660133361816, 5.310342311859131, ...","[3.0, 1.726252555847168, 0.9990460276603699, 0...","[6.0, 1.7289631366729736, 1.121602177619934, 0...","[3.0, 0.8515275120735168, 0.3497370481491089, ...","[9.0, 3.1892800331115723, 1.4277830123901367, ...","[1.0, 0.002815012587234378, 7.924295459815767e...","[10.0, 0.03267307206988335, 0.0001179291430162..."
1,1366496843,"[12.0, 5.705898284912109, 3.5092594623565674, ...","[22.0, 9.402782440185547, 5.164984226226807, 3...","[9.0, 3.0975704193115234, 1.4777079820632935, ...","[8.0, 2.286616325378418, 0.9103704690933228, 0...","[14.0, 6.050302982330322, 3.726116895675659, 2...","[18.0, 5.941456317901611, 2.87703275680542, 1....","[16.0, 7.021714687347412, 4.83206844329834, 3....","[6.0, 2.8207437992095947, 1.765191674232483, 1...","[5.0, 3.2117748260498047, 2.646604061126709, 2...","[5.0, 1.2922619581222534, 0.6314557194709778, ...","[28.0, 3.475991725921631, 0.8413114547729492, ...","[4.0, 0.007331489585340023, 1.392192643834278e...","[11.0, 0.062020767480134964, 0.000560429005417..."
2,1367599797,"[78.0, 26.780254364013672, 13.415040969848633,...","[41.0, 12.142561912536621, 5.951859474182129, ...","[11.0, 4.3323211669921875, 2.161160469055176, ...","[7.0, 2.1093690395355225, 0.9442524313926697, ...","[23.0, 6.6719489097595215, 2.913447618484497, ...","[30.0, 9.518875122070312, 4.2043681144714355, ...","[25.0, 7.744740009307861, 3.39424204826355, 1....","[2.0, 0.5252547264099121, 0.23184120655059814,...","[11.0, 4.0218987464904785, 2.126202344894409, ...","[4.0, 2.0272626876831055, 1.0776073932647705, ...","[100.0, 25.64299964904785, 9.471589088439941, ...","[3.0, 0.004855748265981674, 8.05325817054836e-...","[12.0, 0.0572342574596405, 0.00030509982025250..."
3,1391886163,"[11.0, 5.413773536682129, 3.2894861698150635, ...","[23.0, 8.883554458618164, 4.721118450164795, 2...","[10.0, 3.3879919052124023, 1.6319491863250732,...","[8.0, 1.9532856941223145, 0.6671395897865295, ...","[13.0, 5.776050090789795, 3.2777152061462402, ...","[17.0, 6.013828277587891, 2.922703742980957, 1...","[16.0, 6.098169803619385, 3.082339286804199, 1...","[7.0, 2.3664896488189697, 1.018655776977539, 0...","[5.0, 3.004030704498291, 2.287025213241577, 1....","[8.0, 2.3232483863830566, 0.9722421765327454, ...","[35.0, 4.483249664306641, 0.968339204788208, 0...","[3.0, 0.005734436679631472, 1.103953763958998e...","[9.0, 0.043397918343544006, 0.0002502724819350..."
4,1408157926,"[33.0, 8.578935623168945, 2.700268268585205, 0...","[18.0, 6.049106121063232, 2.7733232975006104, ...","[6.0, 2.608611822128296, 1.2793641090393066, 0...","[7.0, 2.411283254623413, 1.3507696390151978, 0...","[11.0, 2.56019926071167, 0.9970377683639526, 0...","[15.0, 6.22273063659668, 3.067983388900757, 1....","[19.0, 4.564805507659912, 2.0996875762939453, ...","[3.0, 1.3252400159835815, 0.6238857507705688, ...","[7.0, 2.4722888469696045, 1.1537070274353027, ...",,"[4.0, 2.344971179962158, 1.6418575048446655, 1...","[1.0, 0.0015795428771525621, 2.49495565185498e...","[6.0, 0.09124578535556793, 0.00285489996895194..."
