# Basis graph construction

In this notebook, we build the graph that will be the basis of our journey planner algorithm.

## Setup PySpark
We first start by loading the Spark session. Note that some of our queries are heavy and this notebook should be run only once, we allow ourselves take a bit more resources.

In [1]:
%load_ext sparkmagic.magics

In [2]:
import os
import warnings
import pandas as pd

warnings.simplefilter(action='ignore', category=UserWarning)

username = os.environ['RENKU_USERNAME']
print(username)

verardo


In [3]:
server = "http://iccluster029.iccluster.epfl.ch:8998"
from IPython import get_ipython
get_ipython().run_cell_magic('spark', line="config",
                             cell="""{{ "name":"{0}-aces", "executorMemory":"10G", "executorCores":8, "numExecutors":10 }}""".format(username))

In [4]:
get_ipython().run_line_magic(
    "spark", "add -s {0}-aces -l python -u {1} -k".format(username, server) 
)

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
9439,application_1652960972356_5268,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


In [5]:
%%spark
from pyspark.sql.functions import *

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Preprocessing
### Filter on service that are available every week day

We choose the week of May 8th to build our model of the public infrastructure as this week do not have contain any of the bank holiday days of Switzerland. Since our planner will provide information only for week days (considering each week day as the same), we start by filtering out services that do not occur each day of the week.

In [6]:
%%spark
calendar = spark.read.csv('/data/sbb/csv/calendar/2019/05/08/calendar.txt', sep=',', header=True, inferSchema=True)

# Rename all the columns by taking lower cases
old_columns = calendar.schema.names
new_columns = [col.lower() for col in old_columns]

calendar = reduce(lambda data, idx: calendar.withColumnRenamed(old_columns[idx], new_columns[idx]), range(len(old_columns)), calendar)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
%%spark

# Keep only week days
calendar_week = calendar.filter((calendar.monday == 1)
                & (calendar.tuesday  == 1)
                & (calendar.wednesday == 1)
                & (calendar.thursday == 1)
                & (calendar.friday == 1)).select('service_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Using the results, we filter out trips that do not occur in each week day.

In [8]:
%%spark
trips = spark.read.csv('/data/sbb/csv/trips/2019/05/08/trips.txt', sep=',', header=True, inferSchema=True)

# Rename all the columns by taking lower cases
old_columns = trips.schema.names
new_columns = [col.lower() for col in old_columns]

trips = reduce(lambda data, idx: trips.withColumnRenamed(old_columns[idx], new_columns[idx]), range(len(old_columns)), trips)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
%%spark
# Join with the trip dataframe, to get the list of trips that happen only on all week days
trip_id = calendar_week.join(trips, on='service_id', how='inner').select('trip_id', 'route_id').distinct()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Add the stop information
We load stop informations from HDFS and merge them with the trip information already loaded

In [10]:
%%spark
stop_times = spark.read.csv('/data/sbb/csv/stop_times/2019/05/08/stop_times.txt', sep=',', header=True, inferSchema=True)

# Rename all the columns by taking lower cases
old_columns = stop_times.schema.names
new_columns = [col.lower() for col in old_columns]

stop_times = reduce(lambda data, idx: stop_times.withColumnRenamed(old_columns[idx], new_columns[idx]), range(len(old_columns)), stop_times)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [11]:
%%spark
# Get all the nodes 
nodes = trip_id.join(stop_times, on='trip_id', how='inner')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Finally, we include all stop detailled informations and we merge them in a single dataframe.

In [12]:
%%spark
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType
# Note that we need to specify the the format only for this file as PySpark cannot infer it.
schema = StructType([StructField("stop_id", StringType(), True),
                    StructField("stop_name", StringType(), True),
                     StructField("stop_lat", FloatType(), True),
                    StructField("stop_lon", FloatType(), True),
                    StructField("location_type", IntegerType(), True),
                    StructField("parent_location", StringType(), True)])
stops = spark.read.csv('/data/sbb/csv/allstops/stop_locations.csv', sep=',', header=False, schema=schema)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
%%spark
node_info = stops.join(nodes, on="stop_id",how="inner").cache() # Cache the results so that we do not need to compute this series of merge again

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Decoupling of arrival and departure nodes

We perform some processing on the nodes: 
- We add one new column `IS_ARRIVAL` whose purpose is simply to tell us if the node is the arrival or the departure of the current 'edge' 
- We start by keeping only the arrival time of the node and add a column `IS_ARRIVAL` full of one. 
- Then, we do the same with the daprture time and we fill the `IS_ARRIVAL` column with zeos. 
- Take the union of the two processings.

In [14]:
%%spark
from pyspark.sql.functions import *
# Duplicate the nodes so that we have one arrival node and one departure node for each of the stop times
all_nodes_arr = node_info.drop(node_info.departure_time).withColumnRenamed("arrival_time","time")
all_nodes_arr = all_nodes_arr.withColumn("is_arrival",lit(1))
all_nodes_dep = node_info.drop(node_info.arrival_time).withColumnRenamed("departure_time","time")
all_nodes_dep = all_nodes_dep.withColumn("is_arrival", lit(0))
all_nodes = all_nodes_arr.union(all_nodes_dep)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Keep only nodes around Zürich

We define the [haversine distance](https://en.wikipedia.org/wiki/Haversine_formula) function which will help us to compute the distance between two coordinates (in meter)

In [15]:
%%spark
import pyspark.sql.functions as F
from math import radians, cos, sin, asin, sqrt, atan2

@F.udf
def haversine(lat1, lon1, lat2=47.378177, lon2=8.540192):
    """
    Compute the haversine distance between two coordinates
    :param lat1: the latitude of the first point
    :param lon1: the longitude of teh first point
    :param lat2: the latitude of the second point, by default the one of Zurich HB
    :param lon2: the longitude of the second point, by default the one of Zurich HB
    """
    # https://www.movable-type.co.uk/scripts/latlong.html
    # Zurich HB coordinates by default
    
    R = 6371e3
    phi1, phi2, delta_phi, delta_lambda = map(radians, [lat1, lat2, lat2 - lat1, lon2 - lon1])
    a = sin(0.5 * delta_phi)**2 + cos(phi1) * cos(phi2) * sin(0.5 * delta_lambda)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))
    
    return c * R

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Using the previously defined function, we can now filter out every node that is further away than 15km from Zurich HB. To be safe, we include 2.5km of margin. This supplementary margin is a trade-off between being able to reach more nodes and not to have too many walking edges.

In [16]:
%%spark
# Filter out data that are too far
MARGIN = 2.5e3
RADIUS = 15e3

# Filter to keep nodes only during business hours
nodes_zurich = (all_nodes.filter(haversine(all_nodes.stop_lat, all_nodes.stop_lon) < RADIUS + MARGIN)
                .filter((hour(to_timestamp(col('time'), format="HH:mm:ss")) >= 6)
                                                 & (hour(to_timestamp(col('time'), format="HH:mm:ss")) <= 20))
                .cache())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [17]:
%%spark
nodes_zurich.show(5), nodes_zurich.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+---------+--------+-------------+---------------+--------------------+-----------+--------+-------------+-----------+-------------+----------+
|    stop_id| stop_name| stop_lat|stop_lon|location_type|parent_location|             trip_id|   route_id|    time|stop_sequence|pickup_type|drop_off_type|is_arrival|
+-----------+----------+---------+--------+-------------+---------------+--------------------+-----------+--------+-------------+-----------+-------------+----------+
|    8503064|  Scheuren|47.322598|8.659553|         null|           null|1.TA.26-18-j19-1.1.H|26-18-j19-1|10:41:00|            1|          0|            0|         1|
|8503065:0:1|     Forch|47.325336|8.647973|         null|  Parent8503065|1.TA.26-18-j19-1.1.H|26-18-j19-1|10:45:00|            2|          0|            0|         1|
|    8503074|Neue Forch|47.325813| 8.63784|         null|           null|1.TA.26-18-j19-1.1.H|26-18-j19-1|10:46:00|            3|          0|            0|         1

We can see that some of the stop ids contain "$:0:x$" as described in the data cookbook, where $x$ is the platform number. This corresponds to stops belonging to a bigger place, for instance Zurich mainstation has multiple platforms. They will all start with the id of Zurich mainstation but then contains "$:0:x$" where $x$ is the track number. From the cookbook, we can see that the '0' could be in fact replaced by other number, let see if this is the case.

In [18]:
%%spark
nodes_zurich.filter(nodes_zurich.stop_id.contains(":0:")).count(), nodes_zurich.filter(nodes_zurich.stop_id.contains(":")).count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(50151, 50151)

 We can see that there is no node whose id contains "_:y:x_" where $y\neq 0$. Thus, this doesn't contain any information for our purposes and we can safely remove :0 from the node that contains this. 
 
 We also provide a way of giving an unique_id to each of the node to be able to identify them in the graph. This unique identifier has the following shape:
$$
STOPID||\_||TIME||\_||TRIPID||\_||ISARRIVAL
$$
where ∣∣ denotes the concatenation operation

In [19]:
%%spark
# Use the concat_ws to concatenate all the given colun with '_' separator
nodes_zurich = nodes_zurich.withColumn("full_stop_id",
                                       F.concat_ws("_", F.regexp_replace("stop_id",":0",""),
                                                F.regexp_replace("time",":","-"),
                                                nodes_zurich.trip_id,
                                                nodes_zurich.is_arrival)).cache()
nodes_zurich.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

697279

We can see that the number of nodes is odd which can seem strange since we duplicate every stop/start node. However, if a bus arrives at 20:59 at one station and continue its trip at 21:01, then we can see that the second one will be dropped whereas the first one will be kept hence the odd number. 

Since we now have our definite list of nodes, we can save it on HDFS.

In [20]:
%%spark
nodes_zurich.write.save("/group/aces/graph/nodes_final.orc", format="orc", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Build edges
In this part, we will build all the different edges we want to consider in our graph reprsentation of the Zurich public transport model. The edges we consider are 
1. Intra-station edge. I.e. for nodes that have the same non-null parent station attribute and correspond to trip that are no more than 15 minutes apart, we add an edge with 2 min of walking times as indicated in the statement.
2. We add walking edges between stops that are no more than 15 minutes apart.


### Add the timetable information to the nodes

In [21]:
%%spark
id_nodes_zh = nodes_zurich.select("stop_id", "trip_id","route_id","full_stop_id","time").distinct()

# Create the dataframes containing all zurich nodes and with all the information
stop_times_zh_arr = (stop_times.join(
    id_nodes_zh.filter(col("is_arrival")==1).select(
        col("stop_id"),
        col("trip_id"),
        col("route_id"),
        col("full_stop_id"),
        col("time").alias("arrival_time")
), on=["stop_id", "trip_id","arrival_time"], how="inner"))
stop_times_zh_dep = (stop_times.join(
    id_nodes_zh.filter(col("is_arrival")==0).select(
        col("stop_id"),
        col("trip_id"),
        col("route_id"),
        col("full_stop_id"),
        col("time").alias("departure_time")
), on=["stop_id", "trip_id","departure_time"], how="inner"))

stop_times_zh = stop_times_zh_dep.union(stop_times_zh_arr).cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Load the routes information
We can now load the road information such as the road number, the transport mean used and the road id

In [22]:
%%spark
routes = spark.read.csv('/data/sbb/csv/routes/2019/05/08/routes.txt', sep=',', header=True, inferSchema=True)

# Rename all the columns in the route dataframe to lower cases
old_columns = routes.schema.names
new_columns = [col.lower() for col in old_columns]

routes = reduce(lambda data, idx: routes.withColumnRenamed(old_columns[idx], new_columns[idx]), range(len(old_columns)), routes)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
%%spark
routes.select('route_desc').distinct().collect()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[Row(route_desc=u'TGV'), Row(route_desc=u'Eurocity'), Row(route_desc=u'Standseilbahn'), Row(route_desc=u'Regionalzug'), Row(route_desc=u'RegioExpress'), Row(route_desc=u'S-Bahn'), Row(route_desc=u'Luftseilbahn'), Row(route_desc=u'Sesselbahn'), Row(route_desc=u'Taxi'), Row(route_desc=u'F\xe4hre'), Row(route_desc=u'Tram'), Row(route_desc=u'ICE'), Row(route_desc=u'Bus'), Row(route_desc=u'Gondelbahn'), Row(route_desc=u'Nacht-Zug'), Row(route_desc=u'Auoreisezug'), Row(route_desc=u'Eurostar'), Row(route_desc=u'Schiff'), Row(route_desc=u'Schnellzug'), Row(route_desc=u'Intercity'), Row(route_desc=u'InterRegio'), Row(route_desc=u'Extrazug'), Row(route_desc=u'Metro')]

We can see that there is a lot of different transport means but they often have a common base, i.e. InterRegio, Intercity, ExtraZug, ... all refers to train. So we first map them into a common bases which will be one of the following: _Train, Tram, Bus or other_

In [24]:
%%spark

transport_mapping = {
    "Auoreisezug": "Other",
    "Bus": "Bus",
    "Eurostar": "Train",
    "Eurocity": "Train",
    "Extrazug": "Train",
    u"F\xe4hre": "Other",
    "Gondelbahn": "Other",
    "ICE": "Train",
    "Intercity": "Train",
    "InterRegio": "Train",
    "Luftseilbahn": "Other",
    "Metro": "Other",
    "Nacht-Zug": "Train",
    "RegioExpress": "Train",
    "Regionalzug": "Train",
    "S-Bahn": "Train",
    "Schiff": "Other",
    "Schnellzug": "Train",
    "Sesselbahn": "Other",
    "Standseilbahn": "Other", 
    "Taxi": "Other",
    "TGV": "Train",
    "Tram": "Tram"
}

@F.udf
def map_transport(x):
    return transport_mapping[x]

routes = routes.withColumn('route_desc', map_transport(routes.route_desc)).cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Merge with nodes\_zurich
We can now merge all the information into a single dataframe

In [25]:
%%spark
nodes_zh_all_infos = nodes_zurich.join(routes, on="route_id").select("full_stop_id","stop_name","stop_lat","stop_lon", "time", "route_short_name","route_desc")
nodes_zh_all_infos.write.save("/group/aces/graph/nodes_all_info_final.orc", format="orc", mode='overwrite')
nodes_zh_all_infos.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

697279

### Build the transport edges

In this section, we will build the different transport edges. Each edge will link a departure node of one stop to the arrival node of the subsequent stop in the trip.

In [26]:
%%spark
from pyspark.sql.window import Window
from pyspark.sql.functions import to_timestamp, col

# We lag the stop_id, arrival_time_dest and full_stop_id_dest by 1 upward so that we obtain pairs of subsequent stops
stop_times_zh_pairs = stop_times_zh.withColumn('stop_id_dest', F.lag('stop_id', count=-1).over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(),col('full_stop_id').desc()])))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn('arrival_time_dest', F.lag('arrival_time', count=-1).over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(),col('full_stop_id').desc()])))
stop_times_zh_pairs = stop_times_zh_pairs.withColumn('full_stop_id_dest', F.lag('full_stop_id', count=-1).over(Window.partitionBy('trip_id').orderBy([col('stop_sequence').asc(), col("departure_time").asc(),col('full_stop_id').desc()])))
# We drop the arrival_time column and we change the arrival_time_dest to the arrival_time column
stop_times_zh_pairs = stop_times_zh_pairs.drop('arrival_time').withColumnRenamed('arrival_time_dest', 'arrival_time')
# Drop the destination stop id as we already have the full_stop id containing all teh information
stop_times_zh_pairs = stop_times_zh_pairs.dropna(subset='stop_id_dest')
# Compute the expected travel time, i.e. the time between the arrival_time at the finish stop and the departure time at the departure node
stop_times_zh_pairs = stop_times_zh_pairs.withColumn('expected_travel_time', to_timestamp(stop_times_zh_pairs.arrival_time, 'HH:mm:ss').cast('long') - to_timestamp(stop_times_zh_pairs.departure_time, 'HH:mm:ss').cast('long'))

# Drop edges that are from and to same stop id (is_arrival : one for departure and one for arrival with same node id)
stop_times_zh_pairs = stop_times_zh_pairs.filter(stop_times_zh_pairs.stop_id != stop_times_zh_pairs.stop_id_dest)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
%%spark

# Add the routes information
stop_times_zh_pairs = stop_times_zh_pairs.join(routes.select('route_id', 'route_desc', 'route_short_name'), on='route_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
%%spark

# Keep transports only during working hours
stop_times_zh_pairs = stop_times_zh_pairs.filter((hour(to_timestamp(col('departure_time'), format="HH:mm:ss")) >= 6)
                                                 & (hour(to_timestamp(col('arrival_time'), format="HH:mm:ss")) <= 20))

# Select only necessary columns
stop_times_zh_pairs = stop_times_zh_pairs.select(
    col("full_stop_id").alias("start_id"),
    col("full_stop_id_dest").alias("end_id"),
    col("expected_travel_time").alias("duration"),
    col("route_desc").alias("transport"),
    col('route_short_name').alias("line_number"),
    F.lit(1).alias("is_trip"),
    F.lit(0).alias("waiting_time"),
    hour(to_timestamp(stop_times_zh_pairs.arrival_time, 'HH:mm:ss')).alias("hour")
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

The format of the edge table is the following:
- `start_id`: the full stop id of the starting node
- `end_id` : the full stop id of the ending node
- `mean_delay`: 1/the parameters of the exponential to take into account when computing the probability of success of this edge (0 for the walking edges).
- `std_delay`: the standard deviation of the delay of the transports between the two stops (1.0 for the walking edges).
- `median_delay`:the median of the delay of the transports between the two stops (0 for the walking edges).
- `duration`: is the expected travel time for transport edge or the walking time for walking edges.
- `transport`: the type of transport used or feet if it corresponds to walking edge.
- `line_number`: transport line number or -1 if walking edge.
- `is_trip`: boolean with value 1 if transport edge and 0 if walking edge.
- `waiting_time`:-1 if transport edge or the waiting time if this is a walking edge.
- `hour`: the hour at which the transport is taken.

In [29]:
%%spark
stop_times_zh_pairs.show(5, False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------------------------------+-----------------------------------------+--------+---------+-----------+-------+------------+----+
|start_id                                 |end_id                                   |duration|transport|line_number|is_trip|waiting_time|hour|
+-----------------------------------------+-----------------------------------------+--------+---------+-----------+-------+------------+----+
|8503064_10-41-00_1.TA.26-18-j19-1.1.H_0  |8503065:1_10-45-00_1.TA.26-18-j19-1.1.H_1|240     |Train    |18         |1      |0           |10  |
|8503065:1_10-45-00_1.TA.26-18-j19-1.1.H_0|8503074_10-46-00_1.TA.26-18-j19-1.1.H_1  |60      |Train    |18         |1      |0           |10  |
|8503074_10-46-00_1.TA.26-18-j19-1.1.H_0  |8503068_10-47-00_1.TA.26-18-j19-1.1.H_1  |60      |Train    |18         |1      |0           |10  |
|8503068_10-47-00_1.TA.26-18-j19-1.1.H_0  |8503066_10-48-00_1.TA.26-18-j19-1.1.H_1  |60      |Train    |18         |1      |0           |10  |

In [30]:
%%spark

stop_times_zh_pairs.write.save("/group/aces/graph/edges_transport_final.orc", format="orc", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Intra-station edges
Here we add the edge needed to link arrival nodes and departure nodes within a same station. We add only the following edges:
- Pairs of nodes that have non null and same parent_location (i.e. in same station) but different platforms (different stop_id), with a default 2 minutes walking time.
- Pairs of nodes that have same stop_id, with a walking time of 0. This contains 2 types of edges :
    - Edges in same (non-null) parent station and same platforms (same stop_id), i.e. does not get off the transport.
    - Edges with null parent station but same stop_id (e.g. a bus stop), i.e. does not get off the transport.
    
We made the assumption that the maximal acceptable waiting for someone is 15 minutes. Therefore, we remove any edge which doesn't match this constraint. 

In [31]:
%%spark
from pyspark.sql.functions import col 

acceptable_waiting_time = 15*60
walking_time_in_station = 2*60

# keep only the useful attributes from the nodes in the Zurich area
nodes_zurich_tmp = nodes_zurich.select(
    nodes_zurich.is_arrival,
    nodes_zurich.full_stop_id,
    to_timestamp(nodes_zurich.time, 'HH:mm:ss').cast('long').alias('time'),
    nodes_zurich.parent_location,
    nodes_zurich.stop_id
)

# Need to call directly after the cross join to be able to optimize it
# Here we select the nodes belonging to the same station but that are on different platforms.
all_same_station_edges_cross_diff_platforms = (
                            # Select only arrival node
                            nodes_zurich_tmp
                            .filter(nodes_zurich_tmp.is_arrival == 1)
                            .select(nodes_zurich_tmp.full_stop_id.alias("start_id"),
                                      nodes_zurich_tmp.time.alias("arr_time"),
                                      nodes_zurich_tmp.parent_location.alias("arr_parent"),
                                      nodes_zurich_tmp.stop_id.alias("start_stop_id"))
                            .crossJoin(
                                # Select only departure node 
                                nodes_zurich_tmp
                                        .filter(nodes_zurich_tmp.is_arrival == 0)
                                        .select(
                                            nodes_zurich_tmp.full_stop_id.alias("end_id"),
                                            nodes_zurich_tmp.time.alias("dep_time"),
                                            nodes_zurich_tmp.parent_location.alias("dep_parent"),
                                            nodes_zurich_tmp.stop_id.alias("end_stop_id"))
                            ).filter(
                                # Filter on pair of nodes whose both stops parent id are not null and are the same 
                                col("dep_parent").isNotNull() & col("arr_parent").isNotNull() & (col("arr_parent") == col("dep_parent")) & (col('start_stop_id') != col('end_stop_id'))
                            ))

# Here, we will build the edges for the same stop and subsequent trips. 
# This dataframe will create an edge between all the transport arriviing at each step
all_same_station_edges_cross_same_stop = (nodes_zurich_tmp
                            # Select only arrival node
                            .filter(nodes_zurich_tmp.is_arrival == 1)
                            .select(nodes_zurich_tmp.full_stop_id.alias("start_id"),
                                      nodes_zurich_tmp.time.alias("arr_time"),
                                    nodes_zurich_tmp.stop_id.alias("start_stop_id"))
                            .crossJoin(
                                # Select only departure node
                                nodes_zurich_tmp
                                        .filter(nodes_zurich_tmp.is_arrival == 0)
                                        .select(
                                            nodes_zurich_tmp.full_stop_id.alias("end_id"),
                                            nodes_zurich_tmp.time.alias("dep_time"),
                                            nodes_zurich_tmp.stop_id.alias("end_stop_id"))
                            ).filter(
                                # Keep the edges with the same stopID, i.e. in the same station
                                col('start_stop_id') == col('end_stop_id')
                            ))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [32]:
%%spark

# Keep nodes with same parent stations but different platforms
all_same_station_edges_diff_platforms = (all_same_station_edges_cross_diff_platforms
                             # Default walking time between platforms
                            .withColumn("duration", F.lit(walking_time_in_station))
                            .withColumn("waiting_time", col("dep_time")-col("arr_time")-col("duration"))
                            .filter((col("waiting_time") < acceptable_waiting_time) & (col("waiting_time") >= 0)) #Filter out the edges which don't match the time constraint of the roblem
                            .select(col("start_id"),
                                      col("end_id"),
                                      F.lit(0).alias("mean_delay"), # Default value for the mean (since the time on walking/waiting edge is deterministic)
                                      F.lit(1.0).alias("std_delay"), # Default value for the std (since the time on walking/waiting edge is deterministic)
                                      F.lit(0).alias("median_delay"), # Default value for the median (since the time on walking/waiting edge is deterministic)
                                      col('duration'), 
                                      F.lit("feet").alias("transport"),
                                      F.lit("-1").alias("line_number"), # Default value for the line number when we are not in a line (useful for probability computation)
                                      F.lit(0).alias("is_trip"), # is_trip is 1 only for transport edges
                                      col("waiting_time"),
                                      F.lit(-1).alias("hour")))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [33]:
%%spark

# Need to filter the stops missing in the stop_times files
line_numbers_zh = stop_times_zh_pairs.select('line_number').rdd.map(lambda x: x.line_number).distinct().collect()
trip_id2line_nb = routes.join(trips, on='route_id', how='inner').select('trip_id', 'route_short_name').distinct().filter(col('route_short_name').isin(line_numbers_zh)).collect()
# We build a dictionnary mapping the trip id to the route short name (the line number)
trip_id2line_nb = {r['trip_id']: r['route_short_name'] for r in trip_id2line_nb}

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [34]:
%%spark

@F.udf
def map_line_nb(arr_id, end_id):
    """
    Give a line number to the intra station walking edges
    if the edge is in the continuity of a trip. This is needed
    to correctly compute the confidence of a given path.
    :return: -1 if the two given id doesn't belong to the same stop, 
             -2 if the trip is not valid
             the line number otherwise
    """
    arr_stop_id = arr_id.split('_')[0]
    arr_trip_id = arr_id.split('_')[2]
    
    end_stop_id = end_id.split('_')[0]
    end_trip_id = end_id.split('_')[2]
    
    if (arr_stop_id == end_stop_id) and (arr_trip_id == end_trip_id):
        if arr_trip_id in trip_id2line_nb:
            return trip_id2line_nb[arr_trip_id]
        else:
            # If the trip id is not in the dictionnary, we consider the trip as invalid
            return "-2"
    else:
        return "-1"

# Keep nodes having same stop id, i.e. does change the platform 
all_same_station_edges_same_platform = (all_same_station_edges_cross_same_stop                                        
                            # 0 walking time is 0 since does not change the plateform
                            .withColumn("duration", F.lit(0))
                            .withColumn("waiting_time", col("dep_time")-col("arr_time")-col("duration"))
                            .filter((col("waiting_time") < acceptable_waiting_time) & (col("waiting_time") >= 0)) # Filer edges whose waiting time are above the maximal waiting time
                            .select(col("start_id"),
                                      col("end_id"),
                                      F.lit(0).alias("mean_delay"), # Default value for the mean (since the time on walking/waiting edge is deterministic)
                                      F.lit(1.0).alias("std_delay"), # Default value for the std (since the time on walking/waiting edge is deterministic)
                                      F.lit(0).alias("median_delay"), # Default value for the median (since the time on walking/waiting edge is deterministic)
                                      col('duration'),
                                      F.lit("feet").alias("transport"),
                                      F.lit("-1").alias("line_number"), # Default value for the line number when we are not in a line (useful for probability computation)
                                      F.lit(0).alias("is_trip"), # is_trip is 1 only for transport edges
                                      col("waiting_time"),
                                      F.lit(-1).alias("hour"))#Since the hour is used only to get the mean delay and there is no mean dealy in walking time, we can put -1
                            )

# remove all edges whose station is not on a valid trip_id (not in the stop_times file)
all_same_station_edges_same_platform = all_same_station_edges_same_platform.withColumn('line_number', map_line_nb(col('start_id'), col('end_id')))
all_same_station_edges_same_platform = all_same_station_edges_same_platform.filter(col('line_number') != '-2')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [35]:
%%spark
# Make the union of all the edges
all_same_station_edges = all_same_station_edges_diff_platforms.union(all_same_station_edges_same_platform).cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [36]:
%%spark
# Save it to HDFS
all_same_station_edges.write.save("/group/aces/graph/all_same_station_edges_final.orc", format="orc", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [37]:
%%spark
all_same_station_edges.count() # 3_260_206

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

3260206

In [38]:
%%spark
nodes_zurich.count() # 697_279

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

697279

### Build the walking edges
Here we add the walking edges between close stations. The edges satisfying the following conditions are added :
- End points are at most 500m appart (from the problem statement)
- End points must :
    - Either have both null parent station and different stop_id
    - Or have different stop_id and different parent station (one potentially null).

In [39]:
%%spark

max_distance = 500
meter_second = 60.0 / 50 # Need to divide float by integer, since the Scala interpret in the backend thinks that 60/50 = 1

# Keep only the nodes that are on the specified perimter
nodes_zh = node_info.filter(haversine(all_nodes.stop_lat, all_nodes.stop_lon) < RADIUS + MARGIN).cache()

# Select all the pairs of nodes such that the ids are not the same and their distance is smaller than the max_distance
all_pairs_stop_id = (nodes_zh.select(nodes_zh.stop_id.alias("start_id"),
                          nodes_zh.stop_lat.alias("arr_lat"),
                          nodes_zh.stop_lon.alias("arr_lon"),
                          nodes_zh.parent_location.alias("arr_parent"))
                 .distinct()
                 .crossJoin(nodes_zh
                            .select(
                                nodes_zh.stop_id.alias("end_id"),
                                nodes_zh.stop_lat.alias("dep_lat"),
                                nodes_zh.stop_lon.alias("dep_lon"),
                                nodes_zh.parent_location.alias("dep_parent"))
                            .distinct()
                   )
                .filter(
                    (col('start_id') != col('end_id')) & 
                    ((col('dep_parent').isNull() & col('arr_parent').isNull()) |
                     (col('dep_parent').isNotNull() & col('arr_parent').isNull()) |
                     (col('dep_parent').isNull() & col('arr_parent').isNotNull()) |
                     (col('dep_parent') != col('arr_parent')))
                )
                .withColumn("distance", haversine(col("arr_lat"), col("arr_lon"), col("dep_lat"), col("dep_lon"))) # Compute the distance between the two stops
                .filter(col("distance") <= max_distance) #Filter out the stops that are larger than the required distance.
                .withColumn("walking_time", col("distance") * meter_second)) # Compute walking time

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [40]:
%%spark
# Finally, w create all the walking edges between the pairs of node id selected before.
# We must be careful and select the arrival edge for the left hand of the edge and departure edge for the right end of the edge
all_walking_edges = (all_pairs_stop_id.join(
                    nodes_zurich.filter(nodes_zurich.is_arrival == 1)
                                .select(
                                        to_timestamp(nodes_zurich.time, 'HH:mm:ss').cast('long').alias("arr_time"),
                                        nodes_zurich.stop_id.alias("start_id"),
                                        col("time").alias("arr_time_h"),
                                        nodes_zurich.full_stop_id.alias("full_arr_id")), on=["start_id"], how='inner')
                .join(
                    nodes_zurich.filter(nodes_zurich.is_arrival == 0)
                                .select(
                                        to_timestamp(nodes_zurich.time, 'HH:mm:ss').cast('long').alias("dep_time"),
                                        nodes_zurich.stop_id.alias("end_id"),
                                        col("time").alias("dep_time_h"),
                                        nodes_zurich.full_stop_id.alias("full_dep_id")), on=["end_id"], how='inner')
                # filter out all the pairs that doesn't match the time requiremeent
                .filter(((col("dep_time") - col("arr_time")) >= col("walking_time")) & ((col("dep_time")-col("arr_time")-col("walking_time")) < acceptable_waiting_time))
                # select the correct column in the correct order
                .select(col("full_arr_id").alias("start_id"),
                      col("full_dep_id").alias("end_id"),
                      F.lit(0).alias("mean_delay"), # Default value for the mean (since the time on walking/waiting edge is deterministic)
                      F.lit(1.0).alias("std_delay"), # Default value for the std (since the time on walking/waiting edge is deterministic)
                      F.lit(0).alias("median_delay"), # Default value for the median (since the time on walking/waiting edge is deterministic)
                      col('walking_time').alias('duration'),
                      F.lit("feet").alias("transport"),
                      F.lit("-1").alias("line_number"), # Default value for the line number when we are not in a line (useful for probability computation)
                      F.lit(0).alias("is_trip"), # is_trip is 1 only for transport edges
                      (col("dep_time")-col("arr_time")-col("walking_time")).alias("waiting_time"),
                      F.lit(-1).alias('hour'))).cache()


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [41]:
%%spark
# Save on disk
all_walking_edges.write.save("/group/aces/graph/all_walking_edges_final.orc", format="orc", mode='overwrite')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [42]:
%%spark
all_walking_edges.count() # 12_754_187 with 500m

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

12754187