# Swiss rail network map

Compare the SBB-specific graph with the nationwide `schienennetz_2056_de` graph. Each map cell below focuses on a single dataset so you can inspect them independently.

In [None]:
from pathlib import Path
import pickle

import pandas as pd
from IPython.display import display

In [None]:

BASE_DIR = Path('datasets/switzerland')
SBB_STATION_PATH = BASE_DIR / 'sbb-dienststellen-gemass-opentransportdataswiss.csv'

DATASET_CONFIGS = [
    {
        'key': 'sbb',
        'label': 'SBB (opentransportdata)',
        'path': BASE_DIR / 'sbb_rail_network.gpickle',
        'fallback_paths': [BASE_DIR / 'swiss_rail_network.gpickle'],
        'node_colors': ('#1f77b4', '#ff7f0e'),
        'edge_color': '#6c757d',
        'show': True,
    },
    {
        'key': 'swisstopo',
        'label': 'swisstopo (schienennetz_2056_de)',
        'path': BASE_DIR / 'swiss_rail_network_swisstopo.gpickle',
        'fallback_paths': [],
        'node_colors': ('#2ca02c', '#d62728'),
        'edge_color': '#4b3f72',
        'show': True,
    },
]

station_metadata = pd.read_csv(SBB_STATION_PATH, sep=';')
station_metadata['abbreviation_clean'] = (
    station_metadata['abbreviation'].astype(str).str.strip().str.upper()
)
stop_point_mask = station_metadata['stopPoint'].astype(str).str.lower() == 'true'
station_abbreviation_set = set(
    station_metadata.loc[stop_point_mask, 'abbreviation_clean'].dropna().tolist()
)
station_abbreviation_set.discard('NAN')
print('Loaded station reference table with', len(station_abbreviation_set), 'stop-point abbreviations')


In [None]:

def parse_geopos(value):
    if isinstance(value, str) and ',' in value:
        lat_str, lon_str = value.split(',', 1)
        try:
            lat = float(lat_str.strip())
            lon = float(lon_str.strip())
        except ValueError:
            return None
        if -90 <= lat <= 90 and -180 <= lon <= 180:
            return lat, lon
    return None

def infer_station_flag(node_id, node_data):
    flag = node_data.get('is_station')
    if flag is not None:
        return bool(flag)
    abbr = node_data.get('abbreviation') or node_id
    abbr_clean = (abbr or '').strip().upper()
    if abbr_clean and abbr_clean in station_abbreviation_set:
        return True
    rows = node_data.get('rows') or []
    for row in rows:
        abbr_row = row.get('Station abbreviation')
        if isinstance(abbr_row, str) and abbr_row.strip().upper() in station_abbreviation_set:
            return True
    return False


In [None]:

dataset_contexts = []
for cfg in DATASET_CONFIGS:
    candidate_paths = [cfg['path']] + cfg.get('fallback_paths', [])
    path = next((p for p in candidate_paths if p.exists()), None)
    if path is None:
        expected = ', '.join(str(p) for p in candidate_paths)
        print(f"Skipping {cfg['label']} (no existing graph files among: {expected})")
        continue
    if path != cfg['path']:
        print(f"Using fallback graph {path} for {cfg['label']}")
    with path.open('rb') as f:
        G = pickle.load(f)
    node_records = []
    positions = {}
    missing = []
    station_count = 0
    for node_id, data in G.nodes(data=True):
        lat = data.get('lat')
        lon = data.get('lon')
        if lat is None or lon is None:
            coords = None
            for row in data.get('rows', []):
                coords = parse_geopos(row.get('Geopos')) or parse_geopos(row.get('Geopos_didok'))
                if coords:
                    lat, lon = coords
                    break
        if lat is None or lon is None:
            missing.append(node_id)
            continue
        is_station = infer_station_flag(node_id, data)
        station_count += int(is_station)
        label = data.get('label') or data.get('name') or node_id
        record = {
            'dataset': cfg['key'],
            'dataset_label': cfg['label'],
            'node_id': node_id,
            'label': label,
            'lat': lat,
            'lon': lon,
            'is_station': bool(is_station),
        }
        node_records.append(record)
        positions[node_id] = (lat, lon)
    infra_count = len(node_records) - station_count
    dataset_contexts.append({
        'config': cfg,
        'graph': G,
        'graph_path': path,
        'node_records': node_records,
        'positions': positions,
        'missing_nodes': missing,
    })
    print(f"{cfg['label']}: plotted {len(node_records)} nodes (stations {station_count}, infrastructure {infra_count}, missing coords {len(missing)})")

if not dataset_contexts:
    raise RuntimeError('No graph files were found. Generate them in the exploration notebook first.')

context_by_key = {ctx['config']['key']: ctx for ctx in dataset_contexts}
print('Datasets loaded:', ', '.join(context_by_key.keys()))


In [None]:

import folium

def build_dataset_map(ctx, zoom_start=8):
    node_df = pd.DataFrame(ctx['node_records'])
    if node_df.empty:
        raise ValueError(f"Dataset {ctx['config']['label']} has no plottable nodes.")
    center_lat = node_df['lat'].mean()
    center_lon = node_df['lon'].mean()
    cfg = ctx['config']
    m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start, tiles='CartoDB Positron')
    edges_fg = folium.FeatureGroup(name=f"{cfg['label']} – edges", show=True)
    for u, v in ctx['graph'].edges():
        if u not in ctx['positions'] or v not in ctx['positions']:
            continue
        lat1, lon1 = ctx['positions'][u]
        lat2, lon2 = ctx['positions'][v]
        if None in (lat1, lon1, lat2, lon2):
            continue
        folium.PolyLine([[lat1, lon1], [lat2, lon2]], color=cfg['edge_color'], weight=1, opacity=0.6).add_to(edges_fg)
    edges_fg.add_to(m)
    station_color, infra_color = cfg['node_colors']
    stations_fg = folium.FeatureGroup(name=f"{cfg['label']} – stations", show=True)
    infra_fg = folium.FeatureGroup(name=f"{cfg['label']} – infrastructure", show=True)
    for rec in ctx['node_records']:
        layer = stations_fg if rec['is_station'] else infra_fg
        color = station_color if rec['is_station'] else infra_color
        popup = folium.Popup(
            f"<b>{rec['label']}</b><br>ID: {rec['node_id']}<br>Dataset: {cfg['label']}",
            max_width=260,
        )
        folium.CircleMarker(
            location=[rec['lat'], rec['lon']],
            radius=4 if rec['is_station'] else 3,
            color=color,
            fill=True,
            fill_color=color,
            fill_opacity=0.9,
            weight=1,
            tooltip=rec['label'],
            popup=popup,
        ).add_to(layer)
    stations_fg.add_to(m)
    infra_fg.add_to(m)
    folium.LayerControl(collapsed=False).add_to(m)
    return m


In [None]:

sbb_ctx = context_by_key.get('sbb')
if sbb_ctx is None:
    print('SBB graph not available. Run the exploration notebook with DATA_SOURCE="sbb" first.')
else:
    print('Rendering SBB map from', sbb_ctx['graph_path'])
    sbb_map = build_dataset_map(sbb_ctx, zoom_start=8)
    display(sbb_map)


In [None]:

swiss_ctx = context_by_key.get('swisstopo')
if swiss_ctx is None:
    print('Nationwide graph not available. Run the exploration notebook with DATA_SOURCE="swisstopo".')
else:
    print('Rendering nationwide map from', swiss_ctx['graph_path'])
    swiss_map = build_dataset_map(swiss_ctx, zoom_start=7)
    display(swiss_map)


In [None]:

FOCUS_DATASET_KEY = 'sbb'  # switch to 'swisstopo' to inspect nationwide nodes
FOCUS_NODE_IDS = ['ASZW', 'ASZO', 'ASKO', 'ASSW']  # replace with node ids of interest

ctx = context_by_key.get(FOCUS_DATASET_KEY)
if ctx is None:
    print(f"Dataset '{FOCUS_DATASET_KEY}' is not loaded. Available: {list(context_by_key.keys())}")
elif not FOCUS_NODE_IDS:
    print('Provide at least one node id in FOCUS_NODE_IDS to center the map.')
else:
    subset = [rec for rec in ctx['node_records'] if rec['node_id'] in FOCUS_NODE_IDS]
    if not subset:
        print('No matching nodes found for the requested IDs.')
    else:
        display(pd.DataFrame(subset))
        focus_lat = pd.DataFrame(subset)['lat'].mean()
        focus_lon = pd.DataFrame(subset)['lon'].mean()
        focus_map = folium.Map(location=[focus_lat, focus_lon], zoom_start=13, tiles='CartoDB Positron')
        for rec in subset:
            color = ctx['config']['node_colors'][0] if rec['is_station'] else ctx['config']['node_colors'][1]
            folium.CircleMarker(
                location=[rec['lat'], rec['lon']],
                radius=6,
                color=color,
                fill=True,
                fill_color=color,
                fill_opacity=0.95,
                tooltip=f"{rec['label']} ({rec['node_id']})",
            ).add_to(focus_map)
        for u, v in ctx['graph'].edges():
            if {u, v}.issubset(set(FOCUS_NODE_IDS)) and u in ctx['positions'] and v in ctx['positions']:
                lat1, lon1 = ctx['positions'][u]
                lat2, lon2 = ctx['positions'][v]
                folium.PolyLine([[lat1, lon1], [lat2, lon2]], color=ctx['config']['edge_color'], weight=2).add_to(focus_map)
        focus_map
