In [1]:
# Cell 2 — Imports & Config
import os, time, pickle, random
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, TransientError
from concurrent.futures import ThreadPoolExecutor, as_completed


# Neo4j connection
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"


# Radii per category (meters)
RADII = {"CERCA_DE_CAT1": 600.0, "CERCA_DE_CAT2": 1200.0, "CERCA_DE_CAT3": 2400.0}


# Classes (11)
CLASSES = [
'sport_and_leisure','medical','education_prim','veterinary','food_and_drink_stores',
'arts_and_entertainment','food_and_drink','park_like','security','religion','education_sup'
]

In [2]:
# Cell 3 — Cypher queries
NODE_QUERY = """
MATCH (d:Departamento {id:$apt_id})
RETURN elementId(d) AS id, 'Departamento' AS label,
d.id AS apt_id, d.latitude AS lat, d.longitude AS lon, null AS cat
UNION
MATCH (d:Departamento {id:$apt_id})-[:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
RETURN elementId(p) AS id, 'POI' AS label,
null AS apt_id, null AS lat, null AS lon, p.cat AS cat
"""


REL_QUERY = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
WITH elementId(d) AS source, elementId(p) AS target, type(r) AS t, r.distancia_metros AS dist
WITH source, target,
CASE t
WHEN 'CERCA_DE_CAT1' THEN CASE WHEN 1.0 - dist / $r_cat1 > 0.001 THEN 1.0 - dist / $r_cat1 ELSE 0.001 END
WHEN 'CERCA_DE_CAT2' THEN CASE WHEN 1.0 - dist / $r_cat2 > 0.001 THEN 1.0 - dist / $r_cat2 ELSE 0.001 END
ELSE CASE WHEN 1.0 - dist / $r_cat3 > 0.001 THEN 1.0 - dist / $r_cat3 ELSE 0.001 END
END AS weight
RETURN source, target, weight
"""

In [3]:
# Cell 4 — Export function (session-aware)
def export_apartment_class_graph(apt_id: int, class_name: str, session=None) -> Data | None:
    """Return PyG Data for (apt × class), or None if no nodes/edges.
    In threaded mode we pass a session; for ad-hoc tests session can be None.
    """
    # If no session provided, open a temporary one (kept for developer convenience)
    if session is None:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as s:
                return export_apartment_class_graph(apt_id, class_name, session=s)
        finally:
            driver.close()


    nodes = session.run(NODE_QUERY, {"apt_id": apt_id, "class": class_name}).data()
    if not nodes:
        return None
    node_df = pd.DataFrame(nodes)
    id_map = {nid: i for i, nid in enumerate(node_df["id"].tolist())}


    edges = session.run(
        REL_QUERY,
        {
            "apt_id": apt_id, "class": class_name,
            "r_cat1": RADII["CERCA_DE_CAT1"],
            "r_cat2": RADII["CERCA_DE_CAT2"],
            "r_cat3": RADII["CERCA_DE_CAT3"],
        },
    ).data()
    if not edges:
        print(f"No edges found for apartment {apt_id}, class {class_name}")
        return None
    edge_df = pd.DataFrame(edges)


    edge_index = torch.tensor([
        [id_map[s] for s in edge_df["source"]],
        [id_map[t] for t in edge_df["target"]],
    ], dtype=torch.long)
    edge_attr = torch.tensor(edge_df["weight"].values, dtype=torch.float).unsqueeze(1)


    feats: list[list[float]] = []
    for _, row in node_df.iterrows():
        if row["label"] == "Departamento":
            feats.append([1.0, 0.0, 0.0, 0.0])
        else:
            onehot = [0.0, 0.0, 0.0, 0.0]
            cat_val = row.get("cat")
            if cat_val is not None:
                try:
                    idx = int(cat_val)
                    if 0 <= idx < 4:
                        onehot[idx] = 1.0
                except (ValueError, TypeError):
                    pass
            feats.append(onehot)


    x = torch.tensor(np.array(feats), dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, apt_id=apt_id, poi_class=class_name)

In [None]:
# Batch/parallel config
OUTPUT_FILE = "apartment_graphs_bottom.pkl" # this machine's file
BATCH_SIZE = 1000 # change freely
MAX_WORKERS = 8 # adjust to CPU/DB capacity
MAX_RETRIES = 2 # apartment-level retries
RETRY_DELAY = 5

In [27]:
# Cell 5 — Load dataset & prepare pending IDs (bottom-up)
df_deptos = pd.read_csv('Datasets/dataset_final.csv')

if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, 'rb') as f:
        all_graphs: dict[int, dict[str, Data | None]] = pickle.load(f)
    done_ids = set(all_graphs.keys())
    print(f"Resuming: {len(done_ids)} apartments already done (bottom-up file).")
else:
    all_graphs = {}
    done_ids = set()


apt_ids = df_deptos['id'].tolist()[::-1] # bottom-up
pending_ids = [i for i in apt_ids if i not in done_ids]
batch_ids = pending_ids[:BATCH_SIZE]
print(f"Processing {len(batch_ids)} apartments in this batch (remaining after this: {len(pending_ids)-len(batch_ids)})")

Resuming: 4950 apartments already done (bottom-up file).
Processing 5000 apartments in this batch (remaining after this: 15265)


In [28]:
def process_apartment_once(apt_id: int) -> tuple[int, dict[str, Data | None] | None, str | None]:
    """One attempt. Returns (apt_id, graphs_dict, error_str)."""
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as session:
                graphs: dict[str, Data | None] = {}
                for cls in CLASSES:
                    g = export_apartment_class_graph(apt_id, cls, session=session)
                    graphs[cls] = g # may be None if no POIs
                return apt_id, graphs, None
        finally:
            driver.close()
    except (ServiceUnavailable, TransientError, Exception) as e:
        return apt_id, None, str(e)


def process_apartment_with_retries(apt_id: int) -> tuple[int, dict[str, Data | None] | None, str | None]:
    for attempt in range(1, MAX_RETRIES + 2):
        aid, graphs, err = process_apartment_once(apt_id)
        if err is None:
            return aid, graphs, None
        if attempt <= MAX_RETRIES:
            time.sleep(RETRY_DELAY)
    return aid, None, err # final failure

In [29]:
start_batch = time.time()


results: list[tuple[int, dict[str, Data | None] | None, str | None]] = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    futs = {ex.submit(process_apartment_with_retries, aid): aid for aid in batch_ids}
    for i, fut in enumerate(as_completed(futs), 1):
        aid = futs[fut]
        try:
            apt_id, graphs, err = fut.result()
        except Exception as e:
            apt_id, graphs, err = aid, None, f"worker crash: {e}"
        ok = (err is None and graphs is not None)
        status = "OK" if ok else f"ERR: {err}"
        print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> {status}")
        if ok:
            all_graphs[apt_id] = graphs
            with open(OUTPUT_FILE, 'wb') as f:
                pickle.dump(all_graphs, f)


print(f"✅ Batch finished in {time.time()-start_batch:.2f}s. Total saved so far: {len(all_graphs)}")

[1/5000] apt=2859780596 -> OK
[2/5000] apt=2844409034 -> OK
[3/5000] apt=2848023442 -> OK
[4/5000] apt=1569588671 -> OK
[5/5000] apt=1537274739 -> OK
[6/5000] apt=2849630118 -> OK
[7/5000] apt=2861588902 -> OK
[8/5000] apt=2863530320 -> OK
[9/5000] apt=2859410746 -> OK
[10/5000] apt=2836626950 -> OK
[11/5000] apt=2849514214 -> OK
[12/5000] apt=2854169188 -> OK
[13/5000] apt=1584098509 -> OK
[14/5000] apt=1589143899 -> OK
[15/5000] apt=2843365020 -> OK
[16/5000] apt=1589193779 -> OK
[17/5000] apt=2858092214 -> OK
[18/5000] apt=2855701542 -> OK
No edges found for apartment 1589105543, class religion
No edges found for apartment 2842128396, class religion
[19/5000] apt=2843439822 -> OK
No edges found for apartment 2742373722, class veterinary
No edges found for apartment 2742373722, class food_and_drink_stores
No edges found for apartment 2815058264, class religion
No edges found for apartment 2742373722, class food_and_drink
No edges found for apartment 1586650997, class veterinary
No ed

KeyboardInterrupt: 

In [21]:
N_LAST = 10 # change as needed
import pickle


with open(OUTPUT_FILE, 'rb') as f:
    saved: dict[int, dict[str, Data | None]] = pickle.load(f)


print(f"Total saved apartments: {len(saved)}")


# Dicts preserve insertion order; take the last N inserted IDs
last_ids = list(saved.keys())[-N_LAST:]
print(f"Inspecting last {len(last_ids)} apartments (most recent first shown below):")


for aid in last_ids:
    graphs = saved[aid]
    # sanity: ensure we have all 11 class keys (some may be None by design)
    missing_keys = [c for c in CLASSES if c not in graphs]
    none_cnt = sum(1 for v in graphs.values() if v is None)
    print(f"Apartment {aid}: classes={len(graphs)} (missing_keys={missing_keys}, none={none_cnt})")
    for cls in CLASSES:
        g = graphs.get(cls)
        if g is None:
            print(f" - {cls:<24} -> None")
        else:
            print(f" - {cls:<24} -> x={tuple(g.x.shape)}, edges={tuple(g.edge_index.shape)}, attr={tuple(g.edge_attr.shape)}")

Total saved apartments: 150
Inspecting last 10 apartments (most recent first shown below):
Apartment 2829175322: classes=11 (missing_keys=[], none=0)
 - sport_and_leisure        -> x=(79, 4), edges=(2, 78), attr=(78, 1)
 - medical                  -> x=(13, 4), edges=(2, 12), attr=(12, 1)
 - education_prim           -> x=(8, 4), edges=(2, 7), attr=(7, 1)
 - veterinary               -> x=(5, 4), edges=(2, 4), attr=(4, 1)
 - food_and_drink_stores    -> x=(15, 4), edges=(2, 14), attr=(14, 1)
 - arts_and_entertainment   -> x=(12, 4), edges=(2, 11), attr=(11, 1)
 - food_and_drink           -> x=(8, 4), edges=(2, 7), attr=(7, 1)
 - park_like                -> x=(7, 4), edges=(2, 6), attr=(6, 1)
 - security                 -> x=(6, 4), edges=(2, 5), attr=(5, 1)
 - religion                 -> x=(3, 4), edges=(2, 2), attr=(2, 1)
 - education_sup            -> x=(13, 4), edges=(2, 12), attr=(12, 1)
Apartment 2856021976: classes=11 (missing_keys=[], none=0)
 - sport_and_leisure        -> x=(8, 4)