In [None]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from pathlib import Path
from scipy import sparse
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DATA_DIR   = Path("raw_data")  # Input CSVs directory
ART_DIR    = Path("artifacts") # Output artifacts directory
ART_DIR.mkdir(exist_ok=True)

print("--- Starting Data Preprocessing ---")

# 1) Load & lowercase
print("Loading people.csv...")
people = pd.read_csv(
    DATA_DIR / "people.csv",
    low_memory=False,
    dtype={"first_name": str, "middle_name": str, "patronym": str, "surname": str}
)
people.columns = people.columns.str.lower().str.strip()
people = people.set_index("id")

# 2) Canonicalize names & flags
print("Canonicalizing names and flags...")
NAME_COLS = ["first_name", "middle_name", "patronym", "surname"]
for col in NAME_COLS:
    if col not in people.columns:
        people[col] = ""
people["full_name"] = (
    people[NAME_COLS]
    .fillna("")
    .apply(lambda r: " ".join(w.strip().lower() for w in r if w), axis=1)
)
people["birthyear"] = pd.to_numeric(people["birthyear"], errors="coerce").fillna(0).astype(int)
people["heimild"]   = pd.to_numeric(people["heimild"],   errors="coerce").fillna(0).astype(int)
people["sex_male"]  = people["sex"].apply(lambda x: 1 if isinstance(x, str) and x.lower()=="karl" else 0).astype(int)
people["has_partner"]= people["partner"].notna().astype(int)
people["has_father"] = people["father"].notna().astype(int)
people["has_mother"] = people["mother"].notna().astype(int)

# 3) Load and join manntol IDs
print("Loading manntol_einstaklingar_new.csv and merging ID columns…")
mann = pd.read_csv(
    DATA_DIR / "manntol_einstaklingar_new.csv",
    dtype=str,
    usecols=["id","bi_sokn","bi_hreppur","bi_sysla","thsk_maki","thsk_fadir","thsk_modir"]
).rename(columns={
    "bi_sokn":    "parish_id",
    "bi_hreppur": "district_id",
    "bi_sysla":   "county_id",
    "thsk_maki":  "partner_mann",
    "thsk_fadir": "father_mann",
    "thsk_modir": "mother_mann",
})
mann["id"] = mann["id"].astype(people.index.dtype)
mann = mann.set_index("id")
people = people.join(
    mann[["parish_id","district_id","county_id","partner_mann","father_mann","mother_mann"]],
    how="left"Longitudinal Identity Resolution Task
)
people["partner"] = people["partner"].fillna(people["partner_mann"])
people["father"]  = people["father"].fillna(people["father_mann"])
people["mother"]  = people["mother"].fillna(people["mother_mann"])
people = people.drop(columns=["partner_mann","father_mann","mother_mann"])

# 4) Merge geography
print("Merging geography data via manntol IDs…")
for fname, idcol, merge_on, newcol in [
    ("parishes.csv",  "id", "parish_id",   "parish_full"),
    ("districts.csv", "id", "district_id", "district_name"),
    ("counties.csv",  "id", "county_id",   "county_name"),
]:
    try:
        geo = pd.read_csv(DATA_DIR / fname, low_memory=False)
        geo.columns = geo.columns.str.lower().str.strip()
        geo[idcol] = geo[idcol].astype(str)
        col_ren = "full_name" if newcol=="parish_full" else "name"
        geo = geo.rename(columns={idcol: idcol, col_ren: newcol})
        if merge_on not in people.columns:
            print(f"Warning: '{merge_on}' missing; skipping {fname}")
            continue
        people = people.merge(
            geo[[idcol,newcol]],
            left_on=merge_on, right_on=idcol,
            how="left", suffixes=("", "_drop")
        ).drop(columns=[c for c in people.columns if c.endswith("_drop") or c==idcol])
    except FileNotFoundError:
        print(f"Warning: {fname} not found; skipping")
    except Exception as e:
        print(f"Warning: error merging {fname}: {e}")

# 5) Impute static attributes
print("Imputing static attributes…")
if "person" in people.columns:
    for col in ["full_name","birthyear","sex_male"]:
        if col in people.columns:
            people[col] = people.groupby("person")[col].transform(lambda x: x.ffill().bfill())
else:
    print("Warning: 'person' column not found. Skipping imputation.")
people["full_name"] = people["full_name"].fillna("unknown")
people["birthyear"] = people["birthyear"].fillna(0).astype(int)
people["sex_male"] = people["sex_male"].fillna(0).astype(int)

# 6) **Include ALL rows** for ML features and graph
print("Including ALL rows for ML features and graph…")
people_ml = people.copy()
print(f"→ {len(people_ml)} total rows")

# Export full row_labels.csv
pd.DataFrame({
    "row_id": people_ml.index,
    "person": people_ml["person"].astype(str)
}).to_csv(ART_DIR / "row_labels.csv", index=False)

# Export just the linked subset
people_ml.loc[people_ml["person"].notna(), ["person"]] \
         .to_csv(ART_DIR / "rows_with_person.csv", index_label="row_id")

# 7) Define feature sets
print("Defining feature sets…")
NUM_COLS = ["birthyear","heimild","sex_male","has_partner","has_father","has_mother"]
CAT_LOW  = ["status","marriagestatus","district_name","county_name"]
CAT_HIGH = ["parish_full"]
for col in NUM_COLS+CAT_LOW+CAT_HIGH:
    if col not in people_ml.columns:
        people_ml[col] = 0 if col in NUM_COLS else ""

# 8) Numeric matrix
print("Creating numeric features…")
X_num = people_ml[NUM_COLS].values.astype(np.float32)

# 9) One-hot low-cardinality
print("One-hot encoding low-cardinality…")
ohe   = OneHotEncoder(handle_unknown="ignore", sparse_output=True)
X_low = ohe.fit_transform(people_ml[CAT_LOW].fillna("").astype(str))
low_cols = ohe.get_feature_names_out(CAT_LOW)

# 10) Ordinal high-cardinality
print("Ordinal encoding high-cardinality…")
ord_enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
X_high = ord_enc.fit_transform(people_ml[CAT_HIGH].fillna("").astype(str)).astype(np.float32)

# 11) Save sparse features
print("Saving sparse features...")
sparse.save_npz(
    ART_DIR / "iceid_ml_ready.npz",
    sparse.hstack([sparse.csr_matrix(X_num), X_low, sparse.csr_matrix(X_high)], format="csr")
)

# 12) Create temporal graph
print("Creating temporal graph…")
row_id_to_idx = {rid:i for i,rid in enumerate(people_ml.index)}
edges = []
for _,grp in people_ml[people_ml["person"].notna()].groupby("person"):
    ids = grp.sort_values("heimild").index
    idx = [row_id_to_idx[r] for r in ids]
    edges += [[u,v] for u,v in zip(idx, idx[1:])] + [[v,u] for u,v in zip(idx, idx[1:])]
edge_index = (torch.tensor(edges, dtype=torch.long).t()
              if edges else torch.empty((2,0),dtype=torch.long))
graph = Data(edge_index=edge_index)
graph.node_id = torch.tensor(people_ml.index.values, dtype=torch.long)
# save only the *structure* (edges + node IDs), not x
torch.save({
    "edge_index": graph.edge_index,
    "node_id":    graph.node_id
}, ART_DIR / "temporal_graph.pt")
print(f"Graph: {graph.num_nodes} nodes, {graph.num_edges//2} undirected edges.")
print("--- Data Preprocessing Finished ---")

--- Starting Data Preprocessing ---
Loading people.csv...
Canonicalizing names and flags...
Loading manntol_einstaklingar_new.csv and merging ID columns…
Merging geography data via manntol IDs…
Imputing static attributes…
Including ALL rows for ML features and graph…
→ 984028 total rows
Defining feature sets…
Creating numeric features…
One-hot encoding low-cardinality…
Ordinal encoding high-cardinality…
Saving sparse features...
Creating temporal graph…
Graph: 984028 nodes, 266817 undirected edges.
--- Data Preprocessing Finished ---
