# Swiss rail network exploration\n
\n
**NOTE: This notebook primarily analyzes the nationwide Swisstopo dataset for the Final Report.**\n
\n
It can load and analyse either:\n
\n
1. The **SBB infrastructure dataset** (Legacy/Incomplete) - Not used for final robustness analysis.\n
2. The **Nationwide `schienennetz_2056_de` geodatabase** from Swisstopo (Primary Source).\n
\n
Set `DATA_SOURCE` in the next cell to switch between them. For the final report, we use `swisstopo`.\n

In [1]:
from pathlib import Path

import pandas as pd
import geopandas as gpd
import networkx as nx
import pickle


In [2]:
DATA_SOURCE = "swisstopo"  # options: "swisstopo" or "swisstopo"
EXPORT_GRAPH = True

BASE_DIR = Path("../datasets/switzerland")
SBB_LINE_PATH = BASE_DIR / "sbb-linie-mit-betriebspunkten.csv"
SBB_STATION_PATH = BASE_DIR / "sbb-dienststellen-gemass-opentransportdataswiss.csv"
SWISSTOPO_GDB_PATH = BASE_DIR / "schienennetz_2056_de.gdb"

GRAPH_OUTPUTS = {
    "sbb": BASE_DIR / "sbb_rail_network.gpickle",
    "swisstopo": BASE_DIR / "swiss_rail_network_swisstopo.gpickle",
}

LEGACY_GRAPH_PATHS = {
    "sbb": BASE_DIR / "swiss_rail_network.gpickle",
}

station_metadata = pd.read_csv(SBB_STATION_PATH, sep=';')
station_metadata['abbreviation_clean'] = (
    station_metadata['abbreviation']
    .astype(str)
    .str.strip()
    .str.upper()
)
stop_point_mask = station_metadata['stopPoint'].astype(str).str.lower() == 'true'
station_abbreviation_set = set(
    station_metadata.loc[stop_point_mask, 'abbreviation_clean'].dropna().tolist()
)
station_abbreviation_set.discard('NAN')

GRAPH_OUTPUT_PATH = GRAPH_OUTPUTS[DATA_SOURCE]
LEGACY_PATH = LEGACY_GRAPH_PATHS.get(DATA_SOURCE)
if LEGACY_PATH and LEGACY_PATH.exists() and not GRAPH_OUTPUT_PATH.exists():
    print(f"Legacy graph detected at {LEGACY_PATH}. New exports will use {GRAPH_OUTPUT_PATH}.")
print(f"Using data source: {DATA_SOURCE}")
print(f"Graph exports will be written to: {GRAPH_OUTPUT_PATH}")

Using data source: swisstopo
Graph exports will be written to: ../datasets/switzerland/swiss_rail_network_swisstopo.gpickle


In [3]:
if DATA_SOURCE == "sbb":
    sbb_line_data = pd.read_csv(SBB_LINE_PATH, sep=';')
    sbb_stations_data = pd.read_csv(SBB_STATION_PATH, sep=';')
    sbb_line_and_station_data = sbb_line_data.merge(
        sbb_stations_data,
        left_on="Didok number",
        right_on="number",
        how="left",
        suffixes=("", "_didok"),
    )

    print(f"Line table shape: {sbb_line_data.shape}")
    print(f"Stations table shape: {sbb_stations_data.shape}")
    print("Combined columns:")
    print(sorted(sbb_line_and_station_data.columns)[:10], "...")
else:
    net_segments = gpd.read_file(SWISSTOPO_GDB_PATH, layer='Netzsegment')
    net_nodes = gpd.read_file(SWISSTOPO_GDB_PATH, layer='Netzknoten')

    print(f"Segments table shape: {net_segments.shape}")
    print(f"Nodes table shape: {net_nodes.shape}")
    print("Segment columns:")
    print(sorted(net_segments.columns))
    print("Node columns:")
    print(sorted(net_nodes.columns))

Segments table shape: (3424, 17)
Nodes table shape: (3210, 11)
Segment columns:
['AnzahlStreckengleise', 'BearbeitungsDatum', 'BeginnGueltigkeit', 'Elektrifizierung', 'EndeGueltigkeit', 'Infrastrukturbetr_TUAbkuerzung', 'Infrastrukturbetr_TUNummer', 'KmAnfang', 'KmEnde', 'Name', 'Spurweite', 'Stand', 'geometry', 'rAnfangsknoten', 'rEndknoten', 'rKmLinie', 'xtf_id']
Node columns:
['BearbeitungsDatum', 'BeginnGueltigkeit', 'Betriebspunkt_Abkuerzung', 'Betriebspunkt_DatenherrAbkuer', 'Betriebspunkt_Name', 'Betriebspunkt_Nummer', 'EndeGueltigkeit', 'Stand', 'geometry', 'rUebergeordnet', 'xtf_id']


In [4]:
def parse_geopos(value):
    if isinstance(value, str) and ',' in value:
        lat_str, lon_str = value.split(',', 1)
        try:
            lat = float(lat_str.strip())
            lon = float(lon_str.strip())
        except ValueError:
            return None
        if -90 <= lat <= 90 and -180 <= lon <= 180:
            return lat, lon
    return None


def classify_station(abbreviation, fallback_rows=None):
    abbr = (abbreviation or "").strip().upper()
    if abbr and abbr in station_abbreviation_set:
        return True
    fallback_rows = fallback_rows or []
    for row in fallback_rows:
        row_abbr = row.get('Station abbreviation')
        if isinstance(row_abbr, str) and row_abbr.strip().upper() in station_abbreviation_set:
            return True
    return False


def flatten_lines(geom):
    if geom is None or geom.is_empty:
        return []
    coords = []
    if geom.geom_type == 'LineString':
        coords.extend((pt[1], pt[0]) for pt in geom.coords)
    elif geom.geom_type == 'MultiLineString':
        for line in geom.geoms:
            coords.extend((pt[1], pt[0]) for pt in line.coords)
    return coords


G = nx.Graph()
node_records = []
positions = {}

if DATA_SOURCE == "sbb":
    LINE_COL = "Line"
    STATION_COL = "Station abbreviation"
    ORDER_COL = "KM"

    node_groups = sbb_line_and_station_data.dropna(subset=[STATION_COL]).groupby(STATION_COL)
    for station, group in node_groups:
        row_dicts = group.to_dict('records')
        coords = None
        for row in row_dicts:
            coords = parse_geopos(row.get('Geopos')) or parse_geopos(row.get('Geopos_didok'))
            if coords:
                break
        lat, lon = coords if coords else (None, None)
        label = next((row.get('Stop name') for row in row_dicts if isinstance(row.get('Stop name'), str) and row.get('Stop name')), station)
        is_station = classify_station(station, row_dicts)
        node_attrs = {
            'label': label,
            'abbreviation': station,
            'lat': lat,
            'lon': lon,
            'is_station': is_station,
            'rows': row_dicts,
            'source': 'sbb',
        }
        G.add_node(station, **node_attrs)
        positions[station] = (lat, lon)
        node_records.append({
            'node_id': station,
            'label': label,
            'lat': lat,
            'lon': lon,
            'is_station': is_station,
            'source': 'sbb',
        })

    ordered_df = (
        sbb_line_and_station_data
        .dropna(subset=[LINE_COL, STATION_COL, ORDER_COL])
        .sort_values([LINE_COL, ORDER_COL])
    )

    for line_id, group in ordered_df.groupby(LINE_COL):
        stops = group[STATION_COL].tolist()
        kms = group[ORDER_COL].tolist()
        row_dicts = group.to_dict('records')

        for idx, (u, v) in enumerate(zip(stops[:-1], stops[1:])):
            segment_meta = {
                'line_id': line_id,
                'order_index': idx,
                'from_station': u,
                'to_station': v,
                'from_km': kms[idx],
                'to_km': kms[idx + 1],
                'from_row': row_dicts[idx],
                'to_row': row_dicts[idx + 1],
            }
            if G.has_edge(u, v):
                G[u][v]['lines'].add(line_id)
                G[u][v]['segments'].append(segment_meta)
            else:
                G.add_edge(
                    u,
                    v,
                    lines={line_id},
                    segments=[segment_meta],
                    source='sbb',
                )

    for u, v, data in G.edges(data=True):
        data['lines'] = sorted(data['lines'])
else:
    nodes_gdf = net_nodes.to_crs(4326)
    segments_gdf = net_segments
    segments_wgs84 = segments_gdf.to_crs(4326)

    for _, row in nodes_gdf.iterrows():
        node_id = row['xtf_id']
        abbr = row.get('Betriebspunkt_Abkuerzung')
        label = abbr or row.get('Betriebspunkt_Name') or node_id
        lat = row.geometry.y
        lon = row.geometry.x
        is_station = False
        if isinstance(abbr, str) and abbr.strip():
            is_station = abbr.strip().upper() in station_abbreviation_set
        node_attrs = {
            'label': label,
            'abbreviation': abbr,
            'lat': lat,
            'lon': lon,
            'is_station': is_station,
            'rows': [row.drop(labels='geometry').to_dict()],
            'source': 'swisstopo',
        }
        G.add_node(node_id, **node_attrs)
        positions[node_id] = (lat, lon)
        node_records.append({
            'node_id': node_id,
            'label': label,
            'lat': lat,
            'lon': lon,
            'is_station': is_station,
            'source': 'swisstopo',
        })

    for _, row in segments_wgs84.iterrows():
        u = row['rAnfangsknoten']
        v = row['rEndknoten']
        if pd.isna(u) or pd.isna(v):
            continue
        if u not in G.nodes or v not in G.nodes:
            continue
        lines = [row['Name']] if isinstance(row.get('Name'), str) else []
        segment_meta = {
            'segment_id': row['xtf_id'],
            'line_name': row.get('Name'),
            'track_count': row.get('AnzahlStreckengleise'),
            'gauge': row.get('Spurweite'),
            'electrified': row.get('Elektrifizierung'),
            'coords_wgs84': flatten_lines(row.geometry),
        }
        if G.has_edge(u, v):
            combined = set(G[u][v]['lines'])
            combined.update(lines)
            G[u][v]['lines'] = sorted(combined)
            G[u][v]['segments'].append(segment_meta)
        else:
            G.add_edge(
                u,
                v,
                lines=sorted(lines),
                segments=[segment_meta],
                source='swisstopo',
            )

node_summary_df = pd.DataFrame(node_records)
print(f"Nodes prepared with coordinates: {len(node_summary_df)}")

Nodes prepared with coordinates: 3210


In [5]:
print(f"Nodes: {G.number_of_nodes():,}")
print(f"Edges: {G.number_of_edges():,}")
print(f"Connected components: {nx.number_connected_components(G):,}")

Nodes: 3,210
Edges: 3,377
Connected components: 52


In [6]:
components = list(nx.connected_components(G))
rows = []
for nodeset in components:
    sub = G.subgraph(nodeset)
    rows.append({
        'num_nodes': sub.number_of_nodes(),
        'num_edges': sub.number_of_edges(),
    })
cc_df = pd.DataFrame(rows)
cc_df = cc_df.sort_values('num_nodes', ascending=False).reset_index(drop=True)
cc_df.insert(0, 'rank', cc_df.index + 1)
cc_df

Unnamed: 0,rank,num_nodes,num_edges
0,1,1687,1843
1,2,253,272
2,3,249,268
3,4,171,174
4,5,126,131
5,6,99,109
6,7,53,52
7,8,43,42
8,9,38,37
9,10,36,35


In [7]:
components_by_size = sorted(components, key=len, reverse=True)
if len(components_by_size) > 1:
    small_components = components_by_size[1:]
else:
    small_components = []

rows = []
component_meta = {}
for rank, nodeset in enumerate(small_components, start=2):
    sub = G.subgraph(nodeset)
    component_meta[rank] = {
        'num_nodes': sub.number_of_nodes(),
        'num_edges': sub.number_of_edges(),
    }
    for node_id in sorted(nodeset):
        data = G.nodes[node_id]
        rows.append({
            'component_rank': rank,
            'node_id': node_id,
            'label': data.get('label', node_id),
            'abbreviation': data.get('abbreviation'),
            'is_station': data.get('is_station', False),
            'source': data.get('source', DATA_SOURCE),
        })

small_cc_df = (
    pd.DataFrame(rows)
    .sort_values(['component_rank', 'label'])
    .reset_index(drop=True)
)

small_cc_df

Unnamed: 0,component_rank,node_id,label,abbreviation,is_station,source
0,2,ch14uvag00092988,"Dübendorf, Giessen",,False,swisstopo
1,2,ch14uvag00092989,"Dübendorf, Ringwiesen",,False,swisstopo
2,2,ch14uvag00092963,"Glattbrugg, Bahnhof",,False,swisstopo
3,2,ch14uvag00092962,"Glattbrugg, Lindberghplatz",,False,swisstopo
4,2,ch14uvag00092965,"Glattbrugg, Unterriet",,False,swisstopo
...,...,...,...,...,...,...
1518,48,ch14uvag00240301,EYSI,EYSI,False,swisstopo
1519,49,ch14uvag00139699,IO,IO,True,swisstopo
1520,50,ch14uvag00240293,TAFR,TAFR,False,swisstopo
1521,51,ch14uvag00139704,BSRB,BSRB,False,swisstopo


In [8]:
pair_rows = []
singleton_rows = []

if small_cc_df.empty:
    print("No satellite components detected outside the largest component.")
else:
    for comp_rank, group in small_cc_df.groupby('component_rank'):
        group = group.sort_values('label')
        meta = component_meta.get(comp_rank, {'num_nodes': len(group), 'num_edges': max(len(group) - 1, 0)})
        if len(group) < 2:
            singleton_rows.append({
                'component_rank': comp_rank,
                'num_nodes': meta['num_nodes'],
                'num_edges': meta['num_edges'],
                'node_id': group.iloc[0]['node_id'],
                'label': group.iloc[0]['label'],
            })
            continue
        a, b = group.iloc[0], group.iloc[1]
        pair_rows.append({
            'component_rank': comp_rank,
            'num_nodes': meta['num_nodes'],
            'num_edges': meta['num_edges'],
            'node1': f"{a['node_id']} ({a['label']})",
            'node2': f"{b['node_id']} ({b['label']})",
        })

small_pairs_df = (
    pd.DataFrame(pair_rows)
    .sort_values('component_rank')
    .reset_index(drop=True)
)

singleton_components_df = (
    pd.DataFrame(singleton_rows)
    .sort_values('component_rank')
    .reset_index(drop=True)
)

print("Pair summary:")
print(small_pairs_df)
print("Singleton components:")
print(singleton_components_df if not singleton_components_df.empty else "None")

Pair summary:
    component_rank  num_nodes  num_edges  \
0                2        253        272   
1                3        249        268   
2                4        171        174   
3                5        126        131   
4                6         99        109   
5                7         53         52   
6                8         43         42   
7                9         38         37   
8               10         36         35   
9               11         32         31   
10              12         30         29   
11              13         26         26   
12              14         24         23   
13              15         22         21   
14              16         22         21   
15              17         21         20   
16              18         18         17   
17              19         17         16   
18              20         16         15   
19              21         16         15   
20              22         15         14   
21              23

In [9]:
if components_by_size:
    largest_nodes = components_by_size[0]
    print(f"Largest component nodes: {len(largest_nodes):,}")
    print(f"Largest component share: {len(largest_nodes) / G.number_of_nodes():.2%}")
else:
    print("Graph has no components.")

Largest component nodes: 1,687
Largest component share: 52.55%


In [10]:
if EXPORT_GRAPH:
    GRAPH_OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    with GRAPH_OUTPUT_PATH.open('wb') as f:
        pickle.dump(G, f)
    print(f"Serialized graph to {GRAPH_OUTPUT_PATH}")
else:
    print("Skipping graph export.")

Serialized graph to ../datasets/switzerland/swiss_rail_network_swisstopo.gpickle
