In [None]:
# Notebook: Export MetroStation & BusStop Star Graphs (Python 3.9)
# ==========================BUS VERSION============================
# This file contains two self-contained sections so you can create
# two notebooks if you prefer. Each section follows the same pattern:
# - Threaded export (I/O-bound) with retries
# - Per-batch shard saved into Graph_data/
# - CSV of done apartment IDs for resume
# - PyG Data node features are minimal and consistent per tag
#
# MetroStation:
#   Node features: [is_apartment, is_metro]
#   Edge weight: distancia_metros (as provided)
# BusStop:
#   Node features: [is_apartment, is_bus]
#   Edge weight: distancia_metros (as provided)
#
# Both keep `apt_id` stored on the Data object.

In [1]:

# ============================
# Section 0 — Common Imports
# ============================
import os, time, pickle
from typing import Optional, Dict, List, Tuple, Set

import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data

from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, TransientError
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from datetime import datetime

In [2]:
# Neo4j connection (edit if needed)
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"

# Dataset
DF_PATH = 'Datasets/dataset_final.csv'

# Threading / retries
MAX_WORKERS = 8            # old PC: 8 cores / 16 threads
MAX_RETRIES = 2
RETRY_DELAY = 5

# Shard dir (shared base folder, different filenames per tag)
SHARD_DIR_BASE = Path("Graph_data")
SHARD_DIR_BASE.mkdir(exist_ok=True)

# ==================================
# Utilities (shared by both sections)
# ==================================

def atomic_dump(obj: dict, path: Path) -> None:
    """Atomic pickle dump to avoid corruption on interrupts."""
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.flush(); os.fsync(f.fileno())
    os.replace(tmp, path)


def run_threaded_over_ids(batch_ids: List[int], worker_fn):
    """Thread pool runner. worker_fn(apt_id) -> (apt_id, data_or_dict, err_str)."""
    results: Dict[int, object] = {}
    start_batch = time.time()
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(worker_fn, aid): aid for aid in batch_ids}
        for i, fut in enumerate(as_completed(futs), 1):
            aid = futs[fut]
            try:
                apt_id, payload, err = fut.result()
            except Exception as e:
                apt_id, payload, err = aid, None, f"worker crash: {e}"
            ok = (err is None and payload is not None)
            status = "OK" if ok else f"ERR: {err}"
            print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> {status}")
            if ok:
                results[apt_id] = payload
    elapsed = time.time() - start_batch
    return results, elapsed


In [5]:
# =====================================================
# Section B — BusStop (tag: :BusStop, rel: CERCA_BUS)
# =====================================================

# ---------- Config (Bus) ----------
BUS_CLASSES_NAME = "bus"
SHARD_DIR_BUS = SHARD_DIR_BASE
DONE_CSV_BUS = SHARD_DIR_BUS / "done_ids_bus.csv"
SHARD_PREFIX_BUS = "BUSSHARD"

BATCH_SIZE_BUS = 25115  # adjust freely per run

# ---------- Queries (Bus) ----------
NODE_QUERY_BUS = """
MATCH (d:Departamento {id:$apt_id})
RETURN elementId(d) AS id, 'Departamento' AS label, d.id AS apt_id
UNION
MATCH (d:Departamento {id:$apt_id})-[:CERCA_BUS]->(b:BusStop)
RETURN elementId(b) AS id, 'BusStop' AS label, null AS apt_id
"""

REL_QUERY_BUS = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_BUS]->(b:BusStop)
RETURN elementId(d) AS source, elementId(b) AS target, r.distancia_metros AS weight
"""

# ---------- Resume helpers (Bus) ----------

def list_shards_bus() -> List[Path]:
    return sorted(SHARD_DIR_BUS.glob(f"{SHARD_PREFIX_BUS}_*.pkl"))


def load_done_ids_bus() -> Set[int]:
    if DONE_CSV_BUS.exists():
        s: Set[int] = set()
        with open(DONE_CSV_BUS, "r") as f:
            for line in f:
                line = line.strip()
                if line.isdigit():
                    s.add(int(line))
        return s
    done: Set[int] = set()
    for p in list_shards_bus():
        try:
            with open(p, "rb") as f:
                d = pickle.load(f)
            done |= set(d.keys())
        except Exception as e:
            print(f"[warn] skipping shard {p.name}: {e}")
    if done:
        with open(DONE_CSV_BUS, "w") as f:
            for aid in sorted(done):
                f.write(f"{aid}\n")
    return done


def append_done_ids_bus(ids: List[int]) -> None:
    if not ids:
        return
    with open(DONE_CSV_BUS, "a") as f:
        for aid in ids:
            f.write(f"{aid}\n")

# ---------- Export (Bus) ----------

def export_bus_graph(apt_id: int, session=None) -> Optional[Data]:
    """Return a star graph Data for Bus around an apartment, or None if no nodes/edges.
    Node features: [is_apartment, is_bus]. Edge weight: distancia_metros (raw).
    """
    if session is None:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as s:
                return export_bus_graph(apt_id, session=s)
        finally:
            driver.close()

    nodes = session.run(NODE_QUERY_BUS, {"apt_id": apt_id}).data()
    if not nodes:
        return None
    node_df = pd.DataFrame(nodes)
    id_map = {nid: i for i, nid in enumerate(node_df["id"].tolist())}

    edges = session.run(REL_QUERY_BUS, {"apt_id": apt_id}).data()
    if not edges:
        return None
    edge_df = pd.DataFrame(edges)

    edge_index = torch.tensor([
        [id_map[s] for s in edge_df["source"]],
        [id_map[t] for t in edge_df["target"]],
    ], dtype=torch.long)
    edge_attr = torch.tensor(edge_df["weight"].values, dtype=torch.float).unsqueeze(1)

    feats: List[List[float]] = []
    for _, row in node_df.iterrows():
        if row["label"] == "Departamento":
            feats.append([1.0, 0.0])
        else:  # BusStop
            feats.append([0.0, 1.0])
    x = torch.tensor(np.array(feats), dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, apt_id=apt_id, poi_tag=BUS_CLASSES_NAME)

# ---------- Workers (Bus) ----------

def process_apartment_bus_once(apt_id: int) -> Tuple[int, Optional[Data], Optional[str]]:
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as session:
                g = export_bus_graph(apt_id, session=session)
                return apt_id, g, None
        finally:
            driver.close()
    except (ServiceUnavailable, TransientError, Exception) as e:
        return apt_id, None, str(e)


def process_apartment_bus_with_retries(apt_id: int) -> Tuple[int, Optional[Data], Optional[str]]:
    for attempt in range(1, MAX_RETRIES + 2):
        aid, g, err = process_apartment_bus_once(apt_id)
        if err is None:
            return aid, g, None
        if attempt <= MAX_RETRIES:
            time.sleep(RETRY_DELAY)
    return aid, None, err

# ---------- Batch runner (Bus) ----------

def run_bus_batch():
    df = pd.read_csv(DF_PATH)
    done = load_done_ids_bus()
    apt_ids = df['id'].tolist()[::-1]  # bottom-up
    pending = [i for i in apt_ids if i not in done]

    batch_ids = pending[:BATCH_SIZE_BUS]
    print(f"Bus: processing {len(batch_ids)} apartments (remaining after: {len(pending)-len(batch_ids)})")
    if not batch_ids:
        print("Bus: nothing to do.")
        return

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    shard_name = f"{SHARD_PREFIX_BUS}_{ts}_{batch_ids[0]}-{batch_ids[-1]}.pkl"
    shard_path = SHARD_DIR_BUS / shard_name
    print(f"Bus shard: {shard_path}")

    start_batch = time.time()
    ok_graphs: Dict[int, Data] = {}
    succeeded_ids: List[int] = []   # apartments that finished (graph or None)
    failed_ids: List[int] = []      # apartments that errored out

    def worker(aid: int):
        return process_apartment_bus_with_retries(aid)

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(worker, aid): aid for aid in batch_ids}
        for i, fut in enumerate(as_completed(futs), 1):
            aid = futs[fut]
            try:
                apt_id, g, err = fut.result()
            except Exception as e:
                apt_id, g, err = aid, None, f"worker crash: {e}"
            if err is None:
                # success (even if g is None because no bus stops)
                succeeded_ids.append(apt_id)
                if g is not None:
                    ok_graphs[apt_id] = g
                print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> OK ({'graph' if g is not None else 'none'})")
            else:
                failed_ids.append(apt_id)
                print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> ERR: {err}")

    # Save graphs (only non-None) and mark ALL successes done (including None)
    if ok_graphs:
        atomic_dump(ok_graphs, shard_path)
        print(f"Saved {len(ok_graphs)} bus graphs to {shard_path.name}")

    if succeeded_ids:
        append_done_ids_bus(sorted(succeeded_ids))

    elapsed = time.time() - start_batch
    print(f"✅ Bus batch finished in {elapsed:.2f}s "
          f"(ok: {len(succeeded_ids)}, with_graphs: {len(ok_graphs)}, fail: {len(failed_ids)})")





In [6]:
run_bus_batch()

Bus: processing 25115 apartments (remaining after: 0)
Bus shard: Graph_data\BUSSHARD_20250826_153430_2862820058-1548097259.pkl
[1/25115] apt=1572173539 -> OK (graph)
[2/25115] apt=2862820058 -> OK (graph)
[3/25115] apt=2833913964 -> OK (graph)
[4/25115] apt=2857961014 -> OK (graph)
[5/25115] apt=1590737915 -> OK (graph)
[6/25115] apt=1540403013 -> OK (graph)
[7/25115] apt=2861515108 -> OK (graph)
[8/25115] apt=2860590024 -> OK (graph)
[9/25115] apt=2744371160 -> OK (graph)
[10/25115] apt=1583717711 -> OK (graph)
[11/25115] apt=1591516075 -> OK (graph)
[12/25115] apt=2843808890 -> OK (graph)
[13/25115] apt=2854484284 -> OK (graph)
[14/25115] apt=1572273345 -> OK (graph)
[15/25115] apt=1587866625 -> OK (graph)
[16/25115] apt=2800860534 -> OK (graph)
[17/25115] apt=2848726482 -> OK (graph)
[18/25115] apt=1584056011 -> OK (graph)
[19/25115] apt=1586351337 -> OK (graph)
[20/25115] apt=1587835167 -> OK (graph)
[21/25115] apt=1591789513 -> OK (graph)
[22/25115] apt=2839803586 -> OK (graph)
[2

In [7]:
# Optional: quick inspect last shard (Bus)

def inspect_latest_bus_shard(n_last: int = 10):
    shards = list_shards_bus()
    if not shards:
        print("No bus shards.")
        return
    latest = shards[-1]
    print(f"Latest bus shard: {latest.name}")
    d = pickle.load(open(latest, "rb"))
    ids = list(d.keys())[-n_last:]
    for aid in ids:
        g = d[aid]
        print(f"apt {aid}: x={tuple(g.x.shape)}, edges={tuple(g.edge_index.shape)}, attr={tuple(g.edge_attr.shape)}")


# =============================
# Usage examples (one-liners):
# =============================
# BATCH_SIZE_METRO = 700; run_metro_batch(); inspect_latest_metro_shard(10)
# BATCH_SIZE_BUS   = 700; run_bus_batch();   inspect_latest_bus_shard(10)

In [8]:
inspect_latest_bus_shard(10)

Latest bus shard: BUSSHARD_20250826_153430_2862820058-1548097259.pkl
apt 2849646914: x=(9, 2), edges=(2, 8), attr=(8, 1)
apt 1518864795: x=(10, 2), edges=(2, 9), attr=(9, 1)
apt 1588684845: x=(8, 2), edges=(2, 7), attr=(7, 1)
apt 1521489447: x=(12, 2), edges=(2, 11), attr=(11, 1)
apt 1543740023: x=(7, 2), edges=(2, 6), attr=(6, 1)
apt 2803401906: x=(7, 2), edges=(2, 6), attr=(6, 1)
apt 1536666991: x=(6, 2), edges=(2, 5), attr=(5, 1)
apt 1592811585: x=(5, 2), edges=(2, 4), attr=(4, 1)
apt 2732119230: x=(12, 2), edges=(2, 11), attr=(11, 1)
apt 1548097259: x=(9, 2), edges=(2, 8), attr=(8, 1)
