# SBB planner 2.0.

1. Construct `stop_times_zurich` (now it uses the graph to get nodes, but unecessary), cache it
2. Planner, when user asks journey:
    1. Create graph (2 hours before arrival)
    2. Create 2nd graph with only stations and edges reachable from Zürich
    3. Compute paths
    4. Find shortest path

## 1. Setup:

First we setup our spark application and load the necessary data containing information about specific transit information:  
- `stop_times.txt` : arrival and departure times at stops
- `stops.txt` : information about stops 
- `trips.txt` : information about journeys
- `calendar.txt` : information about which servcies are active on which dates

#### Set up spark:

In [1]:
%%configure
{"conf": {
    "spark.app.name": "dslab-group_final"
}}

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
6512,application_1589299642358_1001,pyspark,idle,Link,Link,
6531,application_1589299642358_1020,pyspark,idle,Link,Link,
6532,application_1589299642358_1021,pyspark,busy,Link,Link,
6538,application_1589299642358_1027,pyspark,idle,Link,Link,
6542,application_1589299642358_1031,pyspark,idle,Link,Link,
6543,application_1589299642358_1032,pyspark,idle,Link,Link,
6545,application_1589299642358_1034,pyspark,idle,Link,Link,
6548,application_1589299642358_1037,pyspark,busy,Link,Link,
6550,application_1589299642358_1039,pyspark,idle,Link,Link,
6552,application_1589299642358_1041,pyspark,idle,Link,Link,


#### Imports:

In [2]:
import networkx as nx
from geopy.distance import distance as geo_distance
from pyspark.sql import Row
from pyspark.sql.functions import col
from pyspark.sql.types import FloatType
from networkx.algorithms.shortest_paths.weighted import dijkstra_path

from heapq import heappush, heappop
from itertools import count

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
6561,application_1589299642358_1050,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

##### Complete version:

In [3]:
# Loading data, these are snapshots of the all available data
# Calendar and trips are useful to filter the other dataframe according to the day

stop_times = spark.read.format('orc').load('/data/sbb/timetables/orc/stop_times/000000_0')
stops = spark.read.format('orc').load('/data/sbb/timetables/orc/stops/000000_0')
trips = spark.read.format('orc').load('/data/sbb/timetables/orc/trips/000000_0')
calendar = spark.read.format('orc').load('/data/sbb/timetables/orc/calendar/000000_0')
transfers = spark.read.format('orc').load('/data/sbb/timetables/orc//transfers/000000_0')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 2. Pre-processing: 

Here, we pre-process the data to correspond to the following criteria:
- only consider journeys at reasonable hours of the day, and on a typical business day, and assuming the schedule of May 13-17, 2019
- allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes
- only consider journeys that start and end on known station coordinates (train station, bus stops, etc.), never from a random location
- only consider stations in a 15km radius of Zürich's train station, Zürich HB (8503000), (lat, lon) = (47.378177, 8.540192)
- only consider stations in the 15km radius that are reachable from Zürich HB, either directly, or via transfers through other stations within the area
- assume that the timetables remain unchanged throughout the 2018 - 2019 period

Filtering criteria for later use (e.g. probability intervals and planner):
- assuming that delays or travel times on the public transport network are uncorrelated with one another
- once a route is computed, a traveller is expected to follow the planned routes to the end, or until it fails (i.e. miss a connection)
- planner will not need to mitigate the traveller's inconvenience if a plan fails

#### Stop times during rush-hour: 

Only consider journeys at reasonable hours of the day, thus we take only stop times that are in the window of rush-hour (e.g. from 8 a.m. to 8 p.m.). 

In [4]:
# Filter stop_times to be only in 08:00-19:59:
stop_times = stop_times.where((col('departure_time') >= '08:00:00') 
                              & (col('departure_time') <= '19:59:59'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Stations around Zürich HB:

Only consider stations in a 15km radius of Zürich's train station (Zürich HB). 

First we get the geolocation of Zürich Hauptbahnhof to be able to calculate the distance of the other stations to the Hauptbahnhof. 

In [5]:
zurich_pos = stops.where(column('stop_name') == 'Zürich HB').select('stop_lat', 'stop_lon').collect()
zurich_pos = (zurich_pos[0][0], zurich_pos[0][1])
print('Location of Zürich Hauptbahnhof (lat, lon) :'+str(zurich_pos))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Location of Zürich Hauptbahnhof (lat, lon) :(47.3781762039461, 8.54019357578468)

In [6]:
def zurich_distance(x, y):
    """zurich_distance: returns the distance of a station to Zurich HB
    @input: (lat,lon) of a station
    @output: distance in km to Zurich HB
    """
    return geo_distance(zurich_pos, (x,y)).km

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Then we create a dataframe `stops_zurich` of the stations where we add a column for the distance to Zurich HB. In that dataframe, we keep only those that are in a radisu of 15km to the HB. The same filter is applied to the `stop_times` df mentioned above. 

In [7]:
# filter stops:
stops_distance = stops.rdd.map(lambda x: (x['stop_id'], zurich_distance(x['stop_lat'], x['stop_lon'])))
stops_distance = spark.createDataFrame(stops_distance.map(lambda r: Row(stop_id=r[0], 
                                                                        zurich_distance=r[1])))

stops_distance = stops_distance.filter(column('zurich_distance') <= 15)

#print('There are '+str(stops_distance.count())+' stops in a radius of 15km around Zurich HB')

# add distance to HB to stops info and keep only in radius of 15km
stops_zurich = stops_distance.join(stops, on='stop_id')

# keep only stop times in radius of 15km of Zurich
stop_times_zurich = stop_times.join(stops_distance.select('stop_id'), on='stop_id')

#print('There are '+str(stop_times_zurich.count())+ ' stop times in a radius of 15km around  Zurich HB')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
# Cache it to save time:
stop_times_zurich.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[stop_id: string, trip_id: string, arrival_time: string, departure_time: string, stop_sequence: smallint, pickup_type: tinyint, drop_off_type: tinyint]

### Have a look at the data we have so far: 

#### Stop times in Zurich: 

In [9]:
stop_times_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|stop_id|             trip_id|arrival_time|departure_time|stop_sequence|pickup_type|drop_off_type|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|8502508|9.TA.1-303-j19-1.2.R|    19:55:00|      19:55:00|            6|          0|            0|
|8502508|12.TA.1-303-j19-1...|    09:55:00|      09:55:00|            6|          0|            0|
|8502508|13.TA.1-303-j19-1...|    08:25:00|      08:25:00|            6|          0|            0|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
only showing top 3 rows

#### Trips in Zurich:

In [10]:
trips.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 3 rows

#### Stops in Zurich:

In [11]:
stops_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
|    stop_id|   zurich_distance|           stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
|    8500926|11.510766966884365|Oetwil a.d.L., Sc...|47.4236270123012| 8.4031825286317|             |              |
|    8502186|10.798985488832079|Dietikon Stoffelbach|47.3934058321612|8.39894248049007|             |      8502186P|
|8502186:0:1|10.800041577194426|Dietikon Stoffelbach|47.3934666445388|8.39894248049007|             |      8502186P|
+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
only showing top 3 rows

#### Calendar:

In [12]:
calendar.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|
+----------+------+-------+---------+--------+------+--------+------+
|  TA+b0nx9|  true|   true|     true|    true|  true|   false| false|
|  TA+b03bf|  true|   true|     true|    true|  true|   false| false|
|  TA+b0008|  true|   true|     true|    true|  true|   false| false|
+----------+------+-------+---------+--------+------+--------+------+
only showing top 3 rows

## 3. Create a network:

From the pre-processed data, we would like to create a directed network where each node is a station and each edge between two nodes corresponds to a possible trip. 

A node will have the following attributes:
- stop_name: name of the station (e.g. Zurich HB)
- latitude
- longitude

An directed edge will have the following attributes:
- stop_id: the id of the stop the (directed) edge points from
- next_stop: the id of the stop the edge points to
- duration: the duration of the trip from stop_id to next_stop
- departure time: the time from which the service departs from stop_id

First, we create a function that returns all services that operate on certain days (as some of them might not operate on weekends for example). 

In [13]:
days_dict = {0: 'monday', 1: 'tuesday', 2: 'wednesday', 3: 'thursday', 4: 'friday'}

def day_trips(*day_ids):
    """
    day_trips: gives the trip_ids that operate on certain days
    input: a variable number of day ids
    output:s spark dataframe with trip_ids
    
    """
    days = [days_dict[day_id] for day_id in day_ids]
    where_clause = " and ".join(days)

    day_services = calendar.where(where_clause).select('service_id')
    return day_services.join(trips, on='service_id').select('trip_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Nodes:

Then we create a **multigraph** (e.g. more than one edge allowed between two nodes). 

In [14]:
graph_long = nx.MultiDiGraph()

nodes = stops_zurich.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
graph_long.add_nodes_from(nodes)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Save nodes to hdfs

In [19]:
%%local
import os
username = os.environ['JUPYTERHUB_USER']

In [20]:
%%send_to_spark -i username -t str -n username

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'username' as 'username' to Spark kernel

In [40]:
stops_zurich.write.format("orc").save("/user/{}/nodes.orc".format(username))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Edges:

Right now, we have only nodes that are stations in our graph and we would like to add edges between them showing possible trips at certain times. For this we need to do a few operations. 

First, in our `stop_times_zurich` table, we have the time of arrival and departure but we would like to have an idea of the time elapsed in minutes since 12 p.m. This way times will be easily subtractable and we can get an idea of trip duration in minutes. So we convert those two columns: 

In [15]:
@udf
def convertToMinute(s):
    h, m, _ = s.split(':')
    h,m = int(h), int(m)
    
    return h*60+m

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [16]:
# Convert time information to minutes elapsed since 0am
stop_times_zurich = stop_times_zurich.withColumn('arrival_time', 
                                                 convertToMinute(column('arrival_time')))
stop_times_zurich = stop_times_zurich.withColumn('departure_time', 
                                                 convertToMinute(column('departure_time')))
stop_times_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|stop_id|             trip_id|arrival_time|departure_time|stop_sequence|pickup_type|drop_off_type|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|8502508|9.TA.1-303-j19-1.2.R|        1195|          1195|            6|          0|            0|
|8502508|12.TA.1-303-j19-1...|         595|           595|            6|          0|            0|
|8502508|13.TA.1-303-j19-1...|         505|           505|            6|          0|            0|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
only showing top 3 rows

Then we want a dataframe that has the trip duration to the next stop from the current one on the trip. For that, we first create a table with the next stop and arrival time for each stop sequence in a trip. 

In [17]:
stop_times_zurich_2 = (stop_times_zurich.withColumn('stop_sequence_prev', column('stop_sequence')-1)
                       .select('trip_id',
                               column('stop_id').alias('next_stop'),
                               column('stop_sequence_prev').alias('stop_sequence'),
                               column('arrival_time').alias('next_arrival_time')))

stop_times_zurich_2.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+---------+-------------+-----------------+
|             trip_id|next_stop|stop_sequence|next_arrival_time|
+--------------------+---------+-------------+-----------------+
|9.TA.1-303-j19-1.2.R|  8502508|            5|             1195|
|12.TA.1-303-j19-1...|  8502508|            5|              595|
+--------------------+---------+-------------+-----------------+
only showing top 2 rows

Then we join this to the `stop_times_zurich` table to have trip duration (in minutes) and next stop information. 

In [18]:
# Add trip duration and next stop: 
stop_times_zurich = stop_times_zurich.join(stop_times_zurich_2, 
                                           on=['trip_id', 'stop_sequence']).orderBy('trip_id', 'stop_sequence')
stop_times_zurich = stop_times_zurich.withColumn('trip_duration', 
                                                 column('next_arrival_time')-column('departure_time'))
stop_times_zurich = stop_times_zurich.select('trip_id', 
                                             'stop_id', 'arrival_time', 'departure_time', 
                                             'next_stop', 'trip_duration').cache()
stop_times_zurich.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-------+------------+--------------+---------+-------------+
|             trip_id|stop_id|arrival_time|departure_time|next_stop|trip_duration|
+--------------------+-------+------------+--------------+---------+-------------+
|1.TA.1-231-j19-1.1.H|8582462|         578|           578|  8572600|          1.0|
|1.TA.1-231-j19-1.1.H|8572600|         579|           579|  8572601|          0.0|
+--------------------+-------+------------+--------------+---------+-------------+
only showing top 2 rows

#### Save edges informations to hdfs

In [21]:
stop_times_zurich.write.format("orc").save("/user/{}/edges.orc".format(username))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

To make our computations easier, we add a further assumption stating that we are interested only in trips between two stations that are done in under 2 hours. 

For this, we create a function that selects all trips on a certain week-day in a window of 2 hours before a certain arrival time. Thus if for example we want to arrive at 11:30, we are only interested in possible trips that departed 2 hours before 11:30. When creating the edges, we add the walking edges.

In [22]:
%%local
import pandas as pd
walking_times = pd.read_pickle('walking_times.pickle')

In [24]:
%%send_to_spark -i walking_times -t df -m 20000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'walking_times' as 'walking_times' to Spark kernel

#### TODO: Take care of transfers using transfers.txt

In [25]:
edges_walking = walking_times.toPandas()
edges_walking['attrs'] = edges_walking.apply(lambda x: {'time': -1, 'duration': x['duration']+2}, axis=1)
edges_walking = list(edges_walking[['source', 'target', 'attrs']].to_numpy())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [26]:
edges_walking[0]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

array([u'8500926', u'8590616',
       {'duration': 4.4523214721999995, 'time': -1}], dtype=object)

In [27]:
MAX_TRIP_DURATION = 2 #duration in hour 

def create_edges_for_trip(edges_df, day_id, hour, minute):
    """
    create_edges_for_trip: constructs edges (and thus trips) that exist in a window of two hours before a given input time
    @input:
    - edges_df: df from which we construct the edges
    - day_id: id of week-day (e.g. wednesday is day id 2, see dictionnary above)
    - hour, minute: time at which we want to arrive somewhere (e.g. 11:30)
    @output: data frame of selected edges
    """
    #select only the trips that occur on that day:
    edges_df= edges_df.join(day_trips(day_id), on='trip_id')
    
    arrival_minute = hour*60+minute
    min_dep_time = arrival_minute - 60*60*MAX_TRIP_DURATION
    
    #keep only those in a window of two hours:
    edges_df = edges_df.filter((col('departure_time') > min_dep_time) & 
                                            (col('arrival_time') <= arrival_minute))

    edges = edges_df.rdd.map(lambda r: (r['stop_id'], r['next_stop'], {'duration': r['trip_duration'],
                                                                       'time': float(r['departure_time'])})).collect()
    
    return edges + edges_walking

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We test this out on Wednesday with arrival at 11:30 

In [28]:
# Example of graph construction: Wednesday arrival at 11:30:00
edges_wed_11_30 = create_edges_for_trip(stop_times_zurich, 2, 11, 30)

print('Number of edges:', len(edges_wed_11_30))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('Number of edges:', 116983)

In [29]:
_ = graph_long.add_edges_from(edges_wed_11_30)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

# Simple graph creation

##### Mini-graph 1: 

In [52]:
## create a simple graph: 
trip_id = '1.TA.1-231-j19-1.1.H'

#select four stops : 
stops_minig = ['8582462','8572600','8572601','8502553']

stops_info = stops_zurich.where(column('stop_id').isin(stops_minig))
stop_times_info = stop_times_zurich.where((column('stop_id').isin(stops_minig))&(column('trip_id') == trip_id))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [48]:
mini_graph = nx.MultiDiGraph()

mini_nodes = stops_info.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
mini_graph.add_nodes_from(mini_nodes)


# add artificial edge: 
mini_graph.add_edges_from([('8582462', '8572600',  {'duration': 2.0, 'time': 578.0}), 
                          ('8582462', '8572600',  {'duration': 1.0, 'time': 578.0}),
                          ('8572600', '8572601',  {'duration': 0.0, 'time': 579.0}),
                          ('8572601', '8502553',  {'duration': 4.0, 'time': 579.0}),
                          ('8502553', '8572602',  {'duration': 2.0, 'time': 583.0}),])

#mini_graph.add_edges_from([('8582462', '8502553',  {'duration': 2.0, 'time': 578.0})])


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[0, 1, 0, 0, 0]

In [49]:
list(mini_graph.nodes(data=True))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[(u'8582462', {'lat': 47.3475576701104, 'lon': 8.34819665008309, 'name': u'Bremgarten AG, Zelgli'}), (u'8502553', {'lat': 47.3221585583935, 'lon': 8.380473118246, 'name': u'Unterlunkhofen, Breiten\xe4cker'}), ('8572602', {}), (u'8572601', {'lat': 47.3417386266265, 'lon': 8.35463757067112, 'name': u'Zufikon, Algier'}), (u'8572600', {'lat': 47.34464822855, 'lon': 8.3519875405826, 'name': u'Zufikon, Emaus'})]

In [50]:
list(mini_graph.edges(data=True))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[(u'8582462', '8572600', {'duration': 2.0, 'time': 578.0}), (u'8582462', '8572600', {'duration': 1.0, 'time': 578.0}), (u'8502553', '8572602', {'duration': 2.0, 'time': 583.0}), (u'8572601', '8502553', {'duration': 4.0, 'time': 579.0}), (u'8572600', '8572601', {'duration': 0.0, 'time': 579.0})]

<img src="images/mini_graph.png" width = '350'>

##### Mini-graph 2: 

In [22]:
mini_graph_2 = nx.MultiDiGraph()
mini_graph_2.add_nodes_from(['0','1','2','3','4','5'])
mini_graph_2.add_edges_from([('0', '1',  {'duration': 1.0, 'time': 578.0})])
mini_graph_2.add_edges_from([('0', '2',  {'duration': 1.0, 'time': 579.0})])
mini_graph_2.add_edges_from([('0', '3',  {'duration': 2.0, 'time': 580.0})])
mini_graph_2.add_edges_from([('1', '4',  {'duration': 1.0, 'time': 583.0})])
mini_graph_2.add_edges_from([('1', '2',  {'duration': 1.0, 'time': 580.0})])
mini_graph_2.add_edges_from([('2', '5',  {'duration': 1.0, 'time': 550.0})])
mini_graph_2.add_edges_from([('3', '5',  {'duration': 1.0, 'time': 584.0})])
list(mini_graph_2.edges(data=True))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[('1', '2', {'duration': 1.0, 'time': 580.0}), ('1', '4', {'duration': 1.0, 'time': 583.0}), ('0', '1', {'duration': 1.0, 'time': 578.0}), ('0', '3', {'duration': 2.0, 'time': 580.0}), ('0', '2', {'duration': 1.0, 'time': 579.0}), ('3', '5', {'duration': 1.0, 'time': 584.0}), ('2', '5', {'duration': 1.0, 'time': 550.0})]

<img src="images/mini_graph_2.png" width="400">

Function to get the edges from a dijkstra path, problem is that dijkstra only gives nodes and not the edges it chooses: 

## 4. Paths in Network: Dijkstra's algorithms and helping functions

In [30]:
# returns weights of edge: 
def get_weight_custom(graph, source, target, j):
    attr = graph.edges[(source, target, j)]
    return attr['duration']

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [31]:
# returns weights of edge: 
def get_time_custom(graph, source, target, j):
    attr = graph.edges[(source, target, j)]
    return attr['time']

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Normal  dijkstra:

In [32]:
def normal_dijkstra(G, first_source, paths=None, cutoff=None, last_target=None):
    
    G_succ = G.succ if G.is_directed() else G.adj
    paths = {first_source: [first_source]}

    push = heappush
    pop = heappop
    dist = {}  # dictionary of final distances
    
    # dictionnary of wthether it's the first time a node is visited
    seen = {first_source: 0}

    c = count()
    fringe = []  # use heapq with (distance,label) tuples
    push(fringe, (0, next(c), first_source))
    
    while fringe:
        #take the node to look at: 
        (d, _, source) = pop(fringe)
        
        # check if node has already been looked at: 
        if source in dist:
            continue  # already searched this node.
        
        # update the distance of the node
        dist[source] = d
        
        #stop if the node we look at is the target obviously
        if source == last_target:
            break
            
        # Look at all direct descendents from the source node: 
        for target, edges in G_succ[source].items():
            # Because it's a multigraph, need to look at all edges between two nodes:
            for edge_id in edges:
                
                # Get the duration between two nodes:
                cost = get_weight_custom(G, source, target, edge_id)
                
                if cost is None:
                        continue
                
                # Add the weight to the current distance of a node
                current_dist = dist[source] + get_weight_custom(G, source, target, edge_id)
                
                # if target has already been visited once and has a final distance:
                if target in dist:
                        # if we find a distance smaller than the actual distance in dic
                        # raise error because dic distances contains only final distances
                        if current_dist < dist[target]:
                            raise ValueError('Contradictory paths found:',
                                             'negative weights?')
                # either node node been seen before or the current distance is smaller than the 
                # proposed distance in seen[target]:
                elif target not in seen or current_dist < seen[target]:
                    # update the seen distance
                    seen[target] = current_dist
                    # push it onto the heap so that we will look at its descendants later
                    push(fringe, (current_dist, next(c), target))
                    
                    # update the paths till target:
                    if paths is not None:
                        paths[target] = paths[source] + [target]
    if paths is not None:
        return (dist, paths)
    return dist

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Have a look at nodes that are not reachable from Zürich HB in a 15km radius:

In [36]:
paths = normal_dijkstra(graph_long, '8503000')
not_reachable = set(graph_long.nodes) - set(paths[0].keys())
not_reachable

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

set([u'8591025', u'8583692', u'8591027', u'8503896', u'8530647', u'8590314', u'8590665', u'8594310', u'8583855', u'8530696', u'8594307', u'8594304', u'8579869', u'8503988', u'8578679', u'8573171', u'8573176'])

In [37]:
stops.where(col('stop_id').isin(list(not_reachable))).show(truncate=False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+----------------------------+----------------+----------------+-------------+--------------+
|stop_id|stop_name                   |stop_lat        |stop_lon        |location_type|parent_station|
+-------+----------------------------+----------------+----------------+-------------+--------------+
|8503896|Horgenberg, Moorschwand     |47.2527236083507|8.58658257706307|             |              |
|8503988|Freudwil, Im Dörfli         |47.3747939482936|8.73427459294572|             |              |
|8530647|Fällanden (See)             |47.3657593422308|8.65270756513631|             |              |
|8530696|Mönchaltorf (See)           |47.3259825998046|8.69992301647621|             |              |
|8573171|Rifferswil, Unterrifferswil |47.2482356896554|8.49449627727516|             |              |
|8573176|Hausen am Albis, Vollenweid |47.2596193835131|8.51053120509893|             |              |
|8578679|Würenlos, Bettlen           |47.4459883921025|8.37087911101027|          

See that those trips are not reachable in the 15km radius because the go through stations outside the radius to then come back inside the raidus. We ignore those. 

### Dijkstra with time: 

Finds shortest paths between two stations and takes into acccount only paths taht are doable in time. 

##### Algorithm: (Implémenté plus bas) 

Input: 
- `first_source` : node from which we start path
- `last_target`: final target node
- `INPUT_TIME`: time from which we start algo

Output: (dist[last_target], e_paths[last_target])
- distances `first_source`to `last_target`
- shortest path to to `last_target`  from `first_source`

Mark all nodes unvisited: 
- dic of visited nodes: `seen = {}` 
- dic with final distances: `dist = {}`
- `queue = empty`

Start with the `first_source` node: 
- Assign to our source node the input_time as starting distance (instead of 0 as in normal Dijkstra): `dist[first_source] = INPUT_TIME`
- `queue += (d = INPUT_TIME, node = first_source)`


While the queue is not empty: 
- Take first node `source` in `queue` and its distance `d`
- Update `dist[source] = d`
- If `source == last_target of Dijkstra` : STOP 
- For all `target` in its direct neighbours: 
    - For all edges to `target` from `source` (multigraph):
      - if `departure_time[target] >= dist[source]` (e.g the time till the source is enough to take a new connection): `current_distance = departure_time[target]+ duration[target]`. On update la distance comme ça pour moi, parce que si on arrive à prendre cette connection on arrivera à un temps qui sera le temps de départ de la connection + la durée du trajet. Donc les distances dans notre dijkstra sont en fait le temps depuis miniuit jusqu'à l'arrivée à un node. 
      - if target not in `{seen}` (e.g. not been seen before) or  `current_dist < seen[target]` (e.g. there is a quicker way to neighbour): 
         - Update the seen distance to neighbour in `seen` dictionnary: `seen[target] = current_distance`
         - Push it onto queue so that we will look at its descendants later and update the final distance to target: `queue += (current_distance, target)`
         
Note: see code on how paths are updated. I have la flemme to add it. En plus, il doit y avoir un truc qu'on prend dans la queue le node avec la plus petite distance jusqu'à source mais je crois que c'est dans l'implémentation de heappop. 


##### Dijkstra from wiki: 
- Mark all nodes unvisited. Create a set of all the unvisited nodes called the unvisited set.
- Assign to every node a tentative distance value: set it to zero for our initial node and to infinity for all other nodes. Set the initial node as current.[14]
- For the current node, consider all of its unvisited neighbours and calculate their tentative distances through the current node. Compare the newly calculated tentative distance to the current assigned value and assign the smaller one. For example, if the current node A is marked with a distance of 6, and the edge connecting it with a neighbour B has length 2, then the distance to B through A will be 6 + 2 = 8. If B was previously marked with a distance greater than 8 then change it to 8. Otherwise, the current value will be kept.
- When we are done considering all of the unvisited neighbours of the current node, mark the current node as visited and remove it from the unvisited set. A visited node will never be checked again.
- If the destination node has been marked visited (when planning a route between two specific nodes) or if the smallest tentative distance among the nodes in the unvisited set is infinity (when planning a complete traversal; occurs when there is no connection between the initial node and remaining unvisited nodes), then stop. The algorithm has finished.
- Otherwise, select the unvisited node that is marked with the smallest tentative distance, set it as the new "current node", and go back to step 3.


##### A rajouter : 
- retourner les edges du path (mais j'avais déjà implemnté un truc pour dijkstra facile, faut voir si ça marche toujours pour cet algo)
- ajouter les probas
- tester sur un plus grand graph
- faire à l'envers comme Jules pensait ?? 
- extentsion to k shortest paths ? :
    - discussion on k shortest paths in multigraph: [source](https://groups.google.com/forum/#!topic/networkx-discuss/87uC9F0ug8Y)
    - paper on algorithms to find k-shortests paths: [source](https://www.ics.uci.edu/~eppstein/pubs/Epp-SJC-98.pdf)
    - Comment from forum: "I had a look at the single_source_dijkstra() function and with quite an easy adaptation I can get what I want. Instead of storing nodes in the path I would just store edges. Then I would return the list of edges. Then I would just implement David Eppstein's k-shortest paths algorithm. I didn't understand the comment about the algorithm not being implemented multigraphs because it slows down the algorithm too much. As if you look at the special case where a Multigraph is equal to Graph (i.e. only one edge) the two implementations perform the same. I agree that having Multigraph requires more work to find the shortest path as you then have more neighbors to investigate."


##### Implementation:

In [28]:
def minute_to_string(m):
    hour, minute = m//60, m - m//60
    time_string = '{}:{}:00'.format(hour, minute)
    
    return time_string

def string_to_minute(s):
    h, m, _ = s.split(':')
    h,m = int(h), int(m)
    
    return h*60+m

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [29]:
def dijkstra_time(G, first_source, INPUT_TIME, paths=None, last_target=None):
    
    G_succ = G.succ if G.is_directed() else G.adj
    
    paths = {first_source: [first_source]}
    e_paths = {first_source: []}

    push = heappush
    pop = heappop
    dist = {}  # dictionary of final distances
    
    # dictionnary of wthether it's the first time a node is visited
    seen = {first_source: INPUT_TIME}
    
    
    c = count()
    fringe = []  # use heapq with (distance,label) tuples
    
    #push(fringe, (0, next(c), first_source))
    push(fringe, (INPUT_TIME, next(c), first_source))
    
    while fringe:
        #take the node to look at: 
        (d, _, source) = pop(fringe)
        #print('Looking at node: '+source)
        
        # check if node has already been looked at: 
        if source in dist:
            continue  # already searched this node.
        
        # update the distance of the node
        dist[source] = d
        
        #stop if the node we look at is the target obviously
        if source == last_target:
            break
        
        # Look at all direct descendents from the source node: 
        for target, edges in G_succ[source].items():
            # Because it's a multigraph, need to look at all edges between two nodes:
            for edge_id in edges:
                dep_time_edge = get_time_custom(G, source, target, edge_id)
                if dep_time_edge == -1:
                    dep_time_edge = d
                
                # Note: checker si chgt de ligne faire +2min
                if dep_time_edge >= dist[source]:
                    # Get the duration between two nodes:
                    duration_cost = get_weight_custom(G, source, target, edge_id)

                    if duration_cost is None:
                            continue

                    # Add the weight to the current distance of a node
                    current_dist = dep_time_edge+ duration_cost

                    # if target has already been visited once and has a final distance:
                    if target in dist:
                            # if we find a distance smaller than the actual distance in dic
                            # raise error because dic distances contains only final distances
                            if current_dist < dist[target]:
                                raise ValueError('Contradictory paths found:',
                                                 'negative weights?')

                    # either node node been seen before or the current distance is smaller than the 
                    # proposed distance in seen[target]:
                    elif target not in seen or current_dist < seen[target]:
                        # update the seen distance
                        seen[target] = current_dist
                        # push it onto the heap so that we will look at its descendants later
                        push(fringe, (current_dist, next(c), target))

                        # update the paths till target:
                        if paths is not None:
                            #paths[target] = paths[source] + [target]
                            e_paths[target] = e_paths[source] + [(source, target, {'departure_time':dep_time_edge, 'duration':duration_cost})]
                            
    if  last_target not in e_paths:
        print('Error: No paths to the source')
        return (0, [])
        #raise ValueError('No paths exist to the source') 
    
    if paths is not None:
        #return (dist, paths, e_paths)
        #return (dist, e_paths)<
        
        #for _ in range(100):
            #Validate path
            #for e in path:
                #sample_gaussian
                #check if miss connection
            #If > 0 connection missed, path missed
        # if 95% must have missed < 5 path
        # if path not validated -> starts with smaller threshold 
        
        arrival_string = minute_to_string(dist[last_target])
        best_path = e_paths[last_target]
        departure_string = minute_to_string(best_path[0][2]['departure_time'])
        stations_id = map(lambda x: x[0], best_path)
        return (dist[last_target], e_paths[last_target])
    return dist

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [30]:
# TEST 1: si input time > temps de départ des edges: 
# output voulu: seulement la source, no path
dijkstra_time(G = mini_graph_2, first_source = '0', last_target = '5', INPUT_TIME = 585)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Error: No paths to the source
(0, [])

In [31]:
# TEST 2: input time < temps départ edges:
# devrait aller jusqu'au node 5 et donner le chemin à traver node 3 
dijkstra_time(G = mini_graph_2, first_source = '0', last_target = '5', INPUT_TIME = 570)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(585.0, [('0', '3', {'duration': 2.0, 'departure_time': 580.0}), ('3', '5', {'duration': 1.0, 'departure_time': 584.0})])

In [53]:
# TEST 3: sur des nodes qui sont des stations:
# devrait trouver un chemin rapide de 3 stations
dijkstra_time(G = mini_graph, first_source = '8582462', last_target = '8572602', INPUT_TIME = 570)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(585.0, [('8582462', '8572600', {'duration': 1.0, 'departure_time': 578.0}), ('8572600', '8572601', {'duration': 0.0, 'departure_time': 579.0}), ('8572601', '8502553', {'duration': 4.0, 'departure_time': 579.0}), ('8502553', '8572602', {'duration': 2.0, 'departure_time': 583.0})])

In [54]:
# Test 4 : sur network plus grand de Jules, devrait avoir le même output que juste ci-dessus

## Look at the same stops as before but in the bigger network: 
trip_id = '1.TA.1-231-j19-1.1.H'

#select four stops on that trip: 
stops_test4 = ['8582462','8572600','8572601','8502553']

dijkstra_time(G = graph, first_source = '8582462', last_target = '8572602', INPUT_TIME = 570)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(585.0, [('8582462', u'8572600', {'duration': 8.6414038086, 'departure_time': 570}), (u'8572600', u'8572601', {'duration': 0.0, 'departure_time': 579.0}), (u'8572601', u'8502553', {'duration': 4.0, 'departure_time': 579.0}), (u'8502553', u'8572602', {'duration': 2.0, 'departure_time': 583.0})])

In [55]:
# Old Test 5 (without the 2 minutes added)
print(dijkstra_time(G=graph, first_source='8503000', last_target='8591363', INPUT_TIME=9*60+30))
stops_zurich.where(col('stop_id').isin(['8503000:0:41/42', '8503020:0:3', '8591060', '8591177', '8591038', '8591145', '8591236', '8591203', '8591363'])).show(truncate=False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(586.0, [('8503000', u'8503000:0:41/42', {'duration': 0.135259304, 'departure_time': 570}), (u'8503000:0:41/42', u'8503020:0:3', {'duration': 2.0, 'departure_time': 571.0}), (u'8503020:0:3', u'8591060', {'duration': 0.9932111359, 'departure_time': 573.0}), (u'8591060', u'8591177', {'duration': 2.0, 'departure_time': 575.0}), (u'8591177', u'8591038', {'duration': 2.0, 'departure_time': 577.0}), (u'8591038', u'8591145', {'duration': 1.0, 'departure_time': 579.0}), (u'8591145', u'8591236', {'duration': 1.0, 'departure_time': 581.0}), (u'8591236', u'8591203', {'duration': 2.0, 'departure_time': 583.0}), (u'8591203', u'8591363', {'duration': 1.0, 'departure_time': 585.0})])
+---------------+--------------------+----------------------------+----------------+----------------+-------------+--------------+
|stop_id        |zurich_distance     |stop_name                   |stop_lat        |stop_lon        |location_type|parent_station|
+---------------+--------------------+----------------------

In [64]:
# New Test 5, with graph not created from pickle but with complete process
print(dijkstra_time(G=graph_long, first_source='8503000', last_target='8591363', INPUT_TIME=9*60+30))
stops_zurich.where(col('stop_id').isin(['8503000:0:41/42', '8503020:0:3', '8591060', '8591177', '8591038', '8591145', '8591236', '8591203', '8591363'])).show(truncate=False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(673.3953823855, [('8503000', u'8503088:0:21', {'duration': 4.0726364136, 'departure_time': 570}), (u'8503088:0:21', u'8591367', {'duration': 10.1753155518, 'departure_time': 574.0726364136}), (u'8591367', u'8591381', {'duration': 11.5476794434, 'departure_time': 584.2479519654}), (u'8591381', u'8503011:0:2', {'duration': 11.710758667, 'departure_time': 595.7956314088001}), (u'8503011:0:2', u'8591341', {'duration': 9.041846313499999, 'departure_time': 607.5063900758001}), (u'8591341', u'8502572', {'duration': 10.0800109863, 'departure_time': 616.5482363893001}), (u'8502572', u'8591390', {'duration': 8.4916168213, 'departure_time': 626.6282473756}), (u'8591390', u'8591170', {'duration': 9.0659631348, 'departure_time': 635.1198641969}), (u'8591170', u'8591208', {'duration': 7.4547485352, 'departure_time': 644.1858273317}), (u'8591208', u'8591203', {'duration': 11.3318994141, 'departure_time': 651.6405758669}), (u'8591203', u'8591363', {'duration': 10.4229071045, 'departure_time': 662.972