In [None]:
import geopandas as gpd
import numpy as np
import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx, k_hop_subgraph
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import GroupKFold
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    cohen_kappa_score,
)

# =============================================================================
# 1. Load and preprocess geospatial data
# =============================================================================
mesh = gpd.read_file("/home/stagiaire/Téléchargements/PR/D/mesh_rj.shp")
mesh["label"] = np.where(
    (mesh["vegetation"] <= 0.95)
    & (mesh["ghsl"] >= 0.5)
    & (mesh["osm"] <= 0.1)
    & (mesh["favelas"] > 0.9),
    1,
    np.where(
        (mesh["vegetation"] <= 0.95)
        & (mesh["ghsl"] >= 0.5)
        & (mesh["osm"] <= 0.1)
        & (mesh["favelas"] == 0),
        0,
        np.nan,
    ),
)
dataset = mesh[mesh["label"].notna()].copy()

zones = gpd.read_file("/home/stagiaire/Téléchargements/PR/D/zones.shp")
dataset["centroid"] = dataset.geometry.centroid
points_zones = gpd.sjoin(
    dataset.set_geometry("centroid"), zones[["fid", "geometry"]],
    how="left", predicate="within"
)
dataset["zone"] = points_zones["fid"]
dataset.drop(columns=["centroid"], inplace=True)
dataset = dataset[dataset["zone"].notna()].reset_index(drop=True)

# =============================================================================
# 2. Build planar adjacency graph
# =============================================================================
feature_cols = [
    "vegetation", "slope", "profile_co", "entropy",
    "nodes", "roads", "mean_conne", "min_connex", "max_connex",
]
G = nx.Graph()
for _, row in dataset.iterrows():
    feats = torch.tensor(row[feature_cols].values, dtype=torch.float32)
    G.add_node(int(row["id"]), x=feats, label=int(row["label"]))
for _, row in dataset.iterrows():
    neigh = dataset[dataset.geometry.touches(row.geometry)]
    for _, nr in neigh.iterrows():
        G.add_edge(int(row["id"]), int(nr["id"]))

pyg_data = from_networkx(G, group_node_attrs=["x"])
pyg_data.y = torch.tensor(
    [G.nodes[n]["label"] for n in G.nodes()], dtype=torch.long
)
zone_map = {int(r["id"]): int(r["zone"]) for _, r in dataset.iterrows()}
node_ids = list(G.nodes())
groups = np.array([zone_map[n] for n in node_ids])

# =============================================================================
# 3. Ego‑graph extraction
# =============================================================================
def extract_ego_graph(data, center, hops=1):
    subset, edge_idx, mapping, _ = k_hop_subgraph(
        center, hops, data.edge_index, relabel_nodes=True
    )
    mapping = int(mapping)
    if mapping != 0:
        order = [mapping] + [i for i in range(len(subset)) if i != mapping]
        perm = torch.tensor(order, dtype=torch.long)
        sub_x = data.x[subset][perm]
        inv = {old: new for new, old in enumerate(perm.tolist())}
        new_ei = edge_idx.clone()
        for i in range(new_ei.size(1)):
            new_ei[0, i] = inv[int(new_ei[0, i])]
            new_ei[1, i] = inv[int(new_ei[1, i])]
        center_idx = 0
    else:
        sub_x = data.x[subset]
        new_ei = edge_idx
        center_idx = 0

    g = Data(x=sub_x, edge_index=new_ei)
    g.y = data.y[center].unsqueeze(0)
    g.center_idx = center_idx
    return g

# =============================================================================
# 4. Undersampling to balance classes
# =============================================================================
def undersample(indices, labels):
    idx = np.array(indices)
    labs = labels[idx].numpy()
    c0, c1 = idx[labs == 0], idx[labs == 1]
    if len(c0) == 0 or len(c1) == 0:
        return idx
    m = min(len(c0), len(c1))
    res = np.concatenate([
        np.random.choice(c0, m, replace=False),
        np.random.choice(c1, m, replace=False),
    ])
    np.random.shuffle(res)
    return res

# =============================================================================
# 5. GCN model definition
# =============================================================================
class EgoGCN(nn.Module):
    def __init__(self, in_dim, hid_dim, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, hid_dim)
        self.classifier = nn.Linear(hid_dim, num_classes)

    def forward(self, data):
        x, edge_idx = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_idx))
        x = F.relu(self.conv2(x, edge_idx))
        center = data.ptr[0].item()
        out = self.classifier(x[center : center + 1])
        return F.log_softmax(out, dim=1)

def train_epoch(loader, model, optimizer, criterion, device):
    model.train()
    loss_sum = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        loss = criterion(model(batch), batch.y.view(-1))
        loss.backward()
        optimizer.step()
        loss_sum += loss.item() * batch.num_graphs
    return loss_sum / len(loader.dataset)

def test_epoch(loader, model, device):
    model.eval()
    correct = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch).argmax(dim=1)
        correct += (pred == batch.y.view(-1)).sum().item()
    return correct / len(loader.dataset)

# =============================================================================
# 6. Serial spatial cross‑validation (10×5 folds)
# =============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_dim = pyg_data.x.size(1)
hid_dim = 64
num_classes = len(torch.unique(pyg_data.y))

n_iters = 10
n_folds = 5
metrics_iters = np.zeros((n_iters, n_folds, 4))  # P, R, F1, Kappa

for it in range(n_iters):
    print(f"\n=== Iteration {it + 1}/{n_iters} ===")
    gkf = GroupKFold(n_splits=n_folds)
    for fold, (train_idx, test_idx) in enumerate(
        gkf.split(np.arange(pyg_data.num_nodes), pyg_data.y.numpy(), groups),
        start=1,
    ):
        # balance and extract ego-graphs
        t_idx = undersample(train_idx, pyg_data.y)
        v_idx = undersample(test_idx, pyg_data.y)
        train_graphs = [extract_ego_graph(pyg_data, int(i)) for i in t_idx]
        test_graphs = [extract_ego_graph(pyg_data, int(i)) for i in v_idx]

        train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_graphs, batch_size=32)

        model = EgoGCN(in_dim, hid_dim, num_classes).to(device)
        opt = optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
        crit = nn.CrossEntropyLoss()

        # train
        for epoch in range(1, 401):
            loss = train_epoch(train_loader, model, opt, crit, device)
            if epoch % 100 == 0:
                acc = test_epoch(test_loader, model, device)
                print(f"Fold {fold} | Epoch {epoch:03d} | Loss: {loss:.4f} | Acc: {acc:.4f}")

        # evaluate
        all_preds, all_trues = [], []
        with torch.no_grad():
            for batch in test_loader:
                p = model(batch).argmax(dim=1).cpu()
                all_preds.append(p)
                all_trues.append(batch.y.cpu())
        preds = torch.cat(all_preds).numpy()
        trues = torch.cat(all_trues).numpy()

        p = precision_score(trues, preds)
        r = recall_score(trues, preds, zero_division=0)
        f1 = f1_score(trues, preds, zero_division=0)
        k = cohen_kappa_score(trues, preds)

        print(f"Fold {fold} | P:{p:.3f} R:{r:.3f} F1:{f1:.3f} K:{k:.3f}")
        metrics_iters[it, fold - 1] = [p, r, f1, k]

# =============================================================================
# 7. Aggregate serial CV results
# =============================================================================
mean_fold = metrics_iters.mean(axis=0)
std_fold = metrics_iters.std(axis=0)

print("\n--- Per-fold averages over iterations ---")
for fold in range(n_folds):
    print(
        f"Fold {fold+1}: "
        f"P {mean_fold[fold,0]:.3f}±{std_fold[fold,0]:.3f} "
        f"R {mean_fold[fold,1]:.3f}±{std_fold[fold,1]:.3f} "
        f"F1 {mean_fold[fold,2]:.3f}±{std_fold[fold,2]:.3f} "
        f"K {mean_fold[fold,3]:.3f}±{std_fold[fold,3]:.3f}"
    )

global_mean = mean_fold.mean(axis=0)
global_std = mean_fold.std(axis=0)
print(
    "\n--- Global averages ---\n"
    f"P {global_mean[0]:.3f}±{global_std[0]:.3f} "
    f"R {global_mean[1]:.3f}±{global_std[1]:.3f} "
    f"F1 {global_mean[2]:.3f}±{global_std[2]:.3f} "
    f"K {global_mean[3]:.3f}±{global_std[3]:.3f}"
)