# 1. Road Network Construction

In [2]:
import geopandas as gpd
import pandas as pd
import numpy as np
import networkx as nx
from shapely.ops import unary_union
from shapely.geometry import Point, LineString
from sklearn.cluster import DBSCAN
import folium

# 1) Load
gdf = gpd.read_file("17122148/itinere_roads.geojson")

# 2) Inspect CRS / projection
print("CRS:", gdf.crs)
print("Rows:", len(gdf))
print("Columns:", list(gdf.columns))
print("Geometry types:\n", gdf.geometry.geom_type.value_counts(dropna=False))

# 3) Show a quick bounds check (helps detect wrong CRS)
print("Total bounds:", gdf.total_bounds)  # [minx, miny, maxx, maxy]

# 4) Decide projected vs geographic versions
# If the CRS is missing, assume EPSG:4326 (WGS84 lon/lat) and set it.
if gdf.crs is None:
    gdf = gdf.set_crs(epsg=4326, allow_override=True)
    print("CRS was missing -> set to EPSG:4326")

# Create two standard views:
# - gdf_ll: geographic (EPSG:4326)
# - gdf_m: metric/projected (EPSG:3395) for distance-based operations
gdf_ll = gdf.to_crs(epsg=4326)
gdf_m = gdf.to_crs(epsg=3395)

# Starting from your already-loaded gdf (EPSG:3395)
gdf_m = gdf  # keep naming consistent: metric CRS

# Explode MultiLineStrings into LineStrings
gdf_lines = gdf_m.explode(index_parts=True).reset_index(drop=True)

print("After explode:")
print("CRS:", gdf_lines.crs)
print("Rows:", len(gdf_lines))
print("Geometry types:\n", gdf_lines.geometry.geom_type.value_counts(dropna=False))

# Quick check: any empties?
print("Empty geometries:", int(gdf_lines.geometry.is_empty.sum()))
print("Null geometries:", int(gdf_lines.geometry.isna().sum()))

# 1) Node + split at intersections + overlaps
u = unary_union(list(gdf_lines.geometry))

print("unary_union result geom_type:", u.geom_type)

# 2) Extract LineStrings into a list
lines = []
if u.geom_type == "LineString":
    lines = [u]
elif u.geom_type == "MultiLineString":
    lines = list(u.geoms)
else:
    # GeometryCollection or other: collect any LineStrings/MultiLineStrings inside
    for gg in getattr(u, "geoms", []):
        if gg.geom_type == "LineString":
            lines.append(gg)
        elif gg.geom_type == "MultiLineString":
            lines.extend(list(gg.geoms))

# 3) Make edges GeoDataFrame
edges_noded = gpd.GeoDataFrame({"geometry": lines}, crs=gdf_lines.crs)

print("Noded edges:")
print("Rows:", len(edges_noded))
print("Geometry types:\n", edges_noded.geometry.geom_type.value_counts(dropna=False))

# 4) Quick sanity checks
print("Empty:", int(edges_noded.geometry.is_empty.sum()))
print("Null:", int(edges_noded.geometry.isna().sum()))

# 5) Optional: total length comparison (should be close-ish, can change due to overlap handling)
print("Input total length:", float(gdf_lines.length.sum()))
print("Noded total length:", float(edges_noded.length.sum()))

# edges_noded from step 3 is now a GeoDataFrame of LineStrings 
# representing the noded edges. 
# Next, we want to build a nodes GeoDataFrame and map edges 
# to node IDs.
# Build node candidates from every edge endpoint
start_pts = edges_noded.geometry.apply(lambda ls: Point(ls.coords[0]))
end_pts   = edges_noded.geometry.apply(lambda ls: Point(ls.coords[-1]))

nodes_raw = gpd.GeoDataFrame(
    {"geometry": pd.concat([start_pts, end_pts], ignore_index=True)},
    crs=edges_noded.crs,
)

# Deduplicate nodes (exact coordinate match)
nodes_raw["wkb"] = nodes_raw.geometry.apply(lambda p: p.wkb_hex)
nodes = nodes_raw.drop_duplicates("wkb").drop(columns="wkb").reset_index(drop=True)
nodes["node_id"] = nodes.index.astype("int64") + 1

print("Nodes (unique endpoints):", len(nodes))

# Map endpoints -> node_id
node_map = {p.wkb_hex: nid for p, nid in zip(nodes.geometry, nodes["node_id"])}

from_ids = []
to_ids = []
for ls in edges_noded.geometry:
    a = Point(ls.coords[0]).wkb_hex
    b = Point(ls.coords[-1]).wkb_hex
    from_ids.append(node_map[a])
    to_ids.append(node_map[b])

edges = edges_noded.copy()
edges["edge_id"] = edges.index.astype("int64") + 1
edges["from_node"] = from_ids
edges["to_node"] = to_ids

# 4) Compute node degree from edges
deg = pd.concat([edges["from_node"], edges["to_node"]]).value_counts()
nodes["degree"] = nodes["node_id"].map(deg).fillna(0).astype(int)

print("Degree summary:")
print(nodes["degree"].describe())

print("Degree counts (top 10):")
print(nodes["degree"].value_counts().head(10))

# 5) Sanity check: edges count and implied degree sum
print("Edges:", len(edges))
print("Sum of degrees:", int(nodes["degree"].sum()))


def simplify_degree2(G: nx.Graph):
    """
    Collapse degree-2 chains into single edges.
    Keeps endpoints (deg==1) and junctions (deg>=3) as 'important' nodes.
    Returns:
      H: simplified graph
      edge_geoms: list of dicts with true polyline geometry
                 [{"u":u, "v":v, "weight":w, "geometry":LineString([...])}, ...]
    """
    # nodes to keep
    important = {n for n, deg in G.degree() if deg != 2}

    H = nx.Graph()
    edge_geoms = []

    # Track which directed adjacency traversals we already used
    seen_dir = set()

    for u in important:
        for v in G.neighbors(u):
            if (u, v) in seen_dir:
                continue

            # start path u -> v
            coords = [u, v]
            total_w = float(G[u][v].get("weight", 0.0))

            seen_dir.add((u, v))
            seen_dir.add((v, u))

            prev = u
            cur = v

            # walk through degree-2 nodes until we reach another important node
            while cur not in important:
                nbrs = [n for n in G.neighbors(cur) if n != prev]
                if len(nbrs) != 1:
                    break
                nxt = nbrs[0]
                total_w += float(G[cur][nxt].get("weight", 0.0))

                prev, cur = cur, nxt
                coords.append(cur)

                seen_dir.add((prev, cur))
                seen_dir.add((cur, prev))

            if cur == u:
                continue

            # add/merge edge in simplified graph
            if H.has_edge(u, cur):
                H[u][cur]["weight"] = min(float(H[u][cur].get("weight", total_w)), total_w)
            else:
                H.add_edge(u, cur, weight=total_w)

            edge_geoms.append({
                "u": u,
                "v": cur,
                "weight": float(total_w),
                "geometry": LineString(coords),
            })

    # Ensure isolated important nodes are kept
    for n in important:
        if n not in H:
            H.add_node(n)

    return H, edge_geoms

def simplify_degree2_until_stable(G: nx.Graph, max_iter: int = 10):
    """
    Repeat degree-2 simplification until no further changes.
    Useful because after simplification, some previously 'important' nodes can become degree-2.
    Returns:
      H: final simplified graph
      edge_geoms: geometries for the last pass (true polylines)
    """
    cur = G
    last_geoms = []

    for _ in range(int(max_iter)):
        nxt, geoms = simplify_degree2(cur)

        if (nxt.number_of_nodes() == cur.number_of_nodes()) and (nxt.number_of_edges() == cur.number_of_edges()):
            return nxt, geoms

        cur = nxt
        last_geoms = geoms

    return cur, last_geoms


# REQUIRED: You should already have simplify_degree2_until_stable from earlier work.
# It must return: (H1, edge_geoms1) where edge_geoms1 has {"u","v","weight","geometry"} with true polylines.
def build_detailed_graph_from_edges(edges_gdf) -> nx.Graph:
    G = nx.Graph()
    for line in edges_gdf.geometry:
        coords = list(line.coords)
        if len(coords) < 2:
            continue
        for (x1, y1, *_), (x2, y2, *__) in zip(coords[:-1], coords[1:]):
            u = (float(x1), float(y1))
            v = (float(x2), float(y2))
            if u == v:
                continue
            w = float(np.hypot(v[0] - u[0], v[1] - u[1]))
            if w <= 0:
                continue
            if G.has_edge(u, v):
                # keep the shortest if duplicates exist
                G[u][v]["weight"] = min(G[u][v]["weight"], w)
            else:
                G.add_edge(u, v, weight=w)
    return G

G = build_detailed_graph_from_edges(edges_noded)
print(f"Detailed graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
H1, edge_geoms1 = simplify_degree2_until_stable(G)

# 1) Edges GeoDataFrame (projected)
edges_simpl_3395 = gpd.GeoDataFrame(edge_geoms1, crs=edges_noded.crs).copy()
edges_simpl_3395 = edges_simpl_3395.drop_duplicates(subset=["u", "v", "weight"]).reset_index(drop=True)
edges_simpl_3395["edge_id"] = edges_simpl_3395.index.astype("int64") + 1

# 2) Nodes GeoDataFrame (projected)
nodes_list = list(H1.nodes())
nodes_simpl_3395 = gpd.GeoDataFrame(
    {
        "node_id": np.arange(1, len(nodes_list) + 1, dtype="int64"),
        "degree": [H1.degree(n) for n in nodes_list],
        "x": [n[0] for n in nodes_list],
        "y": [n[1] for n in nodes_list],
    },
    geometry=gpd.points_from_xy([n[0] for n in nodes_list], [n[1] for n in nodes_list]),
    crs=edges_noded.crs,
)

# 3) Add from/to node IDs to edges (stable mapping)
node_id_map = {n: i+1 for i, n in enumerate(nodes_list)}
edges_simpl_3395["from_node"] = edges_simpl_3395["u"].map(node_id_map).astype("int64")
edges_simpl_3395["to_node"]   = edges_simpl_3395["v"].map(node_id_map).astype("int64")

print("Projected simplified edges:", len(edges_simpl_3395))
print("Projected simplified nodes:", len(nodes_simpl_3395))

# 4) Create unprojected versions (EPSG:4326)
nodes_simpl_4326 = nodes_simpl_3395.to_crs(epsg=4326)
edges_simpl_4326 = edges_simpl_3395.to_crs(epsg=4326)

H1, edge_geoms1 = simplify_degree2_until_stable(G)
print(f"H1 (deg-2 removed): {H1.number_of_nodes():,} nodes, {H1.number_of_edges():,} edges")

def merge_close_intersections_dbscan(H: nx.Graph, radius_m: float = 200.0, min_degree: int = 3):
    """
    Merge nodes with degree>=min_degree within radius_m into centroid nodes.
    Returns:
      Hm: merged graph
      mapping: old_node -> new_node
      cluster_info: dict with diagnostic info for plotting
    """
    candidates = [n for n in H.nodes() if H.degree(n) >= min_degree]
    if len(candidates) == 0:
        return H.copy(), {n: n for n in H.nodes()}, {
            "candidates": [],
            "labels": np.array([], dtype=int),
            "centroids": {},
            "merge_nodes_before": set(),
            "centroid_nodes_after": set(),
        }

    X = np.asarray(candidates, dtype=float)  # (k,2)
    db = DBSCAN(eps=radius_m, min_samples=1).fit(X)
    labels = db.labels_

    # Centroid per cluster label
    centroids = {}
    for lab in np.unique(labels):
        pts = X[labels == lab]
        centroids[int(lab)] = (float(pts[:, 0].mean()), float(pts[:, 1].mean()))

    # Nodes that actually merge = clusters with size > 1
    merge_nodes_before = set()
    centroid_nodes_after = set()
    for lab in np.unique(labels):
        members = [candidates[i] for i in np.where(labels == lab)[0]]
        if len(members) > 1:
            merge_nodes_before.update(members)
            centroid_nodes_after.add(centroids[int(lab)])

    # Mapping old -> new (only candidates move)
    mapping = {n: n for n in H.nodes()}
    for n, lab in zip(candidates, labels):
        mapping[n] = centroids[int(lab)]

    # Build merged graph
    Hm = nx.Graph()
    for u, v, d in H.edges(data=True):
        u2 = mapping[u]
        v2 = mapping[v]
        if u2 == v2:
            continue
        w = float(d.get("weight", 1.0))
        if Hm.has_edge(u2, v2):
            Hm[u2][v2]["weight"] = min(Hm[u2][v2].get("weight", w), w)
        else:
            Hm.add_edge(u2, v2, weight=w)

    # Keep isolated nodes
    for n2 in set(mapping.values()):
        if n2 not in Hm:
            Hm.add_node(n2)

    cluster_info = {
        "candidates": candidates,
        "labels": labels,
        "centroids": centroids,
        "merge_nodes_before": merge_nodes_before,
        "centroid_nodes_after": centroid_nodes_after,
    }
    return Hm, mapping, cluster_info


# Optional: merge roundabouts (topology) to produce an "info" dict for mapping
H2, mapping, info_roundabout = merge_close_intersections_dbscan(H1, radius_m=200.0, min_degree=3)
print("Roundabout merge candidates:", len(info_roundabout.get("candidates", [])))
print("Nodes merged (before):", len(info_roundabout.get("merge_nodes_before", [])))
print("Centroids (after):", len(info_roundabout.get("centroid_nodes_after", [])))

# ----------------------------
# utilities
# ----------------------------
def clip_to_region(gdf, region_poly):
    minx, miny, maxx, maxy = region_poly.bounds
    out = gdf.cx[minx:maxx, miny:maxy]
    return out[out.intersects(region_poly)].copy()

def key_xy(xy, prec_m=0.01):
    return (round(float(xy[0]) / prec_m) * prec_m, round(float(xy[1]) / prec_m) * prec_m)

def endpoints_keys_and_degree(lines_gdf, prec_m=0.01):
    deg = {}
    keys = set()
    for ls in lines_gdf.geometry:
        c = list(ls.coords)
        a = key_xy(c[0], prec_m)
        b = key_xy(c[-1], prec_m)
        keys.add(a); keys.add(b)
        deg[a] = deg.get(a, 0) + 1
        deg[b] = deg.get(b, 0) + 1
    return keys, deg

def keys_to_points_gdf(keys, crs="EPSG:3395"):
    pts = [Point(x, y) for (x, y) in keys]
    return gpd.GeoDataFrame({"n": np.arange(len(pts))}, geometry=pts, crs=crs)

def sample_points_gdf(gdf, max_points, seed=0):
    if max_points is None or len(gdf) <= max_points:
        return gdf
    rng = np.random.default_rng(seed)
    idx = rng.choice(len(gdf), size=int(max_points), replace=False)
    return gdf.iloc[idx].copy()

# ----------------------------
# layers for the 4 operations
# ----------------------------
def nodes_removed_degree2(G, H1, region_poly=None):
    H1_nodes = set(H1.nodes())
    deg = dict(G.degree())
    removed = []
    for n, d in deg.items():
        if d == 2 and n not in H1_nodes:
            if region_poly is None or region_poly.contains(Point(n)):
                removed.append(n)
    return gpd.GeoDataFrame(geometry=[Point(xy) for xy in removed], crs="EPSG:3395")

def nodes_remain_after_degree2(H1, region_poly=None):
    nodes = [n for n in H1.nodes() if region_poly is None or region_poly.contains(Point(n))]
    return gpd.GeoDataFrame(geometry=[Point(xy) for xy in nodes], crs="EPSG:3395")

def nodes_roundabout_layers(info_roundabout, region_poly=None):
    if info_roundabout is None:
        empty = gpd.GeoDataFrame(geometry=[], crs="EPSG:3395")
        return empty, empty

    members = list(info_roundabout.get("merge_nodes_before", []))
    cents   = list(info_roundabout.get("centroid_nodes_after", []))

    if region_poly is not None:
        members = [xy for xy in members if region_poly.contains(Point(xy))]
        cents   = [xy for xy in cents if region_poly.contains(Point(xy))]

    g_members = gpd.GeoDataFrame(geometry=[Point(xy) for xy in members], crs="EPSG:3395")
    g_cents   = gpd.GeoDataFrame(geometry=[Point(xy) for xy in cents], crs="EPSG:3395")
    return g_members, g_cents

def snap_clusters_from_graph(H1, snap_radius_m=30.0, min_degree=3, region_poly=None):
    nodes = [n for n in H1.nodes() if H1.degree(n) >= min_degree]
    if region_poly is not None:
        nodes = [n for n in nodes if region_poly.contains(Point(n))]

    if len(nodes) < 2:
        empty = gpd.GeoDataFrame(geometry=[], crs="EPSG:3395")
        return empty, empty

    X = np.asarray(nodes, dtype=float)
    db = DBSCAN(eps=float(snap_radius_m), min_samples=2).fit(X)  # only true clusters
    labels = db.labels_

    clusters = {}
    for xy, lab in zip(nodes, labels):
        if lab == -1:
            continue
        clusters.setdefault(int(lab), []).append(xy)

    members = []
    centroids = []
    for mem in clusters.values():
        xs = [m[0] for m in mem]
        ys = [m[1] for m in mem]
        members.extend(mem)
        centroids.append((float(np.mean(xs)), float(np.mean(ys))))

    g_members = gpd.GeoDataFrame(geometry=[Point(xy) for xy in members], crs="EPSG:3395")
    g_cents   = gpd.GeoDataFrame(geometry=[Point(xy) for xy in centroids], crs="EPSG:3395")
    return g_members, g_cents

def noding_created_and_remaining_nodes(roads_before, roads_after, prec_m=0.01, min_junction_deg=3):
    before_keys, _ = endpoints_keys_and_degree(roads_before, prec_m=prec_m)
    after_keys, after_deg = endpoints_keys_and_degree(roads_after, prec_m=prec_m)

    created = after_keys - before_keys
    created_gdf = keys_to_points_gdf(created)

    remain_gdf = keys_to_points_gdf(after_keys)

    junction_keys = {k for k, d in after_deg.items() if d >= min_junction_deg}
    junction_gdf = keys_to_points_gdf(junction_keys)

    return created_gdf, remain_gdf, junction_gdf

# ----------------------------
# main: build layers + summary
# ----------------------------
def build_cleaning_layers(
    gdf_lines_3395: gpd.GeoDataFrame,
    edges_noded_3395: gpd.GeoDataFrame,
    G_detailed,
    H1_deg2_removed,
    info_roundabout=None,
    region_poly=None,
    snap_radius_m=30.0,
    min_junction_deg=3,
    endpoint_prec_m=0.01,
    max_points_removed=15000,
    max_points_remain=25000,
    seed=0
):
    roads_before = clip_to_region(gdf_lines_3395, region_poly) if region_poly is not None else gdf_lines_3395.copy()
    roads_after  = clip_to_region(edges_noded_3395, region_poly) if region_poly is not None else edges_noded_3395.copy()

    g_deg2_removed = nodes_removed_degree2(G_detailed, H1_deg2_removed, region_poly)
    g_deg2_remain  = nodes_remain_after_degree2(H1_deg2_removed, region_poly)

    g_deg2_removed = sample_points_gdf(g_deg2_removed, max_points_removed, seed=seed)
    g_deg2_remain  = sample_points_gdf(g_deg2_remain,  max_points_remain,  seed=seed+1)

    g_round_members, g_round_centroids = nodes_roundabout_layers(info_roundabout, region_poly)

    g_snap_members, g_snap_centroids = snap_clusters_from_graph(
        H1_deg2_removed,
        snap_radius_m=snap_radius_m,
        min_degree=min_junction_deg,
        region_poly=region_poly,
    )

    g_noded_created, g_noded_remain, g_noded_junctions = noding_created_and_remaining_nodes(
        roads_before, roads_after,
        prec_m=endpoint_prec_m,
        min_junction_deg=min_junction_deg,
    )
    g_noded_created   = sample_points_gdf(g_noded_created,   max_points_removed, seed=seed+2)
    g_noded_remain    = sample_points_gdf(g_noded_remain,    max_points_remain,  seed=seed+3)
    g_noded_junctions = sample_points_gdf(g_noded_junctions, max_points_remain,  seed=seed+4)

    layers = {
        "roads_before": roads_before,
        "roads_after": roads_after,

        "deg2_removed": g_deg2_removed,
        "deg2_remain":  g_deg2_remain,

        "round_members":   g_round_members,
        "round_centroids": g_round_centroids,

        "snap_members":    g_snap_members,
        "snap_centroids":  g_snap_centroids,

        "noding_created":   g_noded_created,
        "noding_remain":    g_noded_remain,
        "noding_junctions": g_noded_junctions,
    }
    summary = {k: len(v) for k, v in layers.items() if isinstance(v, gpd.GeoDataFrame)}
    return layers, summary

# ----------------------------
# folium map helper
# ----------------------------
def folium_map_before_after_with_nodes(
    layers: dict,
    region_poly=None,
    center_latlon=None,
    zoom_start=6,
    out_html="cleaning_nodes_before_after.html",
):
    if center_latlon is None:
        b = layers["roads_before"].to_crs(4326).total_bounds
        center_latlon = ((b[1]+b[3])/2, (b[0]+b[2])/2)

    def add_geojson(parent, gdf_4326, name, color, weight=2, opacity=0.65):
        if gdf_4326 is None or len(gdf_4326) == 0:
            return
        folium.GeoJson(
            gdf_4326.__geo_interface__,
            name=name,
            style_function=lambda feat: {"color": color, "weight": weight, "opacity": opacity},
        ).add_to(parent)

    def add_points(parent, gdf_4326, name, color, radius=3):
        if gdf_4326 is None or len(gdf_4326) == 0:
            return
        folium.GeoJson(
            gdf_4326.__geo_interface__,
            name=name,
            marker=folium.CircleMarker(radius=radius, fill=True, color=color, fill_opacity=0.9),
        ).add_to(parent)

    m = folium.Map(location=[center_latlon[0], center_latlon[1]], zoom_start=zoom_start, tiles="CartoDB positron")

    fg_before = folium.FeatureGroup(name="ROADS — ORIGINAL (before cleaning)", show=True)
    fg_after  = folium.FeatureGroup(name="ROADS — CLEANED (after cleaning)", show=False)
    m.add_child(fg_before); m.add_child(fg_after)

    add_geojson(fg_before, layers["roads_before"].to_crs(4326), "roads_before", color="#1f77b4", weight=2, opacity=0.65)
    add_geojson(fg_after,  layers["roads_after"].to_crs(4326),  "roads_after",  color="#d62728", weight=2, opacity=0.65)

    if region_poly is not None:
        reg = gpd.GeoDataFrame(geometry=[region_poly], crs="EPSG:3395").to_crs(4326)
        add_geojson(m, reg, "Region boundary", color="#444444", weight=2, opacity=0.55)
        b = reg.total_bounds
        m.fit_bounds([[b[1], b[0]], [b[3], b[2]]])

    # node layers
    add_points(m, layers["deg2_removed"].to_crs(4326), "Nodes REMOVED: degree-2", color="#ff00ff", radius=3)
    add_points(m, layers["deg2_remain"].to_crs(4326),  "Nodes REMAIN: after degree-2 (H1)", color="#00cc96", radius=2)

    add_points(m, layers["round_members"].to_crs(4326),   "Nodes MODIFIED: roundabout members (before)", color="#ff7f0e", radius=4)
    add_points(m, layers["round_centroids"].to_crs(4326), "Nodes REMAIN: roundabout centroids (after)",  color="#2ca02c", radius=6)

    add_points(m, layers["snap_members"].to_crs(4326),   "Nodes MODIFIED: snap members (before)", color="#9467bd", radius=4)
    add_points(m, layers["snap_centroids"].to_crs(4326), "Nodes REMAIN: snap centroids (after)",  color="#17becf", radius=6)

    add_points(m, layers["noding_created"].to_crs(4326),  "Nodes CREATED: by noding", color="#e377c2", radius=3)
    add_points(m, layers["noding_remain"].to_crs(4326),   "Nodes REMAIN: endpoints after noding", color="#7f7f7f", radius=2)
    add_points(m, layers["noding_junctions"].to_crs(4326),"Nodes REMAIN: junction endpoints after noding (deg>=3)", color="#bcbd22", radius=3)

    folium.LayerControl(collapsed=False).add_to(m)
    m.save(out_html)
    return out_html

layers_all, summary_all = build_cleaning_layers(
    gdf_lines_3395=gdf_lines,
    edges_noded_3395=edges_noded,
    G_detailed=G,
    H1_deg2_removed=H1,
    info_roundabout=info_roundabout,  # or None
    region_poly=None,                 # whole dataset
    snap_radius_m=30.0,
    min_junction_deg=3,
    endpoint_prec_m=0.01,
    max_points_removed=1000000,
    max_points_remain=25000,
    seed=0
)

print(summary_all)

out_html = folium_map_before_after_with_nodes(
    layers_all,
    region_poly=None,
    center_latlon=None,
    zoom_start=5,
    out_html="full_cleaning_nodes_before_after.html"
)

print("Saved:", out_html)


CRS: EPSG:3395
Rows: 14769
Columns: ['Name', 'Route_Type', 'Type', 'Lower_Date', 'Low_Date_E', 'Upper_Date', 'Up_Date_E', 'Descriptio', 'Citation', 'Bibliograp', 'Cons_per_e', 'Itinerary', 'Segment_s', 'Avg_Slope', 'passabilit', 'Shape_Leng', 'InLine_FID', 'MaxSimpTol', 'MinSimpTol', 'geometry']
Geometry types:
 MultiLineString    14769
Name: count, dtype: int64
Total bounds: [-1051606.9533      2578001.81270617  4627965.33944631  7518329.0997    ]
After explode:
CRS: EPSG:3395
Rows: 14769
Geometry types:
 LineString    14769
Name: count, dtype: int64
Empty geometries: 0
Null geometries: 0
unary_union result geom_type: MultiLineString
Noded edges:
Rows: 16854
Geometry types:
 LineString    16854
Name: count, dtype: int64
Empty: 0
Null: 0
Input total length: 400175337.99454165
Noded total length: 399914306.6043441
Nodes (unique endpoints): 13860
Degree summary:
count    13860.000000
mean         2.432035
std          1.034891
min          1.000000
25%          2.000000
50%          3.00

# 2. Weighted Network Construction

In [None]:
import numpy as np
import pandas as pd
import geopandas as gpd
import networkx as nx
from shapely.geometry import Point, LineString
from shapely.strtree import STRtree

# ============================================================
# 0) Preconditions / sanity
# ============================================================
need = ["gdf_lines", "edges_noded", "edge_geoms1", "H1", "G"]
missing = [k for k in need if k not in globals()]
if missing:
    raise RuntimeError(f"Missing objects: {missing}. Run the cleaning/build steps first.")

if gdf_lines.crs is None or str(gdf_lines.crs).upper() != "EPSG:3395":
    raise RuntimeError(f"gdf_lines must be EPSG:3395. Found: {gdf_lines.crs}")
if "Avg_Slope" not in gdf_lines.columns:
    raise RuntimeError("gdf_lines must contain Avg_Slope (degrees).")

# Keep only LineStrings
gdf_src = gdf_lines[gdf_lines.geometry.notna() & ~gdf_lines.geometry.is_empty].copy()
gdf_src = gdf_src[gdf_src.geometry.geom_type == "LineString"].copy().reset_index(drop=True)
gdf_src["slope_deg_src"] = pd.to_numeric(gdf_src["Avg_Slope"], errors="coerce").astype(float)

# ============================================================
# 1) Build geometry-preserving cleaned edges (deg-2 removed + optional roundabout merge)
# ============================================================
def undirected_key(a, b):
    return (a, b) if a <= b else (b, a)

def merge_edge_geoms_with_mapping(edge_geoms, mapping):
    """
    Apply node mapping (old node -> new centroid node) to polyline edges.
    Preserves the polyline shape by only moving first/last coordinate to the mapped node.
    If multiple edges collapse to same undirected (u2,v2), keeps the shortest.
    """
    best = {}
    for e in edge_geoms:
        u, v = e["u"], e["v"]
        u2 = mapping.get(u, u)
        v2 = mapping.get(v, v)
        if u2 == v2:
            continue

        coords = list(e["geometry"].coords)
        coords[0] = u2
        coords[-1] = v2
        geom2 = LineString(coords)

        k = undirected_key(u2, v2)
        dist = float(geom2.length)

        if (k not in best) or (dist < best[k]["dist_m"]):
            best[k] = {
                "u": k[0],
                "v": k[1],
                "weight": dist,
                "dist_m": dist,
                "geometry": geom2,
            }
    return list(best.values())

def graph_from_edge_geoms(edge_geoms):
    G = nx.Graph()
    for e in edge_geoms:
        u, v = e["u"], e["v"]
        if u == v:
            continue
        w = float(e.get("weight", e["geometry"].length))
        if G.has_edge(u, v):
            G[u][v]["weight"] = min(float(G[u][v]["weight"]), w)
        else:
            G.add_edge(u, v, weight=w)
    return G

# Decide whether to include roundabout merge.
# If you computed mapping via merge_close_intersections_dbscan(H1,...), use it.
USE_ROUNDABOUT_MERGE = ("mapping" in globals())

if USE_ROUNDABOUT_MERGE:
    edge_geoms2 = merge_edge_geoms_with_mapping(edge_geoms1, mapping)
    G2 = graph_from_edge_geoms(edge_geoms2)
    H_clean, edge_geoms_clean = simplify_degree2_until_stable(G2)
    print(f"Cleaned network: deg-2 removed + roundabout merge -> {H_clean.number_of_nodes():,} nodes, {H_clean.number_of_edges():,} edges")
else:
    edge_geoms_clean = edge_geoms1
    H_clean = H1
    print(f"Cleaned network: deg-2 removed only -> {H_clean.number_of_nodes():,} nodes, {H_clean.number_of_edges():,} edges")

# Turn cleaned edge_geoms into GeoDataFrame (this is what you were missing)
edges_clean_3395 = gpd.GeoDataFrame(edge_geoms_clean, crs="EPSG:3395").copy()
edges_clean_3395 = edges_clean_3395[edges_clean_3395.geometry.notna() & ~edges_clean_3395.geometry.is_empty].copy()
edges_clean_3395 = edges_clean_3395[edges_clean_3395.geometry.geom_type == "LineString"].reset_index(drop=True)
edges_clean_3395["edge_id"] = np.arange(1, len(edges_clean_3395) + 1, dtype="int64")
edges_clean_3395["dist_m"] = edges_clean_3395.geometry.length.astype(float)
edges_clean_3395["dist_km"] = edges_clean_3395["dist_m"] / 1000.0

print("edges_clean_3395 created:", len(edges_clean_3395))


# ============================================================
# 2) Transfer slope from original segments -> cleaned edges (length-weighted overlap)
# ============================================================
def slope_transfer_length_weighted_deg(edges_gdf, src_segments_gdf, buffer_m=5.0, min_intersection_m=1.0):
    """
    For each cleaned edge:
      - query source segments within buffer_m
      - compute intersection length with each candidate
      - slope_deg = weighted mean by intersection length
      - fallback: nearest candidate in buffer (if no overlap)
    """
    edges = edges_gdf.copy()

    src_geoms = list(src_segments_gdf.geometry)
    src_slopes = src_segments_gdf["slope_deg_src"].to_numpy(dtype=float)
    tree = STRtree(src_geoms)

    # For Shapely<2 fallback, we need WKB->idx map
    wkb_to_idx = {g.wkb: i for i, g in enumerate(src_geoms)}

    out = np.full(len(edges), np.nan, dtype=float)

    for i, geom in enumerate(edges.geometry):
        q = geom.buffer(buffer_m)
        res = tree.query(q)

        if len(res) == 0:
            continue

        # Shapely 2.x: indices, else geometries
        if isinstance(res[0], (int, np.integer)):
            cand_idx = [int(x) for x in res]
        else:
            cand_idx = [wkb_to_idx.get(g.wkb) for g in res]
            cand_idx = [j for j in cand_idx if j is not None]

        num = 0.0
        den = 0.0
        best_d = np.inf
        best_s = np.nan

        for j in cand_idx:
            seg = src_geoms[j]
            s = src_slopes[j]
            if np.isnan(s):
                continue

            d = geom.distance(seg)
            if d < best_d:
                best_d = d
                best_s = s

            inter = geom.intersection(seg)
            if inter.is_empty:
                continue
            L = float(inter.length)
            if L < min_intersection_m:
                continue
            num += L * float(s)
            den += L

        if den > 0:
            out[i] = num / den
        else:
            out[i] = float(best_s) if not np.isnan(best_s) else np.nan

    edges["slope_deg"] = out
    return edges

edges_w = slope_transfer_length_weighted_deg(edges_clean_3395, gdf_src, buffer_m=5.0, min_intersection_m=1.0)
print("Slope transferred. NaN slope fraction:", float(edges_w["slope_deg"].isna().mean()))


# ============================================================
# 3) Passability + time weights (Ox cart 2 km/h) using Tobler ratio; threshold 6%
# ============================================================
def tobler_speed_kmh(grad):
    # Tobler hiking function (km/h)
    return float(6.0 * np.exp(-3.5 * abs(grad + 0.05)))

def passability_multiplier_from_slope_deg(slope_deg, threshold_pct=6.0):
    """
    Isotropic passability:
      - convert degrees -> gradient = tan(theta)
      - if grad <= 0.06 => multiplier 1.0
      - else multiplier = v(grad)/v(0), clipped to [0.1, 1.0]
    """
    if slope_deg is None or np.isnan(slope_deg):
        return 1.0
    grad = abs(np.tan(np.deg2rad(float(slope_deg))))
    if grad <= (threshold_pct / 100.0):
        return 1.0
    v0 = tobler_speed_kmh(0.0)
    vs = tobler_speed_kmh(grad)
    mult = (vs / v0) if v0 > 0 else 1.0
    return float(np.clip(mult, 0.1, 1.0))

def add_time_weights(edges_gdf, base_speed_kmh=2.0, threshold_pct=6.0):
    edges = edges_gdf.copy()
    edges["passability"] = edges["slope_deg"].apply(lambda d: passability_multiplier_from_slope_deg(d, threshold_pct))
    edges["speed_kmh_eff"] = (base_speed_kmh * edges["passability"]).astype(float)
    edges["time_h"] = (edges["dist_km"] / edges["speed_kmh_eff"]).astype(float)
    edges["time_s"] = (edges["time_h"] * 3600.0).astype(float)
    edges["time_min"] = (edges["time_s"] / 60.0).astype(float)
    return edges

edges_w = add_time_weights(edges_w, base_speed_kmh=2.0, threshold_pct=6.0)
print(edges_w[["dist_m","slope_deg","passability","speed_kmh_eff","time_min"]].describe())


# ============================================================
# 4) Build time-weighted NetworkX graph (weight = time_s)
# ============================================================
def build_time_weighted_graph(edges_gdf):
    Gt = nx.Graph()
    for row in edges_gdf.itertuples(index=False):
        geom = row.geometry
        coords = list(geom.coords)
        if len(coords) < 2:
            continue
        u = (float(coords[0][0]), float(coords[0][1]))
        v = (float(coords[-1][0]), float(coords[-1][1]))
        if u == v:
            continue

        time_s = float(getattr(row, "time_s"))
        dist_m = float(getattr(row, "dist_m"))
        slope_deg = float(getattr(row, "slope_deg")) if not np.isnan(getattr(row, "slope_deg")) else np.nan
        passability = float(getattr(row, "passability"))
        speed_kmh_eff = float(getattr(row, "speed_kmh_eff"))

        # keep fastest if duplicates exist
        if Gt.has_edge(u, v):
            if time_s < Gt[u][v]["weight_time_s"]:
                Gt[u][v].update(
                    weight_time_s=time_s,
                    dist_m=dist_m,
                    slope_deg=slope_deg,
                    passability=passability,
                    speed_kmh_eff=speed_kmh_eff,
                    geometry=geom,
                )
        else:
            Gt.add_edge(
                u, v,
                weight_time_s=time_s,
                dist_m=dist_m,
                slope_deg=slope_deg,
                passability=passability,
                speed_kmh_eff=speed_kmh_eff,
                geometry=geom,
            )
    return Gt

Gt_3395 = build_time_weighted_graph(edges_w)
print(f"Time-weighted graph (EPSG:3395): {Gt_3395.number_of_nodes():,} nodes, {Gt_3395.number_of_edges():,} edges")

Gu = nx.Graph()
for u, v, d in Gt_3395.edges(data=True):
    w = d.get("weight", 1.0)
    if Gu.has_edge(u, v):
        Gu[u][v]["weight"] = min(Gu[u][v]["weight"], w)
    else:
        Gu.add_edge(u, v, **d)

Gt_3395 = Gu


Cleaned network: deg-2 removed + roundabout merge -> 9,458 nodes, 12,680 edges
edges_clean_3395 created: 12680
Slope transferred. NaN slope fraction: 0.0
             dist_m     slope_deg   passability  speed_kmh_eff      time_min
count  1.268000e+04  12680.000000  12680.000000   12680.000000  1.268000e+04
mean   2.661819e+04      2.827660      0.924007       1.848015  8.770041e+02
std    2.971207e+04      2.234975      0.136214       0.272428  9.834805e+02
min    5.820766e-11      0.000000      0.100000       0.200000  1.746230e-12
25%    6.719448e+03      1.343194      0.805627       1.611255  2.223910e+02
50%    1.819454e+04      2.118602      1.000000       2.000000  6.003567e+02
75%    3.709574e+04      3.533673      1.000000       2.000000  1.211359e+03
max    4.481473e+05     33.574162      1.000000       2.000000  1.344442e+04
Time-weighted graph (EPSG:3395): 9,458 nodes, 12,680 edges


# Check: Nodes and Edges consitency across the two networks

In [None]:
# print("H1 nodes/edges:", H1.number_of_nodes(), H1.number_of_edges())
# # expect ~9624 nodes, ~12874 edges (your earlier numbers)
# deg2_removed_calc = sum(1 for n,d in dict(G.degree()).items() if d==2 and n not in set(H1.nodes()))
# print("deg2_removed_calc:", deg2_removed_calc)
# print("H_clean nodes/edges:", H_clean.number_of_nodes(), H_clean.number_of_edges())
# print("edges_clean_3395 rows:", len(edges_clean_3395))


H1 nodes/edges: 9624 12874
deg2_removed_calc: 988210
H_clean nodes/edges: 9458 12680
edges_clean_3395 rows: 12680


# 3. Weighted Network Statistics

In [8]:
def graph_stats(G: nx.Graph, weight: str = "weight"):
    """
    Compute basic statistics for an (un)weighted NetworkX graph.

    Assumptions:
      - Undirected graph (nx.Graph). If it's a DiGraph, see note below.
      - Edge attribute `weight` exists for weighted degree; if missing, treated as 1.
    """
    n = G.number_of_nodes()
    m = G.number_of_edges()

    # Average degree: 2m / n for undirected (guard n=0)
    avg_deg = (2.0 * m / n) if n else 0.0

    # Weighted degree: sum of incident weights per node (strength)
    # If an edge has no 'weight', NetworkX treats it as 1 for degree(weight=...)
    wdeg = dict(G.degree(weight='weight_time_s'))
    avg_wdeg = (float(np.mean(list(wdeg.values()))) if n else 0.0)

    # Connected components
    comps = list(nx.connected_components(G))
    num_cc = len(comps)
    largest_cc_size = max((len(c) for c in comps), default=0)

    # Optional: edges in largest component + fraction of nodes in it
    if largest_cc_size > 0:
        largest_nodes = max(comps, key=len)
        G_lcc = G.subgraph(largest_nodes)
        lcc_edges = G_lcc.number_of_edges()
        lcc_frac = largest_cc_size / n if n else 0.0
    else:
        lcc_edges = 0
        lcc_frac = 0.0

    out = {
        "nodes": n,
        "edges": m,
        "avg_degree": avg_deg,
        "avg_weighted_degree": avg_wdeg,
        "connected_components": num_cc,
        "largest_component_nodes": largest_cc_size,
        "largest_component_edges": lcc_edges,
        "largest_component_fraction_of_nodes": lcc_frac,
    }
    return out


stats = graph_stats(Gt_3395, weight="weight")
for k, v in stats.items():
    print(f"{k:35s}: {v}")


nodes                              : 9458
edges                              : 12680
avg_degree                         : 2.681327976316346
avg_weighted_degree                : 141092.14290345868
connected_components               : 218
largest_component_nodes            : 4784
largest_component_edges            : 7017
largest_component_fraction_of_nodes: 0.5058151829139353
