In [27]:
# Notebook: Unsupervised GNN embeddings for POI / Metro / Bus (Python 3.9)
# ======================================================================
# Goals
# - Train a tiny self-supervised encoder on exported star graphs
# - Produce D-dim embeddings (8–16) per (apartment × context)
# - Save a CSV: one row per apartment, 13 columns with JSON-serialized vectors
# - Keep None where a context has no graph (we'll impute later if desired)
#
# Contexts: 11 POI classes (from 'shard_*.pkl') + Metro (METROSHARD_*.pkl) + Bus (BUSSHARD_*.pkl)
# Python 3.9 compatible typing & APIs.

# ============================
# Cell 1 — Imports & Config
# ============================
import os, json, math, time, pickle
from typing import Optional, Dict, List, Tuple, Set
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# Embedding size (8–16)
EMB_DIM = 12
HIDDEN = 32
EPOCHS = 5          # small & fast; increase if needed
BATCH_SIZE = 128    # number of graphs per batch (increase if RAM allows)
LR = 1e-3

# Input shards directory
SHARD_DIR = Path('Graph_data')

# Provide explicit shard file names (as given by the user)
GENERAL_SHARDS = [
    'shard_20250823_193931_1555433697-1539080245.pkl',
    'shard_20250823_195306_2853468746-2862582048.pkl',
    'shard_20250823_212915_1575719893-1579917449.pkl',
    'shard_20250823_224953_1586671181-1586639873.pkl',
    'shard_20250824_001026_1584809495-1548097259.pkl',
]
METRO_SHARDS = [
    'METROSHARD_20250825_211657_1555433697-2854216564.pkl',
    'METROSHARD_20250826_091005_1553843137-1548097259.pkl',
]
BUS_SHARDS = [
    'BUSSHARD_20250826_122800_1555433697-1584388845.pkl',
    'BUSSHARD_20250826_153430_2862820058-1548097259.pkl',
]

# Output CSV path
OUT_CSV = 'apartment_embeddings.csv'

# Context keys (13 total)
CLASSES = [
    'sport_and_leisure','medical','education_prim','veterinary','food_and_drink_stores',
    'arts_and_entertainment','food_and_drink','park_like','security','religion','education_sup'
]
CONTEXT_KEYS = CLASSES + ['metro', 'bus']

Device: cuda


In [48]:
# ================================
# Cell 2 — Load & unify graph items
# ================================
# We will build a list of items: (apt_id, context_key, Data)
# Data may have different feature widths; we'll pad to the max later.

def load_pickle(path: Path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def collect_items() -> List[Tuple[int, str, Data]]:
    items: List[Tuple[int, str, Data]] = []

    # General POI shards: dict[int -> dict[str -> Data|None]]
    for name in GENERAL_SHARDS:
        p = SHARD_DIR / name
        if not p.exists():
            print(f"[warn] missing general shard: {p}")
            continue
        part = load_pickle(p)
        for apt_id, gdict in part.items():
            for cls in CLASSES:
                g = gdict.get(cls)
                if g is None:
                    continue
                # ensure a context key for downstream
                g.poi_tag = cls
                items.append((apt_id, cls, g))

    # Metro shards: dict[int -> Data]
    for name in METRO_SHARDS:
        p = SHARD_DIR / name
        if not p.exists():
            print(f"[warn] missing metro shard: {p}")
            continue
        part = load_pickle(p)
        for apt_id, g in part.items():
            if isinstance(g, Data):
                g.poi_tag = 'metro'
                items.append((apt_id, 'metro', g))

    # Bus shards: dict[int -> Data]
    for name in BUS_SHARDS:
        p = SHARD_DIR / name
        if not p.exists():
            print(f"[warn] missing bus shard: {p}")
            continue
        part = load_pickle(p)
        for apt_id, g in part.items():
            if isinstance(g, Data):
                g.poi_tag = 'bus'
                items.append((apt_id, 'bus', g))

    print(f"Collected items: {len(items)}")
    return items


items = collect_items()

# Determine common feature width for padding
feat_dims = [g.x.size(1) for (_, _, g) in items]
COMMON_F = max(feat_dims) if feat_dims else 0
# Load apt coords and standardize
DF_DEPTOS = pd.read_csv('Datasets/dataset_final.csv')[['id', 'latitud', 'longitud']].dropna()
lat_mean, lat_std = DF_DEPTOS['latitud'].mean(), DF_DEPTOS['latitud'].std() + 1e-8
lon_mean, lon_std = DF_DEPTOS['longitud'].mean(), DF_DEPTOS['longitud'].std() + 1e-8

APT2COORD = {int(r.id): ((r.latitud - lat_mean)/lat_std, (r.longitud - lon_mean)/lon_std)
             for r in DF_DEPTOS.itertuples(index=False)}

# New input width = base COMMON_F + 1 (distance scalar) + 2 (lat_z, lon_z)
INPUT_DIM = COMMON_F + 1 + 2
print("COMMON_F:", COMMON_F, "| INPUT_DIM:", INPUT_DIM)


Collected items: 307255
COMMON_F: 9 | INPUT_DIM: 12


In [49]:
# ================================
# Cell 3 — Dataset & collate utils
# ================================
def augment_with_distance_feature(g: Data) -> Data:
    num_nodes = g.num_nodes
    dist_feat = torch.zeros(num_nodes, 1, dtype=g.x.dtype)
    if g.edge_attr is not None and g.edge_index is not None:
        _, dst = g.edge_index
        w = g.edge_attr.view(-1)
        dist_feat[dst] = w.unsqueeze(1)
    g.x = torch.cat([g.x, dist_feat], dim=1)
    return g

def augment_with_apt_coords(g: Data, apt_id: int, apt_idx: int) -> Data:
    """Append 2-dim feature; fill apt row with (lat_z, lon_z), zeros for others."""
    coords = APT2COORD.get(int(apt_id), (0.0, 0.0))
    add = torch.zeros((g.num_nodes, 2), dtype=g.x.dtype)
    add[apt_idx, 0] = float(coords[0])
    add[apt_idx, 1] = float(coords[1])
    g.x = torch.cat([g.x, add], dim=1)
    return g

class StarItemDataset(Dataset):
    def __init__(self, items: List[Tuple[int, str, Data]], common_f: int):
        self.items = items
        self.common_f = common_f

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Data:
        apt_id, ctx, g = self.items[idx]
        # pad to COMMON_F
        if g.x.size(1) < self.common_f:
            pad_w = self.common_f - g.x.size(1)
            pad = torch.zeros((g.x.size(0), pad_w), dtype=g.x.dtype)
            g = Data(x=torch.cat([g.x, pad], dim=1),
                     edge_index=g.edge_index, edge_attr=g.edge_attr)
        # locate apartment node
        apt_mask = (g.x[:, 0] > 0.5)
        apt_idx = int(torch.nonzero(apt_mask, as_tuple=False)[0].item()) if apt_mask.any() else 0
        # add distance scalar (+1)
        g = augment_with_distance_feature(g)
        # add coords (+2) on apt node only
        g = augment_with_apt_coords(g, int(apt_id), apt_idx)
        # enforce exact INPUT_DIM (pad/trunc for safety)
        f = g.x.size(1)
        if f < INPUT_DIM:
            pad = torch.zeros((g.num_nodes, INPUT_DIM - f), dtype=g.x.dtype)
            g.x = torch.cat([g.x, pad], dim=1)
        elif f > INPUT_DIM:
            g.x = g.x[:, :INPUT_DIM]
        # metadata
        g.apt_id = int(apt_id)
        g.poi_tag = ctx
        g.apt_idx = apt_idx
        return g

dataset = StarItemDataset(items, COMMON_F)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [50]:
# ============================
# Cell 4 — Model (tiny GCN)
# ============================
class TinyGraphConv(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int):
        super().__init__()
        self.conv1 = GraphConv(in_dim, hidden)
        self.conv2 = GraphConv(hidden, out_dim)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.W = nn.Linear(out_dim, out_dim, bias=False)  # bilinear decoder for edges
        # auxiliary heads from apartment embedding:
        self.head_deg   = nn.Linear(out_dim, 1)
        self.head_meanw = nn.Linear(out_dim, 1)

    def forward(self, data):
        x, ei = data.x, data.edge_index
        ew = data.edge_attr.view(-1) if data.edge_attr is not None else None
        h = self.conv1(x, ei, edge_weight=ew)
        h = self.act(h)
        h = self.dropout(h)
        h = self.conv2(h, ei, edge_weight=ew)
        return h

    def score(self, h_src: torch.Tensor, h_dst: torch.Tensor) -> torch.Tensor:
        return (self.W(h_src) * h_dst).sum(dim=-1)

model = TinyGraphConv(INPUT_DIM, HIDDEN, EMB_DIM).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)

In [39]:
assert getattr(model.conv1.lin_rel, 'weight').shape[1] == INPUT_DIM, \
    f"Model expects {getattr(model.conv1.lin_rel, 'weight').shape[1]} features, but INPUT_DIM={INPUT_DIM}"

In [51]:
# ======================================
# Cell 5 — Train (edge weight reconstruction)
# ======================================
def train_one_epoch() -> float:
    model.train()
    total = 0.0
    count = 0
    for big in loader:  # big: Batch
        big = big.to(DEVICE)
        opt.zero_grad()

        H = model(big)   # [N_total, EMB_DIM]
        src, dst = big.edge_index

        # apartment global indices per graph
        num_graphs = big.ptr.numel() - 1
        apt_global = []
        for i in range(num_graphs):
            base = int(big.ptr[i].item())
            aidx = int(big.apt_idx[i].item())
            apt_global.append(base + aidx)
        apt_global = torch.tensor(apt_global, device=big.x.device, dtype=src.dtype)  # [G]

        # mask edges whose src is the apartment in its graph
        mask = (src.view(-1, 1) == apt_global.view(1, -1)).any(dim=1)
        src_pos = src[mask]
        dst_pos = dst[mask]
        if src_pos.numel() == 0:
            continue

        targets_all = edge_weights_for_training(big)  # [E] in [0,1]
        targets_pos = targets_all[mask]

        # positive edge loss
        scores_pos = model.score(H[src_pos], H[dst_pos])
        preds_pos  = torch.sigmoid(scores_pos)
        loss_pos = F.mse_loss(preds_pos, targets_pos)

        # negatives: sample non-neighbors per graph
        K = 5
        neg_src_list, neg_dst_list = [], []
        for i in range(num_graphs):
            base = int(big.ptr[i].item())
            end  = int(big.ptr[i+1].item())
            a_gi = base + int(big.apt_idx[i].item())
            nei_mask = (src_pos == a_gi)
            true_dsts = set(dst_pos[nei_mask].tolist())
            candidates = [j for j in range(base, end) if j != a_gi and j not in true_dsts]
            if not candidates:
                continue
            pick = np.random.choice(candidates, size=min(K, len(candidates)), replace=False)
            neg_src_list.extend([a_gi] * len(pick))
            neg_dst_list.extend([int(x) for x in pick])

        if neg_src_list:
            neg_src = torch.tensor(neg_src_list, device=big.x.device, dtype=src.dtype)
            neg_dst = torch.tensor(neg_dst_list, device=big.x.device, dtype=dst.dtype)
            scores_neg = model.score(H[neg_src], H[neg_dst])
            preds_neg  = torch.sigmoid(scores_neg)
            loss_neg = F.mse_loss(preds_neg, torch.zeros_like(preds_neg))
        else:
            loss_neg = torch.tensor(0.0, device=big.x.device)

        # --- auxiliary graph-level targets from apartment embedding
        # deg: count of outgoing apt edges per graph
        deg_targets = []
        meanw_targets = []
        for i in range(num_graphs):
            base = int(big.ptr[i].item())
            a_gi = base + int(big.apt_idx[i].item())
            sel = (src_pos == a_gi)
            deg_i = int(sel.sum().item())
            deg_targets.append(deg_i)
            if deg_i > 0:
                meanw_targets.append(float(targets_pos[sel].mean().item()))
            else:
                meanw_targets.append(0.0)

        deg_targets   = torch.tensor(deg_targets, device=big.x.device, dtype=torch.float).view(-1, 1)
        meanw_targets = torch.tensor(meanw_targets, device=big.x.device, dtype=torch.float).view(-1, 1)

        # normalize degree for stability (log1p)
        deg_targets_norm = torch.log1p(deg_targets) / 4.0  # rough scale
        apt_h = H[apt_global]  # [G, EMB_DIM]

        pred_deg   = model.head_deg(apt_h)
        pred_meanw = torch.sigmoid(model.head_meanw(apt_h))

        loss_deg   = F.mse_loss(pred_deg, deg_targets_norm)
        loss_meanw = F.mse_loss(pred_meanw, meanw_targets)

        loss = loss_pos + 0.5 * loss_neg + 0.2 * loss_deg + 0.2 * loss_meanw

        loss.backward()
        opt.step()

        total += float(loss.item())
        count += 1

    return total / max(count, 1)

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    loss = train_one_epoch()
    print(f"Epoch {epoch:02d} | loss={loss:.6f} | {time.time()-t0:.2f}s")


Epoch 01 | loss=0.031822 | 662.06s
Epoch 02 | loss=0.027477 | 500.53s
Epoch 03 | loss=0.027400 | 519.82s
Epoch 04 | loss=0.027402 | 550.55s
Epoch 05 | loss=0.027386 | 545.43s


In [52]:
# ======================================
# Cell 6 — Inference: per (apt, context) embedding
# ======================================
@torch.no_grad()
def embed_graph(g: Data, apt_id: int) -> np.ndarray:
    # pad to COMMON_F
    if g.x.size(1) < COMMON_F:
        pad_w = COMMON_F - g.x.size(1)
        pad = torch.zeros((g.x.size(0), pad_w), dtype=g.x.dtype)
        g = Data(x=torch.cat([g.x, pad], dim=1),
                 edge_index=g.edge_index, edge_attr=g.edge_attr)
    # apartment index
    apt_mask = (g.x[:, 0] > 0.5)
    aidx = int(torch.nonzero(apt_mask, as_tuple=False)[0].item()) if apt_mask.any() else 0
    # add distance and coords
    g = augment_with_distance_feature(g)
    g = augment_with_apt_coords(g, int(apt_id), aidx)
    # enforce INPUT_DIM
    f = g.x.size(1)
    if f < INPUT_DIM:
        pad = torch.zeros((g.num_nodes, INPUT_DIM - f), dtype=g.x.dtype)
        g.x = torch.cat([g.x, pad], dim=1)
    elif f > INPUT_DIM:
        g.x = g.x[:, :INPUT_DIM]

    g = g.to(DEVICE)
    H = model(g)
    return H[aidx].cpu().numpy()


# We want one embedding per (apartment × context). Some contexts may be missing.
# We'll collect into: emb_map[apt_id][context] = list or None
emb_map: Dict[int, Dict[str, Optional[List[float]]]] = {}

# Prepare a per-apartment buckets for faster inference
from collections import defaultdict
apt_ctx_graphs: Dict[int, Dict[str, Data]] = defaultdict(dict)
for apt_id, ctx, g in items:
    # pad if needed
    if g.x.size(1) < COMMON_F:
        pad_w = COMMON_F - g.x.size(1)
        pad = torch.zeros((g.x.size(0), pad_w), dtype=g.x.dtype)
        g = Data(x=torch.cat([g.x, pad], dim=1), edge_index=g.edge_index, edge_attr=g.edge_attr)
        g.apt_id = apt_id
        g.poi_tag = ctx
    apt_ctx_graphs[apt_id][ctx] = g

# Collect all apartment IDs seen across shards
all_apt_ids: Set[int] = set([aid for (aid, _, _) in items])

print('Total apartments with at least one context:', len(all_apt_ids))

# Inference loop (single-threaded, stable)
model.eval()
for aid in sorted(all_apt_ids):
    emb_map[aid] = {}
    for ctx in CONTEXT_KEYS:
        g = apt_ctx_graphs[aid].get(ctx)
        if g is None:
            emb_map[aid][ctx] = None
        else:
            vec = embed_graph(g, aid)
            # store as python list for JSON serialization
            emb_map[aid][ctx] = [float(x) for x in vec.tolist()]

Total apartments with at least one context: 25211


In [53]:
# ======================================
# Cell 7 — Build CSV with 13 columns of JSON vectors
# ======================================

rows = []
for aid in sorted(emb_map.keys()):
    row = {'id': aid}
    for ctx in CONTEXT_KEYS:
        val = emb_map[aid].get(ctx)
        row[f'emb_{ctx}'] = None if val is None else json.dumps(val)
    rows.append(row)

out_df = pd.DataFrame(rows)
print('Output shape:', out_df.shape)
out_df.to_csv(OUT_CSV, index=False)
print('Saved CSV ->', OUT_CSV)

# Quick peek
print(out_df.head(3))


Output shape: (25211, 14)
Saved CSV -> apartment_embeddings.csv
           id                              emb_sport_and_leisure  \
0  1359204515  [0.11949535459280014, 8.156523108482361e-05, -...   
1  1366496843  [0.14379562437534332, 0.0005593057721853256, -...   
2  1367599797  [0.2016446441411972, -1.3947486877441406e-05, ...   

                                         emb_medical  \
0  [0.11949535459280014, 8.156523108482361e-05, -...   
1  [0.14379562437534332, 0.0005593057721853256, -...   
2  [0.2016446441411972, -1.3947486877441406e-05, ...   

                                  emb_education_prim  \
0  [0.11949535459280014, 8.156523108482361e-05, -...   
1  [0.14379562437534332, 0.0005593057721853256, -...   
2  [0.2016446441411972, -1.3947486877441406e-05, ...   

                                      emb_veterinary  \
0  [0.11949535459280014, 8.156523108482361e-05, -...   
1  [0.14379562437534332, 0.0005593057721853256, -...   
2  [0.2016446441411972, -1.3947486877441406e-

In [58]:
out_df['emb_medical'].unique()

array(['[0.11949535459280014, 8.156523108482361e-05, -0.018419355154037476, 0.1900814026594162, -1.2534259557724, -0.4860982894897461, 0.4763750731945038, 0.7971335649490356, 1.1067334413528442, 0.41799676418304443, -0.143602192401886, 0.6868269443511963]',
       '[0.14379562437534332, 0.0005593057721853256, -0.047350674867630005, 0.1632624715566635, -1.1872875690460205, -0.4749646484851837, 0.45391929149627686, 0.7937610149383545, 1.0362839698791504, 0.36910170316696167, -0.07874393463134766, 0.6427090167999268]',
       '[0.2016446441411972, -1.3947486877441406e-05, -0.0971163809299469, 0.1933460235595703, -1.1318022012710571, -0.5013339519500732, 0.46635499596595764, 0.7998327016830444, 1.011732816696167, 0.3695237636566162, 0.008165597915649414, 0.6156975030899048]',
       ...,
       '[0.16911165416240692, -0.0002922527492046356, -0.08437836170196533, 0.08540749549865723, -1.1322500705718994, -0.4804706275463104, 0.48595064878463745, 0.7938352823257446, 0.9984339475631714, 0.306

In [59]:
out_df['emb_bus'].unique()

array(['[0.11949535459280014, 8.156523108482361e-05, -0.018419355154037476, 0.1900814026594162, -1.2534259557724, -0.4860982894897461, 0.4763750731945038, 0.7971335649490356, 1.1067334413528442, 0.41799676418304443, -0.143602192401886, 0.6868269443511963]',
       '[0.14379562437534332, 0.0005593057721853256, -0.047350674867630005, 0.1632624715566635, -1.1872875690460205, -0.4749646484851837, 0.45391929149627686, 0.7937610149383545, 1.0362839698791504, 0.36910170316696167, -0.07874393463134766, 0.6427090167999268]',
       '[0.2016446441411972, -1.3947486877441406e-05, -0.0971163809299469, 0.1933460235595703, -1.1318022012710571, -0.5013339519500732, 0.46635499596595764, 0.7998327016830444, 1.011732816696167, 0.3695237636566162, 0.008165597915649414, 0.6156975030899048]',
       ...,
       '[0.16911165416240692, -0.0002922527492046356, -0.08437836170196533, 0.08540749549865723, -1.1322500705718994, -0.4804706275463104, 0.48595064878463745, 0.7938352823257446, 0.9984339475631714, 0.306