In [None]:
%pip install -r ../requirements.txt

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import folium
import networkx as nx
import pickle
from pathlib import Path
from shapely.geometry import Point, LineString

In [None]:
# Define file paths
station_path = '../datasets/japan/N02-24_GML/UTF-8/N02-24_Station.geojson'
railroad_path = '../datasets/japan/N02-24_GML/UTF-8/N02-24_RailroadSection.geojson'

# Load data
print("Loading stations...")
gdf_stations = gpd.read_file(station_path)
print(f"Loaded {len(gdf_stations)} stations.")

print("Loading railroad sections...")
gdf_railroads = gpd.read_file(railroad_path)
print(f"Loaded {len(gdf_railroads)} railroad sections.")

## 1. Understanding the Data Structure

**Difference between the files:**
*   **`N02-24_Station.geojson`**: Contains **LineString** geometries representing the platforms of stations. Attributes include Station Name (`N02_005`) and Operator (`N02_004`).
*   **`N02-24_RailroadSection.geojson`**: Contains **LineString** geometries representing the physical railway tracks. Attributes include Line Name (`N02_003`) and Operator (`N02_004`).

**The Connectivity Challenge:**
For a NetworkX graph, we need explicit "Node A -> Node B" connections. This dataset provides the *visual* lines. To build a graph, we must verify if the endpoints of the `RailroadSection` lines spatially coincide with the `Station` platform lines.

In [None]:
# Display first few rows to check attributes

# (N02_001) Railway classification	Differentiation by type of railway line.
# (N02_002) Business type	        Differentiation by railway line operators.
# (N02_003) Route name	            Name of the railway line
# (N02_004) Operating company	    A company that operates railway lines.
# (N02_005) Station name	        Name of the station
# (N02_005c) Station code	        The unique number added by sorting the latitude of the station in descending order
# (N02_005g) Group code	            Group code A station within 300m and a station with the same name as a group, and the station code closest to the center of gravity of the group
display(gdf_stations.head(3))

# (N02_001) Railway classification	Differentiation by type of railway line.
# (N02_002) Business type	        Differentiation by railway line operators.
# (N02_003) Route name	            Name of the railway line
# (N02_004) Operating company       A company that operates railway lines.
display(gdf_railroads.head(3))

## 2. Visualizing the Network
We will plot a subset of the data (e.g., Tokyo area) to visually inspect the alignment.

In [None]:
# Filter for a specific area (e.g., Tokyo) to keep the map responsive
# Tokyo coordinates approx: 35.68, 139.76
# We'll use a bounding box or just plot the first N features for a quick check if the dataset is huge.
# However, let's try to plot the whole thing statically first.

fig, ax = plt.subplots(figsize=(15, 15))
gdf_railroads.plot(ax=ax, color='gray', linewidth=0.5, alpha=0.7, label='Railroads')
gdf_stations.plot(ax=ax, color='red', markersize=2, label='Stations')
plt.title("Japanese Railway Network (Static View)")
plt.legend()
plt.show()

## 3. Interactive Map (Folium)
Let's zoom in to see if the lines actually touch the stations.

In [None]:
# Create a map centered on Tokyo
m = folium.Map(location=[35.6812, 139.7671], zoom_start=12, tiles='CartoDB Positron')

sample_rail = gdf_railroads
sample_stations = gdf_stations

folium.GeoJson(
    sample_rail,
    name='Railroads',
    style_function=lambda x: {'color': 'blue', 'weight': 2, 'opacity': 0.6}
).add_to(m)

folium.GeoJson(
    sample_stations,
    name='Stations',
    style_function=lambda x: {'color': 'red', 'fillColor': 'red', 'radius': 4}
).add_to(m)

folium.LayerControl().add_to(m)
m

## 4. Topological Assessment
We check if the station coordinates are *exactly* on the line endpoints.

In [None]:
# Extract all start and end points of the railroad sections
rail_endpoints = []
for line in gdf_railroads.geometry:
    if line.geom_type == 'LineString':
        coords = list(line.coords)
        rail_endpoints.append(coords[0]) # Start
        rail_endpoints.append(coords[-1]) # End

rail_endpoints_set = set(rail_endpoints)

# Check how many stations match these endpoints exactly
exact_matches = 0
total_stations = len(gdf_stations)

for points in gdf_stations.geometry:
    # Check if (x, y) is in the endpoints set
    for (x, y) in points.coords:
        point = Point(x, y)
        if (point.x, point.y) in rail_endpoints_set:
            exact_matches += 1

print(f"Total Stations: {total_stations}")
print(f"Stations exactly matching rail endpoints: {exact_matches}")
print(f"Percentage: {exact_matches / total_stations * 100:.2f}%")

## 5. Building the NetworkX Graph

We'll now construct a NetworkX graph by:
1. Grouping station platforms by name to create station nodes
2. Using exact coordinate matching to map railroad endpoints to stations
3. Snapping railroad geometries to station centroids
4. Exporting to `.gpickle` format for later analysis

In [None]:
from collections import defaultdict
# Use pre-computed Group Code (N02_005g) for clustering
# The group code groups stations within 300m that share the same name

print("Step 1: Group station platforms by Group Code (N02_005g)")

group_code_groups = defaultdict(list)
for idx, station in gdf_stations.iterrows():
    group_code = station['N02_005g']
    group_code_groups[group_code].append({
        'idx': idx,
        'geometry': station.geometry,
        'name': station['N02_005'],
        'operator': station.get('N02_004'),
        'station_code': station.get('N02_005c'),
    })

print(f"Found {len(group_code_groups)} unique group codes")

# Step 2: Check for large centroid deviations (sanity check)
print("\nStep 2: Sanity check - verify group codes are spatially coherent:")
SPREAD_THRESHOLD_KM = 1

problematic_groups = []
for group_code, platforms in group_code_groups.items():
    if len(platforms) < 2:
        continue
    
    all_coords = []
    for p in platforms:
        all_coords.extend(list(p['geometry'].coords))
    
    lons = [c[0] for c in all_coords]
    lats = [c[1] for c in all_coords]
    
    lat_spread_km = (max(lats) - min(lats)) * 111
    lon_spread_km = (max(lons) - min(lons)) * 111 * 0.8
    max_spread_km = max(lat_spread_km, lon_spread_km)
    
    if max_spread_km > SPREAD_THRESHOLD_KM:
        names = set(p['name'] for p in platforms)
        problematic_groups.append((group_code, max_spread_km, len(platforms), names))

if problematic_groups:
    print(f"WARNING: {len(problematic_groups)} group codes have platforms spread > {SPREAD_THRESHOLD_KM}km apart:")
    for group_code, spread, count, names in sorted(problematic_groups, key=lambda x: -x[1])[:10]:
        print(f"    Group '{group_code}': {spread:.2f}km spread, {count} platforms, names: {names}")
else:
    print("✓ All group codes are spatially coherent")

In [None]:
from shapely.ops import unary_union
from collections import Counter

# Step 3: Create station nodes from group codes
print("Step 3: Create station nodes")
station_groups = {}
coord_to_station = {}

for group_code, platforms in group_code_groups.items():
    station_id = group_code  # Use group code as station ID
    
    all_platform_geoms = [p['geometry'] for p in platforms]
    combined_platforms = unary_union(all_platform_geoms)
    centroid = combined_platforms.centroid
    
    # Get all unique names and operators in this group
    names = list(set(p['name'] for p in platforms if p['name']))
    operators = list(set(p['operator'] for p in platforms if p['operator']))
    coords = list(set([coord for p in platforms for coord in p['geometry'].coords]))
    
    # Use the most common name as the display name
    name_counts = Counter(p['name'] for p in platforms if p['name'])
    display_name = max(name_counts, key=name_counts.get) if name_counts else f"Station_{group_code}"
    
    station_groups[station_id] = {
        'centroid': centroid,
        'coords': coords,
        'geometry': combined_platforms,
        'lat': centroid.y,
        'lon': centroid.x,
        'platforms': combined_platforms,
        'name': display_name,
        'all_names': names,
        'operators': operators,
        'platform_count': len(platforms),
        'group_code': group_code,
    }

for station_id, station in station_groups.items():
    for coord in station['coords']:
        if coord in coord_to_station:
            print(f"Warning: Coordinate {coord} already mapped to station {coord_to_station[coord]}")
        coord_to_station[coord] = station_id

print(f"Created {len(station_groups)} station nodes")
print(f"Total platform coordinates mapped: {len(coord_to_station)}")

In [None]:
# Step 4: Detect and merge interchange stations (shared coordinates)
print("Step 4: Detect and merge interchange stations")

# Find coordinates shared by multiple station groups
coord_to_groups = defaultdict(set)
for group_code, data in station_groups.items():
    for coord in data['coords']:
        coord_to_groups[coord].add(group_code)

# Find coordinates shared by multiple groups
shared_coords = {coord: groups for coord, groups in coord_to_groups.items() if len(groups) > 1}

if shared_coords:
    # Build interchange groups using Union-Find
    interchange_pairs = defaultdict(set)
    for coord, groups in shared_coords.items():
        groups_list = list(groups)
        for i in range(len(groups_list)):
            for j in range(i + 1, len(groups_list)):
                interchange_pairs[groups_list[i]].add(groups_list[j])
                interchange_pairs[groups_list[j]].add(groups_list[i])
    
    print(f"Found {len(interchange_pairs)} station groups with shared coordinates")
    
    # Union-Find to merge connected interchange stations
    parent = {code: code for code in station_groups}
    
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
    
    def union(x, y):
        px, py = find(x), find(y)
        if px != py:
            parent[px] = py
    
    # Union all interchange pairs
    for group_code, partners in interchange_pairs.items():
        for partner in partners:
            union(group_code, partner)
    
    # Group by root
    root_to_members = defaultdict(set)
    for code in station_groups:
        root_to_members[find(code)].add(code)
    
    # Merge station groups that share coordinates
    merged_count = 0
    for root, member_codes in root_to_members.items():
        if len(member_codes) > 1:
            merged_count += 1
            # Pick the shortest name as the representative
            representative = min(member_codes, key=lambda c: len(station_groups[c]['name']))
            
            # Collect all data from member groups
            all_coords = list(set(coord for code in member_codes for coord in station_groups[code]['coords']))
            all_names = list(set(name for code in member_codes for name in station_groups[code]['all_names']))
            all_operators = list(set(op for code in member_codes for op in station_groups[code]['operators']))
            all_platform_geoms = [station_groups[code]['platforms'] for code in member_codes]
            combined_platforms = unary_union(all_platform_geoms)
            centroid = combined_platforms.centroid
            total_platforms = sum(station_groups[code]['platform_count'] for code in member_codes)
            
            # Update representative with merged data
            station_groups[representative] = {
                'centroid': centroid,
                'lat': centroid.y,
                'lon': centroid.x,
                'geometry': combined_platforms,
                'platforms': combined_platforms,
                'name': station_groups[representative]['name'],
                'all_names': all_names,
                'operators': all_operators,
                'platform_count': total_platforms,
                'coords': all_coords,
                'group_code': representative,
                'merged_from': list(member_codes),
            }
            
            # Update coord_to_station mapping
            for coord in all_coords:
                coord_to_station[coord] = representative
            
            # Remove non-representative groups
            for code in member_codes:
                if code != representative:
                    del station_groups[code]
            
            print(f"  Merged {list(member_codes)} -> '{station_groups[representative]['name']}' ({representative})")
    
    print(f"\nMerged {merged_count} interchange groups")
    print(f"Station groups after merging: {len(station_groups)}")
else:
    print("✓ No stations share coordinates - no merging needed")

print(f"\nFinal station count: {len(station_groups)}")

In [None]:
# Step 5: Build the NetworkX graph with spatial snapping for small imprecision in coordinates
print("Step 5: Build NetworkX graph")

from scipy.spatial import cKDTree
import numpy as np

G = nx.Graph()

# Add all station nodes
for station_id, data in station_groups.items():
    G.add_node(
        station_id,
        node_type='station',
        lat=data['lat'],
        lon=data['lon'],
        name=data['name'],
        all_names=data.get('all_names', [data['name']]),
        operators=data.get('operators', []),
        platform_count=data.get('platform_count', 1),
        group_code=data.get('group_code', station_id),
        merged_from=data.get('merged_from', []),
    )

print(f"Added {len(station_groups)} station nodes")

# Build KD-tree of all station coordinates for spatial snapping
all_station_coords = []
coord_to_station_list = []  # parallel list for lookups
for station_id, data in station_groups.items():
    for coord in data['coords']:
        all_station_coords.append(coord)
        coord_to_station_list.append(station_id)

station_coord_array = np.array(all_station_coords)
station_tree = cKDTree(station_coord_array)

# Snapping threshold: ~20cm (in degrees, roughly 2e-6 at Japan's latitude)
# 1 degree ≈ 111km, so 0.2m ≈ 0.2/111000 ≈ 1.8e-6 degrees
SNAP_THRESHOLD_DEG = 2e-6  # ~20cm
print(f"Spatial snapping threshold: {SNAP_THRESHOLD_DEG}° (~{SNAP_THRESHOLD_DEG * 111000 * 100:.1f}cm)")

infra_node_counter = 0
infra_nodes = {}
infra_coords_list = []  # For KD-tree of infrastructure nodes
infra_tree = None

matched_endpoints = 0
snapped_endpoints = 0
unmatched_endpoints = 0
skipped_self_loops = 0

def get_or_create_infra_node(coord):
    """Get existing infra node or create new one, with spatial snapping."""
    global infra_node_counter, infra_tree
    
    # Check if there's already an infra node nearby
    if infra_coords_list:
        if infra_tree is None or len(infra_coords_list) > len(infra_tree.data):
            infra_tree = cKDTree(np.array(infra_coords_list))
        dist, idx = infra_tree.query(coord)
        if dist <= SNAP_THRESHOLD_DEG:
            return list(infra_nodes.values())[idx]
    
    # Create new infra node
    infra_node_id = f"INFRA_{infra_node_counter}"
    infra_node_counter += 1
    infra_nodes[coord] = infra_node_id
    infra_coords_list.append(coord)
    G.add_node(infra_node_id, node_type='infrastructure', lat=coord[1], lon=coord[0])
    return infra_node_id

for idx, rail in gdf_railroads.iterrows():
    if rail.geometry.geom_type != 'LineString':
        continue
    
    coords = list(rail.geometry.coords)
    start_coord = coords[0]
    end_coord = coords[-1]
    
    # Try exact match first, then spatial snapping to stations
    def resolve_endpoint(coord):
        global matched_endpoints, snapped_endpoints, unmatched_endpoints
        
        # Exact match
        if coord in coord_to_station:
            matched_endpoints += 1
            return coord_to_station[coord]
        
        # Spatial snapping to stations
        dist, idx = station_tree.query(coord)
        if dist <= SNAP_THRESHOLD_DEG:
            snapped_endpoints += 1
            return coord_to_station_list[idx]
        
        # Create or find infrastructure node (with snapping)
        unmatched_endpoints += 1
        return get_or_create_infra_node(coord)
    
    start_node = resolve_endpoint(start_coord)
    end_node = resolve_endpoint(end_coord)
    
    if start_node == end_node:
        skipped_self_loops += 1
        continue
    
    line_name = rail.get('N02_003') if 'N02_003' in rail.index else None
    operator = rail.get('N02_004') if 'N02_004' in rail.index else None
    
    if G.has_edge(start_node, end_node):
        if line_name and line_name not in G[start_node][end_node]['lines']:
            G[start_node][end_node]['lines'].append(line_name)
    else:
        G.add_edge(start_node, end_node, lines=[line_name] if line_name else [], operator=operator)

print(f"\nGraph construction complete:")
print(f"  Total nodes: {G.number_of_nodes():,}")
print(f"  - Station nodes: {len(station_groups):,}")
print(f"  - Infrastructure nodes: {len(infra_nodes):,}")
print(f"  Total edges: {G.number_of_edges():,}")
print(f"  Skipped self-loops: {skipped_self_loops:,}")
print(f"\nEndpoint matching:")
print(f"  Exact matches: {matched_endpoints:,}")
print(f"  Snapped (within {SNAP_THRESHOLD_DEG * 111000 * 100:.0f}cm): {snapped_endpoints:,}")
print(f"  Unmatched (infrastructure): {unmatched_endpoints:,}")
total_endpoints = matched_endpoints + snapped_endpoints + unmatched_endpoints
print(f"  Station match rate: {(matched_endpoints + snapped_endpoints) / total_endpoints * 100:.1f}%")

In [None]:
print("Step 6: Analyze graph structure")

num_components = nx.number_connected_components(G)
components = sorted(nx.connected_components(G), key=len, reverse=True)

print(f"Connected components: {num_components}")
if num_components > 0:
    largest = components[0]
    print(f"Largest component: {len(largest):,} nodes ({len(largest)/G.number_of_nodes()*100:.1f}%)")
    
    # Count single-node station components
    single_station_components = [c for c in components if len(c) == 1 and G.nodes[list(c)[0]].get('node_type') == 'station']
    if single_station_components:
        print(f"\nWARNING: {len(single_station_components)} stations are isolated (single-node components):")
        for c in single_station_components[:10]:
            node = list(c)[0]
            node_data = G.nodes[node]
            print(f"    '{node}' ({node_data.get('name', 'unknown')})")
        if len(single_station_components) > 10:
            print(f"    ... and {len(single_station_components) - 10} more")

In [None]:
print("Step 6b: Diagnosing isolated stations")

from scipy.spatial import cKDTree
import numpy as np

# Get all isolated stations
isolated_stations = [list(c)[0] for c in components if len(c) == 1 and G.nodes[list(c)[0]].get('node_type') == 'station']

# Build a KD-tree of all railroad endpoints
all_rail_endpoints = []
endpoint_to_rail_idx = {}
for idx, rail in gdf_railroads.iterrows():
    if rail.geometry.geom_type != 'LineString':
        continue
    coords = list(rail.geometry.coords)
    for coord in [coords[0], coords[-1]]:
        if coord not in endpoint_to_rail_idx:
            endpoint_to_rail_idx[coord] = []
        endpoint_to_rail_idx[coord].append(idx)
        all_rail_endpoints.append(coord)

# Remove duplicates for KD-tree
unique_endpoints = list(set(all_rail_endpoints))
endpoint_array = np.array(unique_endpoints)
tree = cKDTree(endpoint_array)

print(f"Isolated stations: {len(isolated_stations)}")
print(f"Unique rail endpoints: {len(unique_endpoints)}\n")

# Analyze each isolated station
for station_id in isolated_stations[:10]:  # First 10
    station_data = station_groups[station_id]
    station_coords = station_data['coords']
    
    print(f"Station '{station_id}' ({station_data['name']}):")
    print(f"  Platform coords: {len(station_coords)}")
    
    # Find nearest rail endpoint to any station coordinate
    min_dist = float('inf')
    nearest_endpoint = None
    nearest_station_coord = None
    
    for coord in station_coords:
        dist, idx = tree.query(coord)
        if dist < min_dist:
            min_dist = dist
            nearest_endpoint = unique_endpoints[idx]
            nearest_station_coord = coord
    
    # Convert to approximate meters (at Japan's latitude ~35°)
    dist_meters = min_dist * 111000 * 0.9  # rough conversion
    
    print(f"  Nearest rail endpoint: {min_dist:.8f}° (~{dist_meters:.1f}m)")
    print(f"    Station coord: {nearest_station_coord}")
    print(f"    Rail endpoint: {nearest_endpoint}")
    print()

In [None]:
# Step 7: Export the graph
output_dir = Path('../datasets/japan')
output_path = output_dir / 'japan_rail_network.gpickle'

print(f"Exporting graph to {output_path}...")

with output_path.open('wb') as f:
    pickle.dump(G, f)

print(f"✓ Graph exported successfully!")

In [None]:
# Step 8: Visualize the graph (Swiss-style with station/infrastructure layers)
print("Creating interactive map...")

# Calculate center of all nodes
all_lats = [data['lat'] for _, data in G.nodes(data=True) if data.get('lat')]
all_lons = [data['lon'] for _, data in G.nodes(data=True) if data.get('lon')]
center_lat = sum(all_lats) / len(all_lats)
center_lon = sum(all_lons) / len(all_lons)

# Create base map
m = folium.Map(location=[center_lat, center_lon], zoom_start=6, tiles='CartoDB Positron')

# Define colors (similar to Swiss map)
STATION_COLOR = '#1f77b4'  # Blue for stations
INFRA_COLOR = '#ff7f0e'    # Orange for infrastructure
EDGE_COLOR = '#6c757d'     # Gray for edges
ISOLATED_COLOR = '#d62728' # Red for isolated nodes

# Create feature groups for layers
edges_fg = folium.FeatureGroup(name='Edges', show=True)
stations_fg = folium.FeatureGroup(name='Stations', show=True)
infra_fg = folium.FeatureGroup(name='Infrastructure', show=False)
isolated_fg = folium.FeatureGroup(name='Isolated Stations', show=True)

# Identify isolated stations
isolated_nodes = set()
for comp in components:
    if len(comp) == 1:
        node = list(comp)[0]
        if G.nodes[node].get('node_type') == 'station':
            isolated_nodes.add(node)

# Add edges
for u, v, data in G.edges(data=True):
    u_data = G.nodes[u]
    v_data = G.nodes[v]
    
    if u_data.get('lat') and v_data.get('lat'):
        folium.PolyLine(
            [[u_data['lat'], u_data['lon']], [v_data['lat'], v_data['lon']]],
            color=EDGE_COLOR,
            weight=1,
            opacity=0.6,
        ).add_to(edges_fg)

# Add nodes
for node, data in G.nodes(data=True):
    if not data.get('lat'):
        continue
    
    node_type = data.get('node_type', 'unknown')
    is_isolated = node in isolated_nodes
    
    if is_isolated:
        color = ISOLATED_COLOR
        layer = isolated_fg
        radius = 6
    elif node_type == 'station':
        color = STATION_COLOR
        layer = stations_fg
        radius = 4
    else:
        color = INFRA_COLOR
        layer = infra_fg
        radius = 3
    
    # Create popup with node info
    popup_html = f"""
    <b>{data.get('name', node)}</b><br>
    ID: {node}<br>
    Type: {node_type}<br>
    """
    if node_type == 'station':
        popup_html += f"Operators: {', '.join(data.get('operators', ['N/A']))}<br>"
        popup_html += f"Platforms: {data.get('platform_count', 'N/A')}<br>"
        if data.get('all_names') and len(data['all_names']) > 1:
            popup_html += f"Alt names: {', '.join(data['all_names'])}<br>"
    popup_html += f"Degree: {G.degree(node)}"
    if is_isolated:
        popup_html += "<br><b style='color:red'>⚠️ ISOLATED</b>"
    
    folium.CircleMarker(
        location=[data['lat'], data['lon']],
        radius=radius,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=0.9 if is_isolated else 0.7,
        weight=1,
        tooltip=f"{data.get('name', node)} ({node})",
        popup=folium.Popup(popup_html, max_width=300),
    ).add_to(layer)

# Add layers to map
edges_fg.add_to(m)
stations_fg.add_to(m)
infra_fg.add_to(m)
isolated_fg.add_to(m)

# Add layer control
folium.LayerControl(collapsed=False).add_to(m)

# Summary
station_count = sum(1 for _, d in G.nodes(data=True) if d.get('node_type') == 'station')
infra_count = sum(1 for _, d in G.nodes(data=True) if d.get('node_type') == 'infrastructure')

print(f"✓ Map created:")
print(f"  Stations: {station_count:,} (isolated: {len(isolated_nodes)})")
print(f"  Infrastructure: {infra_count:,}")
print(f"  Edges: {G.number_of_edges():,}")

m