# Info

tl;dr: This notebook constructs the schedule network. Load it from `/user/anmaier/schedule_network.orc`, it is partitioned by `hour`. 
You can access the stops near Zurich HB from `/user/anmaier/df_stops_near.orc`.

This notebook constructs the schedule network for a date (`year-month-day`) user-defined below.
We assume that the timetable is the same for the whole year. It also compute the relevant stops near Zurich HB (i.e. the ones that are part of a trip that are at least partially within a 15 km radius from Zurich HB).

## Schedule network

The schedule network is constructed based on `istdaten` dataset for the specified date, using in particular the scheduled arrival and departure times. It is partitioned by hour, and saved in `/user/anmaier/schedule_network.orc`. It is a spark dataframe where each row represents an edge with the following schema:
* `src_timestamp (TimestampType)`: timestamp of the source
* `src_stop_name (StringType)`: name of the source station
* `dst_timestamp (TimestampType)`: timestamp of the destination
* `dst_stop_name (StringType)`: name of the destination station
* `route_desc (StringType)`: type of the mean of transport (*Zug/Bus/Tram/Schiff/...* for a *trip* edge, *walking* for a walking edge and *waiting* for a waiting edge)
* `trip_id (StringType)`: id of the trip (non-empty for *trip* and *waiting* edges, empty for *walking* edges)
* `distance (DoubleType)`: distance of the edge in km
* `duration (IntegerType)`: duration of the edge in min
* `probability (DoubleType)`: probability of the edge (relevant only for *walking* edges). !!! All the probabilites are set 1 for now, they are computed afterwards in `probabilities.py` !!!
* `walking_duration (DoubleType)`: time it would take to walk from source to destination with a delay of 2 minutes and a speed of 50m/min

__Nomenclature__: 
Data in  `istdaten` dataset represent a transport arriving at a station (arrival) and leaving the station (departure). For the schedule network we use instead source and destination, which does not necessarily correspond to a departure and an arrival of a transport but really to the source and destination of an edge. There are three types of edges:
* A *waiting* edge is when we wait in a station without leaving the mean of transport
* A *trip* edge is the travel between two consecutive stations without changing the mean of transport
* A *walking* edge is any situation where we leave the mean of transport and walk to another one (either another platform of the same station or another station)

## Stops near Zurich HB

We also compute the stops near Zurich HB (i.e. the ones that are part of a trip that are at least partially within a 15 km radius from Zurich HB). They are saved in `/user/anmaier/df_stops_near.orc`. It is a spark dataframe with the following schema:
* `stop_name (StringType)`: name of the stop
* `stop_lat (DoubleType)`: latitude of the stop
* `stop_lon (DoubleType)`: longitude of the stop
* `in_radius (BooleanType)`: whether the stop is within a 15 km radius from Zurich HB

__Nomenclature__:
A stop is *near* if it is part of a trip that has at least one stop within a 15 km radius from Zurich HB. A stop is additionally *in radius* if it is within a 15 km radius from Zurich HB.

# Starting up the Spark runtime

In [None]:
%load_ext sparkmagic.magics

In [None]:
%spark cleanup

In [None]:
%spark add -l python -s groupAD -u http://iccluster044.iccluster.epfl.ch:8998 -k

# Load preprocessed data

In [None]:
%%spark
df_schedule_network = spark.read.orc("/user/anmaier/schedule_network.orc").cache()
df_stops_near = spark.read.orc("/user/anmaier/df_stops_near.orc").cache()
df_timetable_near = spark.read.orc("/user/anmaier/df_timetable_near.orc").cache()
df_edges = spark.read.orc("/user/anmaier/df_edges.orc").cache()
df_waiting_edges = spark.read.orc("/user/anmaier/df_waiting_edges.orc").cache()
df_trip_edges = spark.read.orc("/user/anmaier/df_trip_edges.orc").cache()
df_walking_edges = spark.read.orc("/user/anmaier/df_walking_edges.orc").cache()

In [None]:
%%spark
(df_schedule_network
     .withColumn('hour', F.hour('dst_timestamp')) # use dst_timestamp for hour instead
     .write.partitionBy('hour').orc("/user/anmaier/schedule_network.orc", mode="overwrite"))

In [None]:
%%spark
# !!!! Think before modifying it
# If True, will overwrite the dataframes on hdfs
store = False

In [None]:
%%spark
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.window import Window

import graphframes as gf

from math import sin, cos, sqrt, atan2, radians

# Date of reference

In [None]:
%%spark
# Specify the date of reference
# Not all days are available in the timetables dataset
year = "2022"
month = "4" # 1 to 12, do not write a '0' prefix, i.e. write 4 not 04
day = "27" # Idem, write 4 not 04 for example
# 4th anniversary of the Panmunjom Declaration \o/

# Get string date in yyyy-MM-dd format, ensuring that month and day are 2-digits numbers
_month, _day = month, day
if len(month) == 1:
    _month = "0" + month
if len(day) == 1:
    _day = "0" + day
date = year + "-" + _month + "-" + _day
date_col = F.to_date(F.lit(date))

# Toolbox

In [None]:
%%spark
@F.udf(returnType=T.DoubleType())
def distance(lat1, lon1, lat2, lon2):
    """
    Calculate the distance between two points on the Earth's surface using the WGS84 ellipsoid.
    
    Args:
        lat1 (float): Latitude of the first point in degrees.
        lon1 (float): Longitude of the first point in degrees.
        lat2 (float): Latitude of the second point in degrees.
        lon2 (float): Longitude of the second point in degrees.
        
    Returns:
        float: The distance between the two points in kilometers.
    """
    # Approximate radius of earth in km
    R = 6378.0

    lat1 = radians(lat1)
    lon1 = radians(lon1)
    lat2 = radians(lat2)
    lon2 = radians(lon2)
    
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    
    # Haversine formula
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    
    c = 2 * atan2(sqrt(a), sqrt(1 - a))
    
    return R * c

# !! It is NOT an UDF, it's on purpose
def duration(src_timestamp, dst_timestamp) -> T.IntegerType():
    """Compute the time diff between destination and source in minutes"""
    return ((F.unix_timestamp(dst_timestamp) - F.unix_timestamp(src_timestamp))/60).cast(T.IntegerType())

# Data Preprocessing

## Load datasets

In [None]:
%%spark
df_istdaten = spark.read.orc("/data/sbb/part_orc/istdaten/year=" + year + "/month=" + month)
df_stops = spark.read.orc("/data/sbb/part_orc/timetables/stops/year=" + year + "/month=" + month + "/day=" + day)

In [None]:
%%spark
def print_df(df: DataFrame, df_name: str = "", n: int = 3, truncate: bool = True) -> None:
    """Prints the schema and the first n rows of a DataFrame."""
    if df_name:
        print(df_name + ":")
    print([a[1] for a in df.dtypes])
    df.show(n=n, truncate=truncate)
    return None

# Print datasets
print_df(df_istdaten, "istdaten")
print_df(df_stops, "stops")

## Remove non-standard rows, rename and select columns

In [None]:
%%spark
# window function to add a stop_sequence column that is a stop counter for each trip_id
# i.e. the first stop has stop_sequence=0, the 2nd stop stop_sequence=1, etc 
# up to the last stop of the trip
w_stop_sequence = Window.partitionBy('trip_id').orderBy('arrival_time')

_df_istdaten = (df_istdaten
               # Remove non-standard rows
               .filter(F.col('ZUSATZFAHRT_TF') == F.lit("false"))
               .filter(F.col('DURCHFAHRT_TF') == F.lit("false"))
               .filter((F.col('ab_prognose_status') != '') & (F.col('an_prognose_status') != ''))
               # Rename columns
               .withColumnRenamed('BETRIEBSTAG', 'date')
               .withColumnRenamed('HALTESTELLEN_NAME', 'stop_name')
               .withColumnRenamed('FAHRT_BEZEICHNER', 'trip_id')
               .withColumnRenamed('ANKUNFTSZEIT', 'arrival_time')
               .withColumnRenamed('ABFAHRTSZEIT', 'departure_time')
               .withColumnRenamed('PRODUKT_ID', 'route_desc')
               # Format date and timestamps
               .withColumn('date', F.to_date('date', 'dd.MM.yyyy'))
               .withColumn('arrival_time', F.to_timestamp('arrival_time', 'dd.MM.yyyy HH:mm'))
               .withColumn('departure_time', F.to_timestamp('departure_time', 'dd.MM.yyyy HH:mm'))
               # Filter the date
               .filter(F.col('date') == F.to_date(F.lit(date)))
               # Remove rows where both arrival and departure time are null or nan
               .filter(~((F.col('arrival_time').isNull()) & (F.col('departure_time').isNull())))
               # Add a stop_sequence column
               .withColumn('stop_sequence', F.rank().over(w_stop_sequence))
               # Select only the columns we need
               .select('date', 'stop_name', 'trip_id', 'arrival_time', 'departure_time', 'route_desc', 'stop_sequence'))

## Keep only stops that are near Zurich HB

In [None]:
%%spark
# Add a column to indicate if a stop is within the `distance_max` radius centered on Zurich HB
distance_max = 15.0 # maximum distance from Zurich HB in km
zurich_HB_lat = 47.378177
zurich_HB_lon = 8.540211
_df_stops = (df_stops
            .withColumn('in_radius', 
                        distance(F.lit(zurich_HB_lat), F.lit(zurich_HB_lon), 
                                 F.col('stop_lat'), F.col('stop_lon')) 
                        <= distance_max))

# Get the id of all trips that has at least one stop within a `distance_max` km radius centered on Zurich HB
df_trip_ids_near = (_df_istdaten
                    .join(_df_stops.filter(F.col('in_radius') == True), 
                          on='stop_name', how='leftsemi')
                    .select('trip_id')
                    .distinct())

# Filter istdaten keeping only trips that enter the radius
df_timetable_near = (_df_istdaten
                    .join(df_trip_ids_near, 'trip_id', 'leftsemi')
                    .cache())
# and get their coresponding stops
df_stops_near = (_df_stops
                 .join(df_timetable_near, 'stop_name', 'leftsemi')
                 .dropDuplicates(['stop_name'])
                 .select('stop_name', 'stop_lat', 'stop_lon', 'in_radius')
                 .cache())
# Finally we filter the stops in the timetable for which we don't have the latitude and longitude
df_timetable_near = df_timetable_near.join(df_stops_near, 'stop_name', 'left_semi')

## Cache and store

In [None]:
%%spark
print(df_timetable_near.show(n=3), df_timetable_near.count())
print(df_stops_near.show(n=3), df_stops_near.count())
if store:
    df_timetable_near.write.orc("/user/anmaier/df_timetable_near.orc", mode="overwrite")
    df_stops_near.write.orc("/user/anmaier/df_stops_near.orc", mode="overwrite")

# Schedule network

In [None]:
%%spark
edge_schema = T.StructType([
    T.StructField('src_timestamp', T.TimestampType(), False),
    T.StructField('src_stop_name', T.StringType(), False),
    T.StructField('dst_timestamp', T.TimestampType(), False),
    T.StructField('dst_stop_name', T.StringType(), False),
    T.StructField('route_desc', T.StringType(), False),
    T.StructField('trip_id', T.StringType(), False),
    T.StructField('dst_trip_id', T.StringType(), False), # extra col for dst_trip_id
    T.StructField('distance', T.DoubleType(), False), # in km
    T.StructField('duration', T.IntegerType(), False), # in min
    T.StructField('probability', T.DoubleType(), False)
])
# Create the dataframe that will contain all the edges
df_edges = spark.createDataFrame([], edge_schema)

## Waiting edges

In [None]:
%%spark
# Waiting edges, i.e. waiting in a transport at a stop without leaving the transport
df_waiting_edges = (df_timetable_near
                    # Remove start of journeys
                    .filter(~F.col('arrival_time').isNull())
                    .filter(~F.col('departure_time').isNull())
                    # Rename times to match schema
                    .withColumnRenamed('arrival_time', 'src_timestamp')
                    .withColumnRenamed('departure_time', 'dst_timestamp')
                    # Add source and destination stop name
                    .withColumn('src_stop_name', F.col('stop_name'))
                    .withColumn('dst_stop_name', F.col('stop_name'))
                    # Add route_desc
                    .withColumn('route_desc', F.lit('waiting'))
                    # Add 0km distance
                    .withColumn('distance', F.lit(0.0))
                    # Add duration
                    .withColumn('duration', duration('src_timestamp', 'dst_timestamp'))
                    # Add proba
                    .withColumn('probability', F.lit(1.))
                    # Filter and order columns to match the schema
                    .select('src_timestamp', 
                            'src_stop_name',
                            'dst_timestamp', 
                            'dst_stop_name',
                            'route_desc', 
                            'trip_id', 
                            'distance',
                            'duration',
                            'probability')
                    .cache())

### Cache and store

In [None]:
%%spark
print(df_waiting_edges.show(n=3), df_waiting_edges.count())
if store:
    df_waiting_edges.write.orc("/user/anmaier/df_waiting_edges.orc", mode="overwrite")

## Trip edges

In [None]:
%%spark
df_timetable_near = spark.read.orc("/user/anmaier/df_timetable_near.orc")
df_stops_near = spark.read.orc("/user/anmaier/df_stops_near.orc")
# A trip edge connects a departure from a location to an arrival to another location
# Two vertices are connected if they belong to the same trip 
# and if the destination stop has stop_sequence incremented by 1 compared to the source stop.

# Format sources and destinations with infos needed for constructing trip edges
df_sources = (df_timetable_near
                 # Exclude end of journeys (null departure time)
                 .filter(~F.col('departure_time').isNull())
                 # Take latitude and longitude of the stop
                 .join(df_stops_near, on='stop_name')
                 .withColumnRenamed('stop_lat', 'src_lat')
                 .withColumnRenamed('stop_lon', 'src_lon')
                 # To distinguish between sources and destinations columns
                 .withColumnRenamed('departure_time', 'src_timestamp')
                 .withColumnRenamed('stop_name', 'src_stop_name')
                 .distinct()
                 .select('*'))
df_destinations = (df_timetable_near
                 # Exclude start of journeys (null arrival time)
                 .filter(~F.col('arrival_time').isNull())
                 # Take latitude and longitude of the stop
                 .join(df_stops_near, on='stop_name')
                 .withColumnRenamed('stop_lat', 'dst_lat')
                 .withColumnRenamed('stop_lon', 'dst_lon')
                 # To distinguish between sources and destinations columns
                 .withColumnRenamed('arrival_time', 'dst_timestamp')
                 .withColumnRenamed('stop_name', 'dst_stop_name')
                 .distinct()
                 .select('dst_timestamp', 'dst_stop_name', 'stop_sequence', 
                         'trip_id', 'dst_lat', 'dst_lon'))

# Then we construct the trip edges
df_trip_edges = (df_sources
                 # Increment stop_sequence of departures to match the next arrival stop_sequence
                 .withColumn('stop_sequence', F.col('stop_sequence') + 1)
                 # Join departures to arrival that have the same trip_id and stop_sequence 
                 .join(df_destinations, on=['trip_id', 'stop_sequence'])
                 # Add ditance, duration and proba
                 .withColumn('distance', distance('src_lat', 'src_lon', 'dst_lat', 'dst_lon'))
                 .withColumn('duration', duration('src_timestamp', 'dst_timestamp'))
                 .withColumn('probability', F.lit(1.))
                 # Filter and order columns to match the schema
                 .select('src_timestamp', 
                         'src_stop_name', 
                         'dst_timestamp', 
                         'dst_stop_name', 
                         'route_desc', 
                         'trip_id', 
                         'distance', 
                         'duration',
                         'probability')
                 .distinct()
                 .cache())

### Cache and store

In [None]:
%%spark
print(df_trip_edges.show(n=3), df_trip_edges.count())
if store:
    df_trip_edges.write.orc("/user/anmaier/df_trip_edges.orc", mode="overwrite")

## Walking edges

In [None]:
%%spark
# A walking edge connect an arrival to all departures that are reachable by walking.
# It is reachable if these conditions holds:
# 1. The distance between the arrival and departure is at most 500m
# 2. The departure time is at least 2 minutes (leaving the transport) + 1min/50m (walking) away from the arrival time
# 3. The departure time is at most `max_duration` minutes away from the arrival time
# 4. For each stop name an arrival is connected to, we only keep the `max_n_destinations_per_stop` ones that are closest in time
max_distance = 0.5 # in km
leave_duration = 2 # in min
velocity = 0.05 # in km/min
max_duration = 30 # Max duration we allow to walk in minutes
max_n_destinations_per_stop = 3 # Maximum number of destinations per stop we connect an arrival to
w_max_n_destinations_per_stop = (Window
                                 .partitionBy('src_stop_name', 'dst_stop_name')
                                 .orderBy(duration('src_timestamp', 'dst_timestamp')))

# For computational reasons, we will split the day in bins of `bin_duration` seconds 
# and do the cross join between 1 sources bin and all the necessary destinations bins 
# so that the last timestamp in sources bin can walk up to `max_duration` seconds.
# For example if max_duration = 30 minutes and bin_duration = 10 minutes
# then we will cross join one sources bin with 3+1 destinations bins
bin_duration = 10 # Duration of the bins in minutes
n_bins = int(24*60 / bin_duration) # Number of bins per day

### Sources

In [None]:
%%spark
# Format sources and destinations with infos needed for constructing walking edges
start_day = F.to_timestamp(F.lit(date)) # Beginning of the day
df_sources = (df_timetable_near
              # Exclude start of journeys (null arrival time)
              .filter(~F.col('arrival_time').isNull())
              # Take latitude and longitude of the stop
              .join(df_stops_near, on='stop_name')
              .withColumnRenamed('stop_lat', 'src_lat')
              .withColumnRenamed('stop_lon', 'src_lon')
              # To distinguish between sources and destinations columns
              .withColumnRenamed('arrival_time', 'src_timestamp')
              .withColumnRenamed('stop_name', 'src_stop_name')
              # We indicate to which nth bin they belong
              .withColumn('src_bin', F.floor(duration(date_col, 'src_timestamp') / bin_duration))
              .select('src_timestamp', 'src_stop_name', 'src_lat', 'src_lon', 'src_bin', 'trip_id') # keep the trip_id of the source trip as trip_id
              .distinct()
              .cache())
if store:
    df_sources.write.orc("/user/anmaier/df_sources.orc", mode="overwrite")

### Destinations

In [None]:
%%spark
df_destinations = (df_timetable_near
                  # Exclude end of journeys (null departure time)
                  .filter(~F.col('departure_time').isNull())
                  # Take latitude and longitude of the stop
                  .join(df_stops_near, on='stop_name')
                  .withColumnRenamed('stop_lat', 'dst_lat')
                  .withColumnRenamed('stop_lon', 'dst_lon')
                  # To distinguish between sources and destinations columns
                  .withColumnRenamed('departure_time', 'dst_timestamp')
                  .withColumnRenamed('stop_name', 'dst_stop_name')
                  .withColumnRenamed('trip_id', 'dst_trip_id')
                  # We indicate to which nth bin they belong
                  .withColumn('dst_bin', F.floor(duration(date_col, 'dst_timestamp') / bin_duration))
                  .select('dst_timestamp', 'dst_stop_name', 'dst_lat', 'dst_lon', 'dst_bin', 'dst_trip_id') # keep the trip_id of the dest trip as dst_trip_id
                  .distinct()
                  .cache())
if store:
    df_destinations.write.orc("/user/anmaier/df_destinations.orc", mode="overwrite")

### Construct the walking edges

In [None]:
%%spark
# Store each bin individually
walking_edges_bins = {}
for i in range(0, n_bins - int(max_duration/bin_duration)):
    walking_edges_bins[i] = (df_sources
                            .filter(F.col('src_bin') == i)
                            .crossJoin(df_destinations
                                       .filter((F.col('dst_bin') >= i) & (F.col('dst_bin') <= i + int(max_duration/bin_duration))))
                            # Filter the combinations that are reachable by walking
                            .withColumn('distance', distance('src_lat', 'src_lon', 'dst_lat', 'dst_lon'))
                            .filter(F.col('distance') <= max_distance)
                            .withColumn('duration', duration('src_timestamp', 'dst_timestamp'))
                            .filter(F.col('duration') >= leave_duration + F.col('distance') / velocity)
                            .filter(F.col('duration') <= max_duration)
                            # Add route_desc, trip_id
                            .withColumn('route_desc', F.lit("walking"))
                            #.withColumn('trip_id', F.lit("")) # no longer want this to be empty
                            .withColumn('probability', F.lit(1.))
                            # Filter and order columns to match the schema
                            .select('src_timestamp', 
                                    'src_stop_name', 
                                    'dst_timestamp', 
                                    'dst_stop_name', 
                                    'route_desc', 
                                    'trip_id',
                                    'dst_trip_id', # keep dst_trip_id as well
                                    'distance', 
                                    'duration',
                                    'probability'))
    if store:
        walking_edges_bins[i].write.orc("/user/anmaier/walking_edges/bin=" + str(i), mode="overwrite")

In [None]:
%%spark
# Union into slices of 10 bins
walking_edges_slices = {}
for i in range(0, n_bins - int(max_duration/bin_duration), 10):
    walking_edges_slices[i] = spark.createDataFrame([], edge_schema)
    for j in range(i, min(i + 10, n_bins - int(max_duration/bin_duration))):
        walking_edges_slices[i] = walking_edges_slices[i].union(spark.read.orc("/user/anmaier/walking_edges/bin=" + str(j)))
    if store:
        walking_edges_slices[i].write.orc("/user/anmaier/walking_edges/slice=" + str(i), mode="overwrite")

In [None]:
%%spark
# Union of all the slices
df_walking_edges = spark.createDataFrame([], edge_schema)
for i in range(0, n_bins - int(max_duration/bin_duration), 10):
    df_walking_edges = df_walking_edges.union(spark.read.orc("/user/anmaier/walking_edges/slice=" + str(i)))
# Remove walking edges when too many are connected to the same stop_name
df_walking_edges = (df_walking_edges
                    .withColumn('rank', F.rank().over(w_max_n_destinations_per_stop))
                    .filter(F.col('rank') <= max_n_destinations_per_stop)
                    .drop('rank'))
if store:
    df_walking_edges.write.orc("/user/anmaier/df_walking_edges.orc", mode="overwrite")

## Union all edges

In [None]:
%%spark
df_waiting_edges = spark.read.orc("/user/anmaier/df_waiting_edges.orc")
df_trip_edges = spark.read.orc("/user/anmaier/df_trip_edges.orc")

# create columns for dst_trip_id in the waiting and trip edges (duplicate of trip_id so its easier to join later)
df_waiting_edges = df_waiting_edges.withColumn('dst_trip_id', F.col("trip_id"))
df_trip_edges = df_trip_edges.withColumn('dst_trip_id', F.col("trip_id"))

df_walking_edges = spark.read.orc("/user/anmaier/df_walking_edges.orc")
df_edges = (df_waiting_edges
            .union(df_trip_edges)
            .union(df_walking_edges)
            .distinct())
# Add a column with the time it would take to walk from src to dst
df_edges = df_edges.withColumn('walking_duration', 2. + F.col('distance') / 0.05)
if store:
    df_edges.write.orc("/user/anmaier/df_edges.orc", mode="overwrite")

In [None]:
%%spark
# Partition by hour
if store:
    (df_edges
     .withColumn('hour', F.hour('dst_timestamp')) # use dst_timestamp for hour instead
     .write.partitionBy('hour').orc("/user/anmaier/schedule_network.orc", mode="overwrite"))