In [14]:
import torch
from torch_geometric.data import Data
from neo4j import GraphDatabase
import numpy as np
import pandas as pd
import random, time, pickle, os

In [7]:
"""
Test script: Export one (apartment × POI class) subgraph from Neo4j
to a PyTorch Geometric Data object.

Steps:
1. Query apartment + POI nodes and edges from Neo4j.
2. Build node feature matrix (x), edge_index, edge_attr.
3. Package into torch_geometric.data.Data.
"""

# ---------------- Config ---------------- #
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"

# Radii per category (meters)
RADII = {"CERCA_DE_CAT1": 600.0, "CERCA_DE_CAT2": 1200.0, "CERCA_DE_CAT3": 2400.0}

# Example
TEST_APT_ID = 1548097259        # <-- replace with real id from your df
TEST_CLASS = "medical"    # one of the 11 classes

# ---------------- Cypher Queries ---------------- #
NODE_QUERY = """
MATCH (d:Departamento {id:$apt_id})
RETURN elementId(d) AS id, 'Departamento' AS label,
       d.id AS apt_id, d.latitude AS lat, d.longitude AS lon, null AS cat
UNION
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
RETURN elementId(p) AS id, 'POI' AS label,
       null AS apt_id, null AS lat, null AS lon, p.cat AS cat
"""

REL_QUERY = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
WITH elementId(d) AS source, elementId(p) AS target, type(r) AS t, r.distancia_metros AS dist
WITH source, target,
  CASE t
    WHEN 'CERCA_DE_CAT1' THEN CASE WHEN 1.0 - dist / $r_cat1 > 0.001 THEN 1.0 - dist / $r_cat1 ELSE 0.001 END
    WHEN 'CERCA_DE_CAT2' THEN CASE WHEN 1.0 - dist / $r_cat2 > 0.001 THEN 1.0 - dist / $r_cat2 ELSE 0.001 END
    ELSE CASE WHEN 1.0 - dist / $r_cat3 > 0.001 THEN 1.0 - dist / $r_cat3 ELSE 0.001 END
  END AS weight
RETURN source, target, weight
"""

# ---------------- Export Function ---------------- #
def export_apartment_class_graph(apt_id: int, class_name: str) -> Data:
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    with driver.session(database=DATABASE) as session:
        
        # --- Nodes
        nodes = session.run(
            NODE_QUERY,
            {"apt_id": apt_id, "class": class_name}
        ).data()
        if not nodes:
            print(f"No nodes found for apartment {apt_id}, class {class_name}")
            return None
        node_df = pd.DataFrame(nodes)
        
        # Map neo4j ids -> local indices
        id_map = {nid: i for i, nid in enumerate(node_df["id"].tolist())}
        
        # --- Edges
        edges = session.run(
            REL_QUERY,
            {
                "apt_id": apt_id, "class": class_name,
                "r_cat1": RADII["CERCA_DE_CAT1"],
                "r_cat2": RADII["CERCA_DE_CAT2"],
                "r_cat3": RADII["CERCA_DE_CAT3"],
            }
        ).data()
        if not edges:
            print(f"No edges found for apartment {apt_id}, class {class_name}")
            return None
        edge_df = pd.DataFrame(edges)
        
        # Remap to local indices
        edge_index = torch.tensor([
            [id_map[s] for s in edge_df["source"]],
            [id_map[t] for t in edge_df["target"]],
        ], dtype=torch.long)
        edge_attr = torch.tensor(edge_df["weight"].values, dtype=torch.float).unsqueeze(1)
        
        # --- Node features
        # Apartment node = mark explicitly
        feats = []
        for _, row in node_df.iterrows():
            if row["label"] == "Departamento":
                feats.append([1.0, 0.0, 0.0, 0.0])  # [is_apartment, cat1, cat2, cat3]
            else:
                onehot = [0.0, 0.0, 0.0, 0.0]
                if row["cat"] is not None:
                    try:
                        idx = int(row["cat"])
                        if 0 <= idx < len(onehot):
                            onehot[idx] = 1.0
                    except (ValueError, TypeError):
                        pass  # skip if cat is invalid
                feats.append(onehot)
        x = torch.tensor(np.array(feats), dtype=torch.float)
        
        # Build Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            apt_id=apt_id,
            poi_class=class_name
        )
        return data



In [9]:
df_deptos = pd.read_csv('Datasets/dataset_final.csv')

In [10]:
# ---------------- Test ---------------- #
CLASSES = [
    'sport_and_leisure',
    'medical',
    'education_prim',
    'veterinary',
    'food_and_drink_stores',
    'arts_and_entertainment',
    'food_and_drink',
    'park_like',
    'security',
    'religion',
    'education_sup'
]


In [15]:


OUTPUT_FILE = "apartment_graphs.pkl"
BATCH_SIZE = 10   # change this number freely

# Load progress if exists
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, "rb") as f:
        all_graphs = pickle.load(f)
    done_ids = set(all_graphs.keys())
    print(f"Resuming: {len(done_ids)} apartments already done.")
else:
    all_graphs = {}
    done_ids = set()

# All apartments
apt_ids = df_deptos["id"].tolist()
pending_ids = [i for i in apt_ids if i not in done_ids]

# Limit to one batch
batch_ids = pending_ids[:BATCH_SIZE]
print(f"Processing {len(batch_ids)} apartments in this batch "
      f"(remaining after this: {len(pending_ids)-len(batch_ids)})")

for idx, apt_id in enumerate(batch_ids, 1):
    print(f"\n[{idx}/{len(batch_ids)}] Apartment ID={apt_id}")
    start = time.time()

    graphs = {}
    for cls in CLASSES:
        g = export_apartment_class_graph(apt_id, cls)
        graphs[cls] = g  # can be None for now

    all_graphs[apt_id] = graphs

    # Save progress after each apartment
    with open(OUTPUT_FILE, "wb") as f:
        pickle.dump(all_graphs, f)

    print(f"Done apt {apt_id} in {time.time()-start:.2f} sec "
          f"(progress: {len(all_graphs)}/{len(apt_ids)})")

print("✅ Batch completed and saved.")


Processing 10 apartments in this batch (remaining after this: 25205)

[1/10] Apartment ID=1548097259
Done apt 1548097259 in 28.55 sec (progress: 1/25215)

[2/10] Apartment ID=2732119230
Done apt 2732119230 in 27.99 sec (progress: 2/25215)

[3/10] Apartment ID=1592811585
Done apt 1592811585 in 27.90 sec (progress: 3/25215)

[4/10] Apartment ID=1536666991
No edges found for apartment 1536666991, class education_prim
Done apt 1536666991 in 28.19 sec (progress: 4/25215)

[5/10] Apartment ID=2803401906
Done apt 2803401906 in 27.93 sec (progress: 5/25215)

[6/10] Apartment ID=1543740023
Done apt 1543740023 in 28.07 sec (progress: 6/25215)

[7/10] Apartment ID=1521489447
No edges found for apartment 1521489447, class religion
Done apt 1521489447 in 28.70 sec (progress: 7/25215)

[8/10] Apartment ID=1588684845
Done apt 1588684845 in 28.49 sec (progress: 8/25215)

[9/10] Apartment ID=2849646914
No edges found for apartment 2849646914, class security
No edges found for apartment 2849646914, clas

In [16]:

with open("apartment_graphs.pkl", "rb") as f:
    graphs = pickle.load(f)

print(f"Loaded {len(graphs)} apartments")

# Print first 3 keys (apartment IDs)
print("Some apartment IDs:", list(graphs.keys())[:3])

# Inspect one apartment
apt_id = list(graphs.keys())[0]
apt_data = graphs[apt_id]

print(f"\nApartment {apt_id}:")
for cls, data in apt_data.items():
    if data is None:
        print(f"  {cls}: None (no POIs)")
    else:
        print(f"  {cls}: Data(x={data.x.shape}, edge_index={data.edge_index.shape}, edge_attr={data.edge_attr.shape})")


Loaded 10 apartments
Some apartment IDs: [1548097259, 2732119230, 1592811585]

Apartment 1548097259:
  sport_and_leisure: Data(x=torch.Size([90, 4]), edge_index=torch.Size([2, 89]), edge_attr=torch.Size([89, 1]))
  medical: Data(x=torch.Size([32, 4]), edge_index=torch.Size([2, 31]), edge_attr=torch.Size([31, 1]))
  education_prim: Data(x=torch.Size([9, 4]), edge_index=torch.Size([2, 8]), edge_attr=torch.Size([8, 1]))
  veterinary: Data(x=torch.Size([2, 4]), edge_index=torch.Size([2, 1]), edge_attr=torch.Size([1, 1]))
  food_and_drink_stores: Data(x=torch.Size([13, 4]), edge_index=torch.Size([2, 12]), edge_attr=torch.Size([12, 1]))
  arts_and_entertainment: Data(x=torch.Size([26, 4]), edge_index=torch.Size([2, 25]), edge_attr=torch.Size([25, 1]))
  food_and_drink: Data(x=torch.Size([26, 4]), edge_index=torch.Size([2, 25]), edge_attr=torch.Size([25, 1]))
  park_like: Data(x=torch.Size([7, 4]), edge_index=torch.Size([2, 6]), edge_attr=torch.Size([6, 1]))
  security: Data(x=torch.Size([8, 

In [18]:
import pickle
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool

# ---------------- Load the pickle file ---------------- #
with open("apartment_graphs.pkl", "rb") as f:
    apartment_graphs = pickle.load(f)

# ---------------- Small GNN model ---------------- #
class SimpleGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        # Two-layer GraphSAGE
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        if batch is None:
            # Single graph → make batch of zeros
            batch = torch.zeros(x.size(0), dtype=torch.long)
        return global_mean_pool(x, batch)  # → one vector per graph


# ---------------- Pick one apartment ---------------- #
apt_id = next(iter(apartment_graphs.keys()))  # just the first apartment
graphs = apartment_graphs[apt_id]

# Feature dimension (4 in your case: is_apartment, cat1, cat2, cat3)
in_channels = graphs[next(iter(graphs))].x.shape[1]
model = SimpleGraphSAGE(in_channels, hidden_channels=16, out_channels=8)

# ---------------- Compute embeddings ---------------- #
embeddings = {}
for cls, g in graphs.items():
    if g is None:
        embeddings[cls] = None
    else:
        with torch.no_grad():
            emb = model(g.x, g.edge_index)
            embeddings[cls] = emb.squeeze().numpy()  # shape (out_channels,)

# ---------------- Inspect ---------------- #
print(f"Apartment {apt_id} embeddings:")
for cls, emb in embeddings.items():
    print(cls, None if emb is None else emb)


Apartment 1548097259 embeddings:
sport_and_leisure [-0.18648088  0.5579107  -0.4131222   0.27406204 -0.35570735  0.40606898
  0.14326589 -0.5241446 ]
medical [-0.20881331  0.31410986 -0.40781352  0.36022988 -0.28139758  0.50777537
  0.04728226 -0.36782062]
education_prim [-0.34600687  0.42414403 -0.4898538   0.15145272 -0.4337186   0.6039778
  0.08607332 -0.5473764 ]
veterinary [-0.27915993  0.22826114 -0.191151    0.27866125 -0.18302271  0.27510607
  0.21248034 -0.57624197]
food_and_drink_stores [-0.27390623  0.4828838  -0.44862312  0.20662722 -0.39463884  0.5065929
  0.11644609 -0.5412748 ]
arts_and_entertainment [-0.17523065  0.2991605  -0.37270257  0.40328738 -0.24081656  0.4592786
  0.05530962 -0.3462455 ]
food_and_drink [-0.32072455  0.48084846 -0.51296467  0.15983051 -0.4509474   0.6042404
  0.08255335 -0.5402578 ]
park_like [-0.34960094  0.40333438 -0.4733036   0.15413487 -0.42034778  0.59179056
  0.09157699 -0.5501465 ]
security [-0.20843989  0.20837197 -0.3466561   0.4006459 

In [13]:
# file: notebooks/test_single_embedding.ipynb
"""
Test script: Generate **one embedding** for **one apartment + one POI class** using Neo4j GDS FastRP.
- Uses `df` dataframe (with column "id") for apartment IDs.
- Writes embedding back to the apartment node as a property (e.g., `medicalContextEmbedding`).
"""

from neo4j import GraphDatabase
import pandas as pd
import numpy as np
import time

# ------------------- Config ------------------- #
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
DATABASE = "neo4jads"

EMBEDDING_DIM = 4
RANDOM_SEED = 42

# Radii per category (meters)
RADII = {"CERCA_DE_CAT1": 600.0, "CERCA_DE_CAT2": 1200.0, "CERCA_DE_CAT3": 2400.0}

# Example: test one class
TEST_CLASS = "medical"
PROP_NAME = "medical_ContextEmbedding"

# ------------------- Cypher ------------------- #
# Debug-friendly: add RETURNs to check matches
NODE_QUERY = """
MATCH (d:Departamento {id:$apt_id})
RETURN id(d) AS id, labels(d) AS labels
UNION
MATCH (d:Departamento {id:$apt_id})-[:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
RETURN id(p) AS id, labels(p) AS labels
"""

REL_QUERY = """
MATCH (d:Departamento {id:$apt_id})-[r:CERCA_DE_CAT1|CERCA_DE_CAT2|CERCA_DE_CAT3]->(p:POI)
WHERE p.class = $class
WITH id(d) AS source, id(p) AS target, type(r) AS t, r.distancia_metros AS dist
WITH source, target,
  CASE t
    WHEN 'CERCA_DE_CAT1' THEN CASE WHEN (1.0 - dist / $r_cat1) < 0.0 THEN 0.0 ELSE (1.0 - dist / $r_cat1) END
    WHEN 'CERCA_DE_CAT2' THEN CASE WHEN (1.0 - dist / $r_cat2) < 0.0 THEN 0.0 ELSE (1.0 - dist / $r_cat2) END
    ELSE CASE WHEN (1.0 - dist / $r_cat3) < 0.0 THEN 0.0 ELSE (1.0 - dist / $r_cat3) END
  END AS weight
RETURN source, target, weight
"""

DROP_GRAPH = "CALL gds.graph.exists($graph_name) YIELD exists WITH exists WHERE exists CALL gds.graph.drop($graph_name) YIELD graphName RETURN graphName"

PROJECT_GRAPH = """
CALL gds.graph.project.cypher(
  $graph_name,
  $node_query,
  $rel_query,
  {
    parameters: {apt_id:$apt_id, class:$class, r_cat1:$r_cat1, r_cat2:$r_cat2, r_cat3:$r_cat3}
  }
)
YIELD graphName, nodeCount, relationshipCount
RETURN graphName, nodeCount, relationshipCount
"""

FASTRP_STREAM = """
CALL gds.fastRP.stream($graph_name, {
  embeddingDimension: $dim,
  randomSeed: $seed,
  relationshipWeightProperty: 'weight'
})
YIELD nodeId, embedding
WITH gds.util.asNode(nodeId) AS n, embedding
WHERE n:Departamento
SET n[$prop_name] = embedding
RETURN n.id AS apartment_id, size(embedding) AS dim
"""

# ------------------- Test Run ------------------- #

def run_single_embedding(apt_id: str, class_name: str, prop_name: str):
    graph_name = f"g_test_{apt_id}_{class_name}"

    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    with driver.session(database=DATABASE) as session:
        # Clean old graph if exists
        session.run(DROP_GRAPH, {"graph_name": graph_name}).consume()

        # Project graph
        try:
            proj = session.run(
                PROJECT_GRAPH,
                {
                    "graph_name": graph_name,
                    "node_query": NODE_QUERY,
                    "rel_query": REL_QUERY,
                    "apt_id": apt_id,
                    "class": class_name,
                    "r_cat1": RADII["CERCA_DE_CAT1"],
                    "r_cat2": RADII["CERCA_DE_CAT2"],
                    "r_cat3": RADII["CERCA_DE_CAT3"],
                },
            ).single()
        except Exception as e:
            print(f"Projection failed for apt {apt_id}, class {class_name}: {e}")
            return

        if proj is None:
            print(f"No nodes found for apartment {apt_id} and class {class_name}")
            return

        print(f"Projected: {proj['graphName']} with {proj['nodeCount']} nodes and {proj['relationshipCount']} rels")

        # Run FastRP stream and write embedding
        result = session.run(
            FASTRP_STREAM,
            {
                "graph_name": graph_name,
                "dim": EMBEDDING_DIM,
                "seed": RANDOM_SEED,
                "prop_name": prop_name,
            },
        ).single()
        if result:
            print(f"Apartment {result['apartment_id']} got embedding of dim {result['dim']}")

        # Drop the graph
        session.run(DROP_GRAPH, {"graph_name": graph_name}).consume()

    driver.close()


# ------------------- Example Execution ------------------- #
# Pick one apartment ID from df
test_apartment_id = int(df.loc[0, "id"])
print(f"Testing apartment id: {test_apartment_id}")
run_single_embedding(test_apartment_id, TEST_CLASS, PROP_NAME)


Testing apartment id: 1548097259




Projected: g_test_1548097259_medical with 32 nodes and 31 rels
Apartment 1548097259 got embedding of dim 4
