In [40]:
# Cell 1 — Imports & Config
import os, time, pickle, random
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable, TransientError
from concurrent.futures import ThreadPoolExecutor, as_completed

# Neo4j connection
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"

# Radii per category (meters)
RADII = {"CERCA_DE_CAT1": 600.0, "CERCA_DE_CAT2": 1200.0, "CERCA_DE_CAT3": 2400.0}

# Classes (11)
CLASSES = [
'sport_and_leisure','medical','education_prim','veterinary','food_and_drink_stores',
'arts_and_entertainment','food_and_drink','park_like','security','religion','education_sup'
]

In [41]:
# Cell 2 — Cypher queries
NODE_QUERY = """
MATCH (d:Departamento {id:$apt_id})
RETURN elementId(d) AS id, 'Departamento' AS label,
d.id AS apt_id, d.latitude AS lat, d.longitude AS lon, null AS cat
UNION
MATCH (d:Departamento {id:$apt_id})-[:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
RETURN elementId(p) AS id, 'POI' AS label,
null AS apt_id, null AS lat, null AS lon, p.cat AS cat
"""

REL_QUERY = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
WITH elementId(d) AS source, elementId(p) AS target, type(r) AS t, r.distancia_metros AS dist
WITH source, target,
CASE t
WHEN 'CERCA_DE_CAT1' THEN CASE WHEN 1.0 - dist / $r_cat1 > 0.001 THEN 1.0 - dist / $r_cat1 ELSE 0.001 END
WHEN 'CERCA_DE_CAT2' THEN CASE WHEN 1.0 - dist / $r_cat2 > 0.001 THEN 1.0 - dist / $r_cat2 ELSE 0.001 END
ELSE CASE WHEN 1.0 - dist / $r_cat3 > 0.001 THEN 1.0 - dist / $r_cat3 ELSE 0.001 END
END AS weight
RETURN source, target, weight
"""

In [None]:
# Cell 3 — Export function (session-aware)
def export_apartment_class_graph(apt_id: int, class_name: str, session=None) -> Data | None:
    """Return PyG Data for (apt × class), or None if no nodes/edges.
    In threaded mode we pass a session; for ad-hoc tests session can be None.
    """
    # If no session provided, open a temporary one (kept for developer convenience)
    if session is None:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as s:
                return export_apartment_class_graph(apt_id, class_name, session=s)
        finally:
            driver.close()

    nodes = session.run(NODE_QUERY, {"apt_id": apt_id, "class": class_name}).data()
    if not nodes:
        return None
    node_df = pd.DataFrame(nodes)
    id_map = {nid: i for i, nid in enumerate(node_df["id"].tolist())}

    edges = session.run(
        REL_QUERY,
        {
            "apt_id": apt_id, "class": class_name,
            "r_cat1": RADII["CERCA_DE_CAT1"],
            "r_cat2": RADII["CERCA_DE_CAT2"],
            "r_cat3": RADII["CERCA_DE_CAT3"],
        },
    ).data()
    if not edges:
        print(f"No edges found for apartment {apt_id}, class {class_name}")
        return None
    edge_df = pd.DataFrame(edges)

    edge_index = torch.tensor([
        [id_map[s] for s in edge_df["source"]],
        [id_map[t] for t in edge_df["target"]],
    ], dtype=torch.long)
    edge_attr = torch.tensor(edge_df["weight"].values, dtype=torch.float).unsqueeze(1)

    feats: list[list[float]] = []
    for _, row in node_df.iterrows():
        if row["label"] == "Departamento":
            feats.append([1.0, 0.0, 0.0, 0.0])
        else:
            onehot = [0.0, 0.0, 0.0, 0.0]
            cat_val = row.get("cat")
            if cat_val is not None:
                try:
                    idx = int(cat_val)
                    if 0 <= idx < 4:
                        onehot[idx] = 1.0
                except (ValueError, TypeError):
                    pass
            feats.append(onehot)

    x = torch.tensor(np.array(feats), dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, apt_id=apt_id, poi_class=class_name)

In [None]:
# Cell 4 - Pickle helpers (no .bak) ---
def safe_load_pickle(path: str) -> dict:
    import os, pickle
    if not os.path.exists(path) or os.path.getsize(path) == 0:
        print(f"[warn] {path} missing/empty; starting fresh.")
        return {}
    try:
        with open(path, "rb") as f:
            return pickle.load(f)
    except (EOFError, pickle.UnpicklingError) as e:
        print(f"[warn] failed to load {path}: {e}. Starting fresh.")
        return {}

def atomic_pickle_dump(obj: dict, path: str) -> None:
    import os, pickle
    tmp = path + ".tmp"
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.flush(); os.fsync(f.fileno())
    os.replace(tmp, path)  # atomic replace

In [None]:
# Cell 5 — Shard + CSV helpers (per-batch pickles)
from pathlib import Path
import glob

SHARD_DIR = Path("Graph_data")
SHARD_DIR.mkdir(exist_ok=True)
DONE_CSV = SHARD_DIR / "done_ids.csv"

def list_shards() -> list[Path]:
    return sorted(SHARD_DIR.glob("shard_*.pkl"))

def load_shard_keys(p: Path) -> set[int]:
    with open(p, "rb") as f:
        d = pickle.load(f)
    return set(d.keys())

def load_done_ids() -> set[int]:
    # Prefer CSV for speed; fallback to scanning shards once
    if DONE_CSV.exists():
        s: set[int] = set()
        with open(DONE_CSV, "r") as f:
            for line in f:
                line = line.strip()
                if line.isdigit():
                    s.add(int(line))
        return s
    done: set[int] = set()
    for p in list_shards():
        try:
            done |= load_shard_keys(p)
        except Exception as e:
            print(f"[warn] skipping shard {p.name}: {e}")
    if done:
        with open(DONE_CSV, "w") as f:
            for aid in sorted(done):
                f.write(f"{aid}\n")
    return done

def append_done_ids(ids: list[int]) -> None:
    if not ids: return
    with open(DONE_CSV, "a") as f:
        for aid in ids:
            f.write(f"{aid}\n")

def atomic_dump(obj: dict, path: Path) -> None:
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.flush(); os.fsync(f.fileno())
    os.replace(tmp, path)

In [None]:
# Cell 6
# Batch/parallel config
OUTPUT_FILE = "apartment_graphs_bottom.pkl" # this machine's file
BATCH_SIZE = 9715 # change freely
MAX_WORKERS = 10 # adjust to CPU/DB capacity
MAX_RETRIES = 2 # apartment-level retries
RETRY_DELAY = 5

In [None]:
# Cell 7 — Load dataset & prepare pending IDs (bottom-up, sharded)
from datetime import datetime

df_deptos = pd.read_csv('Datasets/dataset_final.csv')
done_ids = load_done_ids()
print(f"Resume: {len(done_ids)} apartments already done (from shards/CSV).")

apt_ids = df_deptos['id'].tolist()[::-1]  # bottom-up
pending_ids = [i for i in apt_ids if i not in done_ids]

# Batch window
BATCH_SIZE = min(BATCH_SIZE, len(pending_ids))  # keep your earlier BATCH_SIZE variable
batch_ids = pending_ids[:BATCH_SIZE]
print(f"Processing {len(batch_ids)} apartments (remaining after this: {len(pending_ids) - len(batch_ids)})")

# Name this batch's shard file
if batch_ids:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    shard_name = f"shard_{ts}_{batch_ids[0]}-{batch_ids[-1]}.pkl"
    SHARD_PATH = SHARD_DIR / shard_name
    print(f"Shard file: {SHARD_PATH}")
else:
    SHARD_PATH = None


Resume: 15500 apartments already done (from shards/CSV).
Processing 9715 apartments (remaining after this: 0)
Shard file: Graph_data\shard_20250824_001026_1584809495-1548097259.pkl


In [None]:
# Cell 8 — Worker + retry (Python 3.10+ typing)
def process_apartment_once(apt_id: int) -> tuple[int, dict[str, Data | None] | None, str | None]:
    """One attempt. Returns (apt_id, graphs_dict, error_str)."""
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
        try:
            with driver.session(database=DATABASE) as session:
                graphs: dict[str, Data | None] = {}
                for cls in CLASSES:
                    g = export_apartment_class_graph(apt_id, cls, session=session)
                    graphs[cls] = g  # may be None if no POIs
                return apt_id, graphs, None
        finally:
            driver.close()
    except (ServiceUnavailable, TransientError, Exception) as e:
        return apt_id, None, str(e)

def process_apartment_with_retries(apt_id: int) -> tuple[int, dict[str, Data | None] | None, str | None]:
    for attempt in range(1, MAX_RETRIES + 2):
        aid, graphs, err = process_apartment_once(apt_id)
        if err is None:
            return aid, graphs, None
        if attempt <= MAX_RETRIES:
            time.sleep(RETRY_DELAY)
    return aid, None, err  # final failure


In [None]:
# Cell 9 — Run threaded batch and save one shard + append done CSV
if not batch_ids:
    print("Nothing to do. You're up to date.")
else:
    start_batch = time.time()
    batch_graphs: dict[int, dict[str, Data | None]] = {}

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(process_apartment_with_retries, aid): aid for aid in batch_ids}
        for i, fut in enumerate(as_completed(futs), 1):
            aid = futs[fut]
            try:
                apt_id, graphs, err = fut.result()
            except Exception as e:
                apt_id, graphs, err = aid, None, f"worker crash: {e}"
            ok = (err is None and graphs is not None)
            status = "OK" if ok else f"ERR: {err}"
            print(f"[{i}/{len(batch_ids)}] apt={apt_id} -> {status}")
            if ok:
                batch_graphs[apt_id] = graphs

    # Save one pickle per batch (atomic) and append IDs to CSV
    if batch_graphs:
        atomic_dump(batch_graphs, SHARD_PATH)
        append_done_ids(sorted(batch_graphs.keys()))
        print(f"Saved shard {SHARD_PATH.name} with {len(batch_graphs)} apartments.")

    print(f"✅ Batch finished in {time.time()-start_batch:.2f}s "
          f"(ok: {len(batch_graphs)}, fail: {len(batch_ids)-len(batch_graphs)})")


[1/9715] apt=1492835705 -> OK
[2/9715] apt=1571036265 -> OK
[3/9715] apt=2858448180 -> OK
[4/9715] apt=2853351326 -> OK
[5/9715] apt=2842135468 -> OK
[6/9715] apt=2748151904 -> OK
[7/9715] apt=1584809495 -> OK
[8/9715] apt=1591757711 -> OK
[9/9715] apt=1591526513 -> OK
[10/9715] apt=2861870392 -> OK
[11/9715] apt=2849650700 -> OK
[12/9715] apt=1583773493 -> OK
[13/9715] apt=2853358930 -> OK
[14/9715] apt=2859523048 -> OK
[15/9715] apt=1550510703 -> OK
[16/9715] apt=2800373468 -> OK
[17/9715] apt=1573161813 -> OK
[18/9715] apt=2491360614 -> OK
[19/9715] apt=1537384469 -> OK
[20/9715] apt=2276173886 -> OK
[21/9715] apt=2822237194 -> OK
[22/9715] apt=2784546284 -> OK
[23/9715] apt=2853399828 -> OK
[24/9715] apt=1506693795 -> OK
[25/9715] apt=2840283430 -> OK
[26/9715] apt=1492811657 -> OK
[27/9715] apt=2707862820 -> OK
[28/9715] apt=1570052119 -> OK
[29/9715] apt=2853403412 -> OK
[30/9715] apt=1366496843 -> OK
[31/9715] apt=1592723991 -> OK
[32/9715] apt=2855589112 -> OK
[33/9715] apt=158

In [None]:
# Cell 10 — Verify last N shards (quick summary)
N_SHARDS = 2  # inspect the last N shards
shards = list_shards()
print(f"Total shards: {len(shards)}")
for p in shards[-N_SHARDS:]:
    d = pickle.load(open(p, "rb"))
    print(f"\nShard {p.name}: {len(d)} apartments")
    # show first 2 apts from this shard
    for aid in list(d.keys())[:2]:
        graphs = d[aid]
        none_cnt = sum(1 for v in graphs.values() if v is None)
        print(f"  Apt {aid}: classes={len(graphs)}, none={none_cnt}")


Total shards: 5

Shard shard_20250823_224953_1586671181-1586639873.pkl: 5000 apartments
  Apt 1586671181: classes=11, none=0
  Apt 2783190686: classes=11, none=0

Shard shard_20250824_001026_1584809495-1548097259.pkl: 9715 apartments
  Apt 1492835705: classes=11, none=0
  Apt 1571036265: classes=11, none=0


In [None]:
# Cell 11 — Verify last N apartments from the latest shard in Graph_data/
from pathlib import Path
import pickle
import torch

SHARD_DIR = Path("Graph_data")
N_LAST = 5  # change as needed

# Find newest shard by modification time
shards = sorted(SHARD_DIR.glob("shard_*.pkl"), key=lambda p: p.stat().st_mtime)
if not shards:
    print("No shard files found in Graph_data/.")
else:
    latest = shards[-1]
    print(f"Latest shard: {latest.name} (size: {latest.stat().st_size} bytes)")
    with open(latest, "rb") as f:
        d = pickle.load(f)  # dict[int -> dict[class -> Data|None]]

    print(f"Apartments in shard: {len(d)}")
    # Dict preserves insertion order (likely completion order). Take the last N.
    last_ids = list(d.keys())[-N_LAST:]
    print(f"\nInspecting last {len(last_ids)} apartments in this shard:\n")

    for aid in last_ids:
        graphs = d[aid]
        none_cnt = sum(1 for v in graphs.values() if v is None)
        print(f"Apartment {aid}: classes={len(graphs)}, none={none_cnt}")
        for cls in CLASSES:
            g = graphs.get(cls)
            if g is None:
                print(f"  - {cls:<24} -> None")
            else:
                print(f"  - {cls:<24} -> x={tuple(g.x.shape)}, edges={tuple(g.edge_index.shape)}, attr={tuple(g.edge_attr.shape)}")
        print()


Latest shard: shard_20250824_001026_1584809495-1548097259.pkl (size: 195825707 bytes)
Apartments in shard: 9715

Inspecting last 5 apartments in this shard:

Apartment 2803401906: classes=11, none=0
  - sport_and_leisure        -> x=(122, 4), edges=(2, 121), attr=(121, 1)
  - medical                  -> x=(16, 4), edges=(2, 15), attr=(15, 1)
  - education_prim           -> x=(2, 4), edges=(2, 1), attr=(1, 1)
  - veterinary               -> x=(2, 4), edges=(2, 1), attr=(1, 1)
  - food_and_drink_stores    -> x=(13, 4), edges=(2, 12), attr=(12, 1)
  - arts_and_entertainment   -> x=(21, 4), edges=(2, 20), attr=(20, 1)
  - food_and_drink           -> x=(6, 4), edges=(2, 5), attr=(5, 1)
  - park_like                -> x=(8, 4), edges=(2, 7), attr=(7, 1)
  - security                 -> x=(3, 4), edges=(2, 2), attr=(2, 1)
  - religion                 -> x=(2, 4), edges=(2, 1), attr=(1, 1)
  - education_sup            -> x=(9, 4), edges=(2, 8), attr=(8, 1)

Apartment 1536666991: classes=11, non

#### que los embeddings no le vayan a dar mas importancia a por ejemplo sports and leisure por el mero hecho de haber mas 