In [None]:
# Notebook: Export MetroStation & BusStop Star Graphs (Python 3.9)
# ========================METRO VERSION===========================
# This file contains two self-contained sections so you can create
# two notebooks if you prefer. Each section follows the same pattern:
# - Threaded export (I/O-bound) with retries
# - Per-batch shard saved into Graph_data/
# - CSV of done apartment IDs for resume
# - PyG Data node features are minimal and consistent per tag
#
# MetroStation:
#   Node features: [is_apartment, is_metro]
#   Edge weight: distancia_metros (as provided)
# BusStop:
#   Node features: [is_apartment, is_bus]
#   Edge weight: distancia_metros (as provided)
#
# Both keep `apt_id` stored on the Data object.

In [1]:

# ============================
# Section 0 — Common Imports
# ============================
import os, time, pickle
from typing import Optional, Dict, List, Tuple, Set

import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data

from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, TransientError
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from datetime import datetime

In [2]:
# Neo4j connection (edit if needed)
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"

# Dataset
DF_PATH = 'Datasets/dataset_final.csv'

# Threading / retries
MAX_WORKERS = 8            # old PC: 8 cores / 16 threads
MAX_RETRIES = 2
RETRY_DELAY = 5

# Shard dir (shared base folder, different filenames per tag)
SHARD_DIR_BASE = Path("Graph_data")
SHARD_DIR_BASE.mkdir(exist_ok=True)

# ==================================
# Utilities (shared by both sections)
# ==================================

def atomic_dump(obj: dict, path: Path) -> None:
    """Atomic pickle dump to avoid corruption on interrupts."""
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.flush(); os.fsync(f.fileno())
    os.replace(tmp, path)


def run_threaded_over_ids(batch_ids: List[int], worker_fn):
    """Thread pool runner. worker_fn(apt_id) -> (apt_id, data_or_dict, err_str)."""
    results: Dict[int, object] = {}
    start_batch = time.time()
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(worker_fn, aid): aid for aid in batch_ids}
        for i, fut in enumerate(as_completed(futs), 1):
            aid = futs[fut]
            try:
                apt_id, payload, err = fut.result()
            except Exception as e:
                apt_id, payload, err = aid, None, f"worker crash: {e}"
            ok = (err is None and payload is not None)
            status = "OK" if ok else f"ERR: {err}"
            print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> {status}")
            if ok:
                results[apt_id] = payload
    elapsed = time.time() - start_batch
    return results, elapsed


In [3]:
# =====================================================
# Section A — MetroStation (tag: :MetroStation, rel: CERCA_METRO)
# =====================================================

# ---------- Config (Metro) ----------
METRO_CLASSES_NAME = "metro"  # stored on Data.poi_tag for clarity
SHARD_DIR_METRO = SHARD_DIR_BASE
DONE_CSV_METRO = SHARD_DIR_METRO / "done_ids_metro.csv"
SHARD_PREFIX_METRO = "METROSHARD"

BATCH_SIZE_METRO = 24215  # adjust freely per run

# Metro linea vocabulary (one-hot). Add/adjust as needed.
LINEA_VOCAB: List[str] = ["L1", "L2", "L3", "L4", "L4A", "L5", "L6"]
LINEA_INDEX = {v: i for i, v in enumerate(LINEA_VOCAB)}

# ---------- Queries (Metro) ----------
NODE_QUERY_METRO = """
MATCH (d:Departamento {id:$apt_id})
RETURN elementId(d) AS id, 'Departamento' AS label, d.id AS apt_id, null AS linea
UNION
MATCH (d:Departamento {id:$apt_id})-[:CERCA_METRO]->(m:MetroStation)
RETURN elementId(m) AS id, 'MetroStation' AS label, null AS apt_id, m.linea AS linea
"""

REL_QUERY_METRO = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_METRO]->(m:MetroStation)
RETURN elementId(d) AS source, elementId(m) AS target, r.distancia_metros AS weight
"""

# ---------- Resume helpers (Metro) ----------

def list_shards_metro() -> List[Path]:
    return sorted(SHARD_DIR_METRO.glob(f"{SHARD_PREFIX_METRO}_*.pkl"))


def load_done_ids_metro() -> Set[int]:
    if DONE_CSV_METRO.exists():
        s: Set[int] = set()
        with open(DONE_CSV_METRO, "r") as f:
            for line in f:
                line = line.strip()
                if line.isdigit():
                    s.add(int(line))
        return s
    # fallback: scan shards once
    done: Set[int] = set()
    for p in list_shards_metro():
        try:
            with open(p, "rb") as f:
                d = pickle.load(f)  # dict[int -> Data]
            done |= set(d.keys())
        except Exception as e:
            print(f"[warn] skipping shard {p.name}: {e}")
    if done:
        with open(DONE_CSV_METRO, "w") as f:
            for aid in sorted(done):
                f.write(f"{aid}\n")
    return done


def append_done_ids_metro(ids: List[int]) -> None:
    if not ids:
        return
    with open(DONE_CSV_METRO, "a") as f:
        for aid in ids:
            f.write(f"{aid}\n")

# ---------- Export (Metro) ----------

def export_metro_graph(apt_id: int, session=None) -> Optional[Data]:
    """Return a star graph Data for Metro around an apartment, or None if no nodes/edges.
    Node features: [is_apartment, is_metro, one-hot(linea...)] with LINEA_VOCAB.
    Edge weight: distancia_metros (raw).
    """
    if session is None:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as s:
                return export_metro_graph(apt_id, session=s)
        finally:
            driver.close()

    nodes = session.run(NODE_QUERY_METRO, {"apt_id": apt_id}).data()
    if not nodes:
        return None
    node_df = pd.DataFrame(nodes)
    id_map = {nid: i for i, nid in enumerate(node_df["id"].tolist())}

    edges = session.run(REL_QUERY_METRO, {"apt_id": apt_id}).data()
    if not edges:
        return None
    edge_df = pd.DataFrame(edges)

    edge_index = torch.tensor([
        [id_map[s] for s in edge_df["source"]],
        [id_map[t] for t in edge_df["target"]],
    ], dtype=torch.long)
    edge_attr = torch.tensor(edge_df["weight"].values, dtype=torch.float).unsqueeze(1)

    # Features: [is_apartment, is_metro, one-hot LINEA_VOCAB]
    feats: List[List[float]] = []
    for _, row in node_df.iterrows():
        if row["label"] == "Departamento":
            vec = [1.0, 0.0] + [0.0] * len(LINEA_VOCAB)
        else:  # MetroStation
            vec = [0.0, 1.0] + [0.0] * len(LINEA_VOCAB)
            linea_raw = row.get("linea")
            key = str(linea_raw).strip().upper() if linea_raw is not None else "OTHER"
            idx = LINEA_INDEX.get(key, LINEA_INDEX.get("OTHER"))
            if idx is not None:
                vec[2 + idx] = 1.0
        feats.append(vec)

    x = torch.tensor(np.array(feats), dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, apt_id=apt_id, poi_tag=METRO_CLASSES_NAME)

# ---------- Workers (Metro) ----------

def process_apartment_metro_once(apt_id: int) -> Tuple[int, Optional[Data], Optional[str]]:
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as session:
                g = export_metro_graph(apt_id, session=session)
                return apt_id, g, None
        finally:
            driver.close()
    except (ServiceUnavailable, TransientError, Exception) as e:
        return apt_id, None, str(e)


def process_apartment_metro_with_retries(apt_id: int) -> Tuple[int, Optional[Data], Optional[str]]:
    for attempt in range(1, MAX_RETRIES + 2):
        aid, g, err = process_apartment_metro_once(apt_id)
        if err is None:
            return aid, g, None
        if attempt <= MAX_RETRIES:
            time.sleep(RETRY_DELAY)
    return aid, None, err

# ---------- Batch runner (Metro) ----------

def run_metro_batch():
    df = pd.read_csv(DF_PATH)
    done = load_done_ids_metro()
    apt_ids = df['id'].tolist()[::-1]  # bottom-up
    pending = [i for i in apt_ids if i not in done]

    batch_ids = pending[:BATCH_SIZE_METRO]
    print(f"Metro: processing {len(batch_ids)} apartments (remaining after: {len(pending)-len(batch_ids)})")
    if not batch_ids:
        print("Metro: nothing to do.")
        return

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    shard_name = f"{SHARD_PREFIX_METRO}_{ts}_{batch_ids[0]}-{batch_ids[-1]}.pkl"
    shard_path = SHARD_DIR_METRO / shard_name
    print(f"Metro shard: {shard_path}")

    start_batch = time.time()
    ok_graphs: Dict[int, Data] = {}
    succeeded_ids: List[int] = []   # apartments that finished (graph or None)
    failed_ids: List[int] = []      # apartments that errored out

    def worker(aid: int):
        return process_apartment_metro_with_retries(aid)

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(worker, aid): aid for aid in batch_ids}
        for i, fut in enumerate(as_completed(futs), 1):
            aid = futs[fut]
            try:
                apt_id, g, err = fut.result()
            except Exception as e:
                apt_id, g, err = aid, None, f"worker crash: {e}"
            if err is None:
                # success (even if g is None because no metro)
                succeeded_ids.append(apt_id)
                if g is not None:
                    ok_graphs[apt_id] = g
                print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> OK ({'graph' if g is not None else 'none'})")
            else:
                failed_ids.append(apt_id)
                print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> ERR: {err}")

    # Save graphs (only non-None) and mark ALL successes done (including None)
    if ok_graphs:
        atomic_dump(ok_graphs, shard_path)
        print(f"Saved {len(ok_graphs)} metro graphs to {shard_path.name}")

    if succeeded_ids:
        append_done_ids_metro(sorted(succeeded_ids))

    elapsed = time.time() - start_batch
    print(f"✅ Metro batch finished in {elapsed:.2f}s "
          f"(ok: {len(succeeded_ids)}, with_graphs: {len(ok_graphs)}, fail: {len(failed_ids)})")





In [4]:
run_metro_batch()

Metro: processing 24215 apartments (remaining after: 0)
Metro shard: Graph_data\METROSHARD_20250826_091005_1553843137-1548097259.pkl
[1/24215] apt=2861904484 -> OK (none)
[2/24215] apt=2831965424 -> OK (none)
[3/24215] apt=1553843137 -> OK (none)
[4/24215] apt=1583915323 -> OK (none)
[5/24215] apt=2854009162 -> OK (graph)
[6/24215] apt=2852968410 -> OK (graph)
[7/24215] apt=2863800448 -> OK (graph)
[8/24215] apt=1590619687 -> OK (graph)
[9/24215] apt=2755339950 -> OK (graph)
[10/24215] apt=2862503462 -> OK (graph)
[11/24215] apt=1586715195 -> OK (graph)
[12/24215] apt=1562186953 -> OK (graph)
[13/24215] apt=2843398052 -> OK (graph)
[14/24215] apt=2842876652 -> OK (none)
[15/24215] apt=2854294366 -> OK (graph)
[16/24215] apt=1552415353 -> OK (graph)
[17/24215] apt=2856072040 -> OK (none)
[18/24215] apt=2864123692 -> OK (graph)
[19/24215] apt=2643529294 -> OK (none)
[20/24215] apt=2851480890 -> OK (none)
[21/24215] apt=1586662767 -> OK (graph)
[22/24215] apt=1586204259 -> OK (none)
[23/2

In [5]:
# Optional: quick inspect last shard (Metro)

def inspect_latest_metro_shard(n_last: int = 10):
    shards = list_shards_metro()
    if not shards:
        print("No metro shards.")
        return
    latest = shards[-1]
    print(f"Latest metro shard: {latest.name}")
    d = pickle.load(open(latest, "rb"))
    ids = list(d.keys())[-n_last:]
    for aid in ids:
        g = d[aid]
        print(f"apt {aid}: x={tuple(g.x.shape)}, edges={tuple(g.edge_index.shape)}, attr={tuple(g.edge_attr.shape)}")



In [6]:
inspect_latest_metro_shard(10)

Latest metro shard: METROSHARD_20250826_091005_1553843137-1548097259.pkl
apt 1542955233: x=(4, 9), edges=(2, 3), attr=(3, 1)
apt 2794613142: x=(5, 9), edges=(2, 4), attr=(4, 1)
apt 2859481394: x=(3, 9), edges=(2, 2), attr=(2, 1)
apt 2744761146: x=(3, 9), edges=(2, 2), attr=(2, 1)
apt 2863559152: x=(6, 9), edges=(2, 5), attr=(5, 1)
apt 2856615258: x=(3, 9), edges=(2, 2), attr=(2, 1)
apt 2860713782: x=(5, 9), edges=(2, 4), attr=(4, 1)
apt 2853462940: x=(3, 9), edges=(2, 2), attr=(2, 1)
apt 2852083842: x=(5, 9), edges=(2, 4), attr=(4, 1)
apt 2851992762: x=(3, 9), edges=(2, 2), attr=(2, 1)
