In [1]:
"""
Merge per-context baseline embeddings into your main dataset.
- Input A: base CSV (e.g., 'Datasets/dataset_final.csv') with apartment rows and 'id'.
- Input B: embeddings CSV from the previous step (e.g., 'apartment_embeddings_per_context.csv'),
           columns: id, emb_<context> (JSON list or empty/NaN).

You can:
1) Keep compact schema: one column per context (`emb_<ctx>` as JSON string).
2) Expand to numeric columns: `emb_<ctx>_d0 ... emb_<ctx>_d11` (12 dims), with optional zero-imputation.

Outputs:
- datasets/with_embeddings_compact.csv
- datasets/with_embeddings_expanded.csv  (if EXPAND = True)
"""

# =============================
# Cell 1 — Imports & settings
# =============================
import os, json, math
from pathlib import Path
from typing import List, Optional, Any

import numpy as np
import pandas as pd

# Paths
BASE_DATASET = Path("Datasets/dataset_final.csv")
EMB_CSV      = Path("apartment_embeddings_per_context.csv")
OUT_DIR      = Path("datasets")
OUT_DIR.mkdir(exist_ok=True)

# Contexts and embedding size
CLASSES = [
    'sport_and_leisure','medical','education_prim','veterinary',
    'food_and_drink_stores','arts_and_entertainment','food_and_drink',
    'park_like','security','religion','education_sup'
]
ALL_CONTEXTS = CLASSES + ['metro','bus']
EMB_DIM = 12

# Behaviors
EXPAND = True          # also create wide numeric columns
IMPUTE_MISSING = False # if True, replace missing with zero vectors in the expanded output


In [2]:
# =============================
# Cell 2 — Helpers
# =============================

def is_nan_like(x: Any) -> bool:
    # pandas may give float('nan') for empty cells
    try:
        return x is None or (isinstance(x, float) and math.isnan(x)) or (isinstance(x, str) and x.strip() == "")
    except Exception:
        return False


def parse_vec(cell: Any, dim: int) -> Optional[List[float]]:
    """Parse a JSON list cell → python list[float]; return None if missing/invalid.
    Ensures length==dim when returned (truncate or pad zeros if needed)."""
    if is_nan_like(cell):
        return None
    try:
        if isinstance(cell, list):
            vec = [float(v) for v in cell]
        elif isinstance(cell, str):
            vec = json.loads(cell)
            vec = [float(v) for v in vec]
        else:
            return None
    except Exception:
        return None
    # normalize length
    if len(vec) > dim:
        vec = vec[:dim]
    elif len(vec) < dim:
        vec = vec + [0.0] * (dim - len(vec))
    return vec


def expand_embeddings(df: pd.DataFrame, ctx_cols: List[str], dim: int, impute_missing: bool) -> pd.DataFrame:
    """Expand emb_<ctx> columns into numeric columns emb_<ctx>_d0..d{dim-1}.
    If impute_missing=True, fill missing with zeros; else leave as NaN."""
    out = df.copy()
    for col in ctx_cols:
        base = col  # e.g., 'emb_medical'
        # prepare target columns
        tgt_cols = [f"{base}_d{i}" for i in range(dim)]
        # initialize with NaN
        for c in tgt_cols:
            out[c] = np.nan
        # fill
        for idx, cell in out[col].items():
            vec = parse_vec(cell, dim)
            if vec is None:
                if impute_missing:
                    out.loc[idx, tgt_cols] = [0.0] * dim
                # else keep NaN
            else:
                out.loc[idx, tgt_cols] = vec
    return out

In [3]:
# =============================
# Cell 3 — Load & merge (compact)
# =============================
base = pd.read_csv(BASE_DATASET)
emb  = pd.read_csv(EMB_CSV)

# sanity
if 'id' not in base.columns:
    raise ValueError("Base dataset must contain 'id' column")
if 'id' not in emb.columns:
    raise ValueError("Embeddings CSV must contain 'id' column")

# Keep only expected emb_* cols (ignore extras if any)
emb_cols = [f"emb_{c}" for c in ALL_CONTEXTS if f"emb_{c}" in emb.columns]
merged = base.merge(emb[['id'] + emb_cols], on='id', how='left')

# Save compact version (JSON strings or empty)
out_compact = OUT_DIR / "dataset_embeddings_compact.csv"
merged.to_csv(out_compact, index=False)
print(f"✅ Saved compact dataset: {out_compact}  shape={merged.shape}")


✅ Saved compact dataset: datasets\dataset_embeddings_compact.csv  shape=(25215, 40)


In [4]:


# =============================
# Cell 4 — Optional: expand to numeric
# =============================
if EXPAND:
    wide = expand_embeddings(merged, emb_cols, EMB_DIM, impute_missing=IMPUTE_MISSING)
    out_expanded = OUT_DIR / "dataset_embeddings_expanded.csv"
    wide.to_csv(out_expanded, index=False)
    print(f"✅ Saved expanded dataset: {out_expanded}  shape={wide.shape}")

    # Tiny health check: report NaN rates per expanded block
    for ctx in ALL_CONTEXTS:
        base = f"emb_{ctx}"
        if base in emb_cols:
            cols = [f"{base}_d{i}" for i in range(EMB_DIM)]
            if all(c in wide.columns for c in cols):
                n_missing_rows = wide[cols].isna().all(axis=1).sum()
                print(f"{base}: missing rows (all NaN) = {n_missing_rows}")


  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] = np.nan
  out[c] =

✅ Saved expanded dataset: datasets\dataset_embeddings_expanded.csv  shape=(25215, 196)
emb_sport_and_leisure: missing rows (all NaN) = 92
emb_medical: missing rows (all NaN) = 128
emb_education_prim: missing rows (all NaN) = 696
emb_veterinary: missing rows (all NaN) = 2233
emb_food_and_drink_stores: missing rows (all NaN) = 521
emb_arts_and_entertainment: missing rows (all NaN) = 208
emb_food_and_drink: missing rows (all NaN) = 1010
emb_park_like: missing rows (all NaN) = 940
emb_security: missing rows (all NaN) = 651
emb_religion: missing rows (all NaN) = 3924
emb_education_sup: missing rows (all NaN) = 634
emb_metro: missing rows (all NaN) = 8463
emb_bus: missing rows (all NaN) = 1040


In [16]:
# ======================================
# Merge base dataset + POI + Metro/Bus embeddings
# ======================================
import pandas as pd

# Load datasets
df_base  = pd.read_csv("Datasets/dataset_final.csv")
df_poi   = pd.read_csv("embeddings_poi_from_shards.csv")
df_trans = pd.read_csv("embeddings_metro_bus_from_shards.csv")

print("Base dataset:", df_base.shape)
print("POI embeddings:", df_poi.shape)
print("Transport embeddings:", df_trans.shape)

# Merge all by apartment ID (called 'id' everywhere)
df_all = df_base.merge(df_poi,   on="id", how="left")
df_all = df_all.merge(df_trans, on="id", how="left")

print("Merged dataset:", df_all.shape)

# Save compact version
OUT_CSV = "dataset_embeddings_compact.csv"
df_all.to_csv(OUT_CSV, index=False)
print(f"✅ Saved {OUT_CSV} with {df_all.shape[0]} rows and {df_all.shape[1]} columns")

# Quick sanity check: show first 2 rows of embeddings
sample = df_all.sample(2, random_state=42)
pd.set_option("display.max_colwidth", 120)
print(sample[["id", "emb_medical", "emb_bus", "emb_metro"]])


Base dataset: (25215, 27)
POI embeddings: (25215, 12)
Transport embeddings: (24234, 3)
Merged dataset: (25215, 40)
✅ Saved dataset_embeddings_compact.csv with 25215 rows and 40 columns
               id  \
4048   1584480033   
15723  2863178462   

                                                                                                                   emb_medical  \
4048   [66.0, 1068.72988244259, 285.4775390625, 2368.303466796875, 910.3803405761719, 643.0881731688652, 0.0013646813805907...   
15723  [18.0, 1508.1400451660156, 426.6236267089844, 2368.000732421875, 1925.860595703125, 773.8998472561777, 0.00104769358...   

                                                                                                                       emb_bus  \
4048   [13.0, 272.00402479905347, 71.79084777832031, 377.6366271972656, 234.3451385498047, 87.1877268816821, 0.004460767934...   
15723  [9.0, 235.45552656385632, 89.12711334228516, 391.79022216796875, 200.80645751953125, 89.73231

In [17]:
df = pd.read_csv("Datasets/dataset_embeddings_compact.csv")

In [18]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25215 entries, 0 to 25214
Data columns (total 40 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   id                          25215 non-null  int64  
 1   monto                       25215 non-null  int64  
 2   superficie_t                25215 non-null  float64
 3   dormitorios                 25215 non-null  int64  
 4   dormitorios_faltante        25215 non-null  int64  
 5   banos                       25215 non-null  int64  
 6   banos_faltante              25215 non-null  int64  
 7   antiguedad                  25215 non-null  int64  
 8   antiguedad_faltante         25215 non-null  int64  
 9   Or_N                        25215 non-null  int64  
 10  Or_S                        25215 non-null  int64  
 11  Or_E                        25215 non-null  int64  
 12  Or_O                        25215 non-null  int64  
 13  Or_Faltante                 252

In [19]:
df.iloc[:, -13:].sample(5)

Unnamed: 0,emb_sport_and_leisure,emb_medical,emb_education_prim,emb_veterinary,emb_food_and_drink_stores,emb_arts_and_entertainment,emb_food_and_drink,emb_park_like,emb_security,emb_religion,emb_education_sup,emb_metro,emb_bus
12591,"[27.0, 6.03273868560791, 2.733539342880249, 1.680132269859314, 0.22343476116657257, 0.22653721272945404, 0.001000000...","[60.0, 26.12047004699707, 14.443206787109375, 9.044879913330078, 0.4353411793708801, 0.22627012431621552, 0.01949224...","[15.0, 4.384066581726074, 2.324298143386841, 1.4988279342651367, 0.29227110743522644, 0.2636869549751282, 0.01064758...","[9.0, 2.2140164375305176, 0.8159583806991577, 0.3756501078605652, 0.24600182473659515, 0.17362356185913086, 0.072358...","[70.0, 22.82982063293457, 11.898143768310547, 7.482097625732422, 0.32614028453826904, 0.2522023022174835, 0.00100000...","[141.0, 55.93154525756836, 30.549835205078125, 19.211301803588867, 0.3966776132583618, 0.24354128539562225, 0.004608...","[94.0, 33.053810119628906, 16.894702911376953, 10.392999649047852, 0.35163629055023193, 0.2368180751800537, 0.001000...","[16.0, 4.864889144897461, 2.0395636558532715, 1.1123898029327393, 0.3040555715560913, 0.18714416027069092, 0.0805305...","[25.0, 11.171670913696289, 6.624151229858398, 4.360683441162109, 0.44686684012413025, 0.25549182295799255, 0.0285440...","[7.0, 2.1978070735931396, 0.8873892426490784, 0.4505852162837982, 0.31397244334220886, 0.16790233552455902, 0.143351...","[225.0, 118.62142944335938, 75.36434936523438, 51.919307708740234, 0.5272063612937927, 0.23875948786735535, 0.002381...","[7.0, 0.013779514469206333, 2.8951646527275443e-05, 6.510835959261385e-08, 0.001968502067029476, 0.00051083171274513...","[26.0, 0.10196603834629059, 0.0004630457260645926, 2.5197687136824243e-06, 0.003921770490705967, 0.00155857868958264..."
24731,"[106.0, 29.843183517456055, 14.070626258850098, 8.232409477233887, 0.2815394699573517, 0.2312515825033188, 0.0043511...","[79.0, 28.164804458618164, 15.104934692382812, 9.810725212097168, 0.3565165102481842, 0.25317516922950745, 0.0042942...","[6.0, 1.8973796367645264, 0.7387728095054626, 0.3212207555770874, 0.3162299394607544, 0.15207704901695251, 0.1342272...","[3.0, 1.5634437799453735, 1.148972511291504, 0.9579939246177673, 0.5211479067802429, 0.33375993371009827, 0.12370334...","[41.0, 10.529956817626953, 4.6024088859558105, 2.3702127933502197, 0.2568282186985016, 0.2151584029197693, 0.0028541...","[33.0, 10.406961441040039, 4.621959686279297, 2.4293668270111084, 0.31536245346069336, 0.20150905847549438, 0.014220...","[109.0, 25.659704208374023, 9.441706657409668, 4.360199928283691, 0.23541012406349182, 0.17664438486099243, 0.001482...","[4.0, 0.8875260949134827, 0.2751735746860504, 0.10274406522512436, 0.22188152372837067, 0.13986416161060333, 0.08127...","[4.0, 1.702500581741333, 0.8316766023635864, 0.4534023404121399, 0.42562514543533325, 0.1635921150445938, 0.23315946...","[7.0, 2.3992271423339844, 1.1507759094238281, 0.6292836666107178, 0.3427467346191406, 0.21661308407783508, 0.0503232...","[71.0, 37.08989715576172, 23.824373245239258, 17.304271697998047, 0.5223929286003113, 0.25032010674476624, 0.0157993...","[5.0, 0.008934106677770615, 1.7077310985769145e-05, 3.475332377433915e-08, 0.0017868212889879942, 0.0004719449789263...","[7.0, 0.029441451653838158, 0.00014313982683233917, 8.013990395738801e-07, 0.0042059216648340225, 0.0016609540907666..."
12327,"[15.0, 5.616213798522949, 3.184290885925293, 2.2585856914520264, 0.37441426515579224, 0.2685144543647766, 0.01196209...","[45.0, 15.88559341430664, 7.884194850921631, 4.535528659820557, 0.35301318764686584, 0.22491337358951569, 0.00324547...","[13.0, 5.120848178863525, 2.4622015953063965, 1.3250749111175537, 0.3939113914966583, 0.1850241720676422, 0.11292292...","[6.0, 1.7465453147888184, 0.6612012386322021, 0.26995649933815, 0.29109087586402893, 0.15958166122436523, 0.01187515...","[47.0, 13.251874923706055, 5.611817359924316, 2.905034065246582, 0.2819547951221466, 0.1997545063495636, 0.018063504...","[132.0, 42.261619567871094, 19.22429084777832, 10.489518165588379, 0.3201637864112854, 0.20768658816814423, 0.001000...","[23.0, 7.932838439941406, 4.073561668395996, 2.493381977081299, 0.3449060320854187, 0.24114559590816498, 0.003938904...","[10.0, 2.696364164352417, 1.1931681632995605, 0.6392337083816528, 0.2696364223957062, 0.2159004807472229, 0.04666179...","[25.0, 10.482751846313477, 5.687837600708008, 3.524542808532715, 0.419310063123703, 0.22736001014709473, 0.071512028...","[3.0, 1.2304683923721313, 0.5494815707206726, 0.25970321893692017, 0.41015613079071045, 0.12219854444265366, 0.24499...","[219.0, 62.78380584716797, 24.358396530151367, 11.371350288391113, 0.2866840362548828, 0.17040486633777618, 0.005559...","[4.0, 0.006714826449751854, 1.135989441536367e-05, 1.9379191584789623e-08, 0.0016787066124379635, 0.0001480462233303...","[20.0, 0.08674301207065582, 0.0004941411898471415, 3.7328313737816643e-06, 0.004337150603532791, 0.00242820614948868..."
16914,"[3.0, 0.29402869939804077, 0.051657889038324356, 0.010914384387433529, 0.09800956398248672, 0.08725491911172867, 0.0...","[12.0, 3.9068446159362793, 2.099977731704712, 1.3477599620819092, 0.32557037472724915, 0.2626824676990509, 0.0061264...","[14.0, 4.779664516448975, 2.4229390621185303, 1.4576963186264038, 0.34140461683273315, 0.23771823942661285, 0.001000...","[3.0, 1.0235271453857422, 0.5286986827850342, 0.3331301212310791, 0.34117570519447327, 0.2446058690547943, 0.1463631...","[12.0, 5.583156108856201, 2.88519549369812, 1.5987184047698975, 0.4652630090713501, 0.15480084717273712, 0.227848127...","[4.0, 2.010869264602661, 1.2256022691726685, 0.7987610101699829, 0.5027173161506653, 0.2316805124282837, 0.134778305...","[10.0, 4.993411540985107, 2.853238582611084, 1.7057143449783325, 0.4993411600589752, 0.18968991935253143, 0.05162844...","[2.0, 0.8485768437385559, 0.5328852534294128, 0.372768372297287, 0.42428842186927795, 0.2939761281013489, 0.13031233...","[8.0, 3.3591322898864746, 1.8293230533599854, 1.1303868293762207, 0.4198915362358093, 0.22881539165973663, 0.1064908...","[5.0, 2.1138265132904053, 1.1800000667572021, 0.7693287134170532, 0.42276531457901, 0.23931045830249786, 0.207193180...","[1.0, 0.22403423488140106, 0.05019133910536766, 0.011244578287005424, 0.22403423488140106, 0.0, 0.22403423488140106,...","[1.0, 0.0015809015603736043, 2.4992498310894007e-06, 3.951067917284945e-09, 0.0015809015603736043, 0.0, 0.0015809015...","[20.0, 0.07284502685070038, 0.00027988693909719586, 1.1335534964018734e-06, 0.0036422512494027615, 0.000853435660246..."
9100,"[89.0, 37.40862274169922, 21.186012268066406, 13.606220245361328, 0.42032161355018616, 0.24773943424224854, 0.003912...","[36.0, 13.117101669311523, 6.954886436462402, 4.385624408721924, 0.3643639385700226, 0.24582557380199432, 0.00404919...","[15.0, 6.722784042358398, 3.8764395713806152, 2.4549312591552734, 0.44818559288978577, 0.23991456627845764, 0.034557...","[8.0, 2.1465189456939697, 1.0489513874053955, 0.6221469640731812, 0.2683148682117462, 0.2431585043668747, 0.04641634...","[27.0, 8.01862907409668, 3.801650047302246, 2.1583001613616943, 0.29698625206947327, 0.22934910655021667, 0.01361645...","[29.0, 8.610908508300781, 3.8850324153900146, 2.2461724281311035, 0.29692786931991577, 0.21401046216487885, 0.052954...","[27.0, 9.70600414276123, 4.912251949310303, 2.970088005065918, 0.3594816327095032, 0.22958268225193024, 0.0363308489...","[2.0, 0.3881838321685791, 0.14688241481781006, 0.056279003620147705, 0.19409191608428955, 0.1891283541917801, 0.0049...","[12.0, 5.121079444885254, 3.003354072570801, 2.011220693588257, 0.4267566204071045, 0.26107144355773926, 0.013481559...","[4.0, 1.0175718069076538, 0.4660413861274719, 0.23671062290668488, 0.25439295172691345, 0.22758421301841736, 0.01202...","[88.0, 22.287391662597656, 6.914508819580078, 2.708775043487549, 0.2532658278942108, 0.12012652307748795, 0.00907965...","[4.0, 0.00749985920265317, 1.4608454875997268e-05, 2.9472156271026506e-08, 0.0018749648006632924, 0.0003696223138831...","[13.0, 0.05275449529290199, 0.00023740122560411692, 1.1764184364437824e-06, 0.004058037884533405, 0.0013393886620178..."


In [21]:
import json
import random

# Feature names (baseline schema)
FEATURES = [
    "count_pois", "mean_distance", "min_distance", "max_distance",
    "median_distance", "std_distance",
    "mean_inverse_distance", "max_inverse_distance", "sum_inverse_distance",
    "ratio_within_near_radius", "ratio_within_mid_radius", "ratio_within_far_radius"
]

# Pick a random apartment with a non-null emb_medical
sample_row = df.dropna(subset=["emb_medical"]).sample(1, random_state=random.randint(0,10000)).iloc[0]
apt_id = sample_row["id"]

# Parse the JSON string into a Python list
emb = json.loads(sample_row["emb_medical"])

print(f"Apartment ID: {apt_id}\n")
print("Medical embedding values:")
for i, (name, val) in enumerate(zip(FEATURES, emb)):
    print(f"  dim{i:02d}: {val:.6f} → {name}")


Apartment ID: 2853789086

Medical embedding values:
  dim00: 55.000000 → count_pois
  dim01: 20.414429 → mean_distance
  dim02: 10.583040 → min_distance
  dim03: 6.343233 → max_distance
  dim04: 0.371171 → median_distance
  dim05: 0.233775 → std_distance
  dim06: 0.015978 → mean_inverse_distance
  dim07: 0.930415 → max_inverse_distance
  dim08: 0.930415 → sum_inverse_distance
  dim09: 0.810360 → ratio_within_near_radius
  dim10: 0.036364 → ratio_within_mid_radius
  dim11: 0.309091 → ratio_within_far_radius


Apartment ID: 1591678951
Medical embedding values:
 - dim0: 23.000000 &rarr; Cantidad de POIs
 - dim1: 6.013748 &rarr; Distancia media
 - dim2: 2.323946 &rarr; Distancia mínima
 - dim3: 1.039509 &rarr; Distancia máxima
 - dim4: 0.261467 &rarr; Mediana de la distancia
 - dim5: 0.180765 &rarr; Desviación estandar de la distancia
 - dim6: 0.002084 &rarr; Cercanía media (distancia inversa)
 - dim7: 0.566388 &rarr;
 - dim8: 0.566388 &rarr;
 - dim9: 0.540043 &rarr;
 - dim10: 0.000000 &rarr;
 - dim11: 0.217391 &rarr;

 -  dim0: 9.000000 &rarr; Cantidad de POIs de la clase vinculados al departamento
 -  dim1: 3.402536 &rarr; Distancia media
 -  dim2: 1.822208 &rarr; Distancia Minima
 -  dim3: 1.077151 &rarr; Distancia Maxima
 -  dim4: 0.378060 &rarr; Mediana de la distancia
 -  dim5: 0.244005 &rarr; Desviación estandar de la distancia
 -  dim6: 0.002132 &rarr; Cercanía media (distancia inversa)
 -  dim7: 0.795181 &rarr; Cercanía máxima (POI más cercano)
 -  dim8: 0.795181 &rarr; Cercanía total (suma de distancias inversas)
 -  dim9: 0.625955 &rarr; Proporción dentro del radio cercano (600m) 
 -  dim10: 0.111111 &rarr; Proporción dentro del radio medio (1200m)
 -  dim11: 0.222222 &rarr; Proporción dentro del radio lejano (2400m)

Sobre dim 9, 10 y 11, se refieren a la fracción de POIs que caen dentro de un cierto radio respecto del total de POIs del contexto o clase para ese departamento.