# Should reorder a bit notebook and clean it, and write basis of pipeline

1. Construct `stop_times_zurich` (now it uses the graph to get nodes, but unecessary), cache it
2. Planner, when user asks journey:
    1. Create graph (2 hours before arrival)
    2. Create 2nd graph with only stations and edges reachable from Zürich
    3. Compute paths
    4. Find shortest path

## 1. Setup:

First we setup our spark application and load the necessary data containing information about specific transit information:  
- `stop_times.txt` : arrival and departure times at stops
- `stops.txt` : information about stops 
- `trips.txt` : information about journeys
- `calendar.txt` : information about which servcies are active on which dates

#### Set up spark:

In [2]:
%%configure
{"conf": {
    "spark.app.name": "dslab-group_final"
}}

A session has already been started. If you intend to recreate the session with new configurations, please include the -f argument.


#### Imports:

In [3]:
import networkx as nx
from geopy.distance import distance as geo_distance
from pyspark.sql import Row
from pyspark.sql.functions import *
from pyspark.sql.types import FloatType
from networkx.algorithms.shortest_paths.weighted import dijkstra_path

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Load data:

In [7]:
# Loading data, these are snapshots of the all available data
# Calendar and trips are useful to filter the other dataframe according to the day

stop_times = spark.read.format('orc').load('/data/sbb/timetables/orc/stop_times/000000_0')
stops = spark.read.format('orc').load('/data/sbb/timetables/orc/stops/000000_0')
trips = spark.read.format('orc').load('/data/sbb/timetables/orc/trips/000000_0')
calendar = spark.read.format('orc').load('/data/sbb/timetables/orc/calendar/000000_0')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 2. Pre-processing: 

Here, we pre-process the data to correspond to the following criteria:
- only consider journeys at reasonable hours of the day, and on a typical business day, and assuming the schedule of May 13-17, 2019
- allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes
- only consider journeys that start and end on known station coordinates (train station, bus stops, etc.), never from a random location
- only consider stations in a 15km radius of Zürich's train station, Zürich HB (8503000), (lat, lon) = (47.378177, 8.540192)
- only consider stations in the 15km radius that are reachable from Zürich HB, either directly, or via transfers through other stations within the area
- assume that the timetables remain unchanged throughout the 2018 - 2019 period

Filtering criteria for later use (e.g. probability intervals and planner):
- assuming that delays or travel times on the public transport network are uncorrelated with one another
- once a route is computed, a traveller is expected to follow the planned routes to the end, or until it fails (i.e. miss a connection)
- planner will not need to mitigate the traveller's inconvenience if a plan fails

### Stop times during rush-hour: 

Only consider journeys at reasonable hours of the day, thus we take only stop times that are in the window of rush-hour (e.g. from 8 a.m. to 8 p.m.). 

In [8]:
# Filter stop_times to be only in 08:00-19:59:
stop_times = stop_times.where((col('departure_time') >= '08:00:00') 
                              & (col('departure_time') <= '19:59:59'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Stations around Zürich HB:

Only consider stations in a 15km radius of Zürich's train station (Zürich HB). 

First we get the geolocation of Zürich Hauptbahnhof to be able to calculate the distance of the other stations to the Hauptbahnhof. 

In [14]:
zurich_pos = stops.where(column('stop_name') == 'Zürich HB').select('stop_lat', 'stop_lon').collect()
zurich_pos = (zurich_pos[0][0], zurich_pos[0][1])
print('Location of Zürich Hauptbahnhof (lat, lon) :'+str(zurich_pos))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Location of Zürich Hauptbahnhof (lat, lon) :(47.3781762039461, 8.54019357578468)

In [15]:
def zurich_distance(x, y):
    """zurich_distance: returns the distance of a station to Zurich HB
    @input: (lat,lon) of a station
    @output: distance in km to Zurich HB
    """
    return geo_distance(zurich_pos, (x,y)).km

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Then we create a dataframe `stops_zurich` of the stations where we add a column for the distance to Zurich HB. In that dataframe, we keep only those that are in a radisu of 15km to the HB. The same filter is applied to the `stop_times` df mentioned above. 

In [22]:
# filter stops:
stops_distance = stops.rdd.map(lambda x: (x['stop_id'], zurich_distance(x['stop_lat'], x['stop_lon'])))
stops_distance = spark.createDataFrame(stops_distance.map(lambda r: Row(stop_id=r[0], zurich_distance=r[1])))

stops_distance = stops_distance.filter(column('zurich_distance') <= 15)

# add distance to HB to stops info and keep only in radius of 15km
stops_zurich = stops_distance.join(stops, on='stop_id')

# keep onl stops in raidus of 15km
stop_times = stop_times.join(stops_distance.select('stop_id'), on='stop_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
# Cache it to save time:
stop_times.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[stop_id: string, trip_id: string, arrival_time: string, departure_time: string, stop_sequence: smallint, pickup_type: tinyint, drop_off_type: tinyint]

### Have a look at the data we have so far: 

#### Stop times: 

In [24]:
stop_times.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|stop_id|             trip_id|arrival_time|departure_time|stop_sequence|pickup_type|drop_off_type|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|8500926|127.TA.26-301-j19...|    15:11:00|      15:11:00|            9|          0|            0|
|8500926|124.TA.26-301-j19...|    12:41:00|      12:41:00|            9|          0|            0|
|8500926|120.TA.26-301-j19...|    18:11:00|      18:11:00|            9|          0|            0|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
only showing top 3 rows

#### Trips:

In [25]:
trips.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 3 rows

#### Stops:

In [26]:
stops.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+------------+----------------+----------------+-------------+--------------+
|stop_id|   stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-------+------------+----------------+----------------+-------------+--------------+
|1322000|    Altoggio|46.1672513851495|  8.345807131427|             |              |
|1322001|Antronapiana| 46.060121674738|8.11361957990831|             |              |
|1322002|      Anzola|45.9898698225697|8.34571729989858|             |              |
+-------+------------+----------------+----------------+-------------+--------------+
only showing top 3 rows

#### Calendar:

In [27]:
calendar.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|
+----------+------+-------+---------+--------+------+--------+------+
|  TA+b0nx9|  true|   true|     true|    true|  true|   false| false|
|  TA+b03bf|  true|   true|     true|    true|  true|   false| false|
|  TA+b0008|  true|   true|     true|    true|  true|   false| false|
+----------+------+-------+---------+--------+------+--------+------+
only showing top 3 rows

## 3. Create a network:

In [8]:
days_dict = {0: 'monday', 1: 'tuesday', 2: 'wednesday', 3: 'thursday', 4: 'friday'}

def day_trips(*day_ids):
    days = [days_dict[day_id] for day_id in day_ids]
    where_clause = " and ".join(days)

    day_services = calendar.where(where_clause).select('service_id')
    return day_services.join(trips, on='service_id').select('trip_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
graph = nx.MultiDiGraph()
nodes = stops_zurich.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
graph.add_nodes_from(nodes)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
@udf
def convertToMinute(s):
    h, m, _ = s.split(':')
    h,m = int(h), int(m)
    
    return h*60+m

# Keep only travels around zurich
nodes_list = list(graph.nodes())
stop_times_zurich = stop_times.filter(column('stop_id').isin(nodes_list))
print('Number of stop times around zurich:', stop_times_zurich.count())
# Convert time information to minutes elapsed since 0am
stop_times_zurich = stop_times_zurich.withColumn('arrival_time', convertToMinute(column('arrival_time')))
stop_times_zurich = stop_times_zurich.withColumn('departure_time', convertToMinute(column('departure_time')))
# Add next stop to dataframe
stop_times_zurich_2 = (stop_times_zurich.withColumn('stop_sequence_prev', column('stop_sequence')-1)
                       .select('trip_id',
                               column('stop_id').alias('next_stop'),
                               column('stop_sequence_prev').alias('stop_sequence'),
                               column('arrival_time').alias('next_arrival_time')))
# Add trip duration
stop_times_zurich = stop_times_zurich.join(stop_times_zurich_2, on=['trip_id', 'stop_sequence']).orderBy('trip_id', 'stop_sequence')
stop_times_zurich = stop_times_zurich.withColumn('trip_duration', column('next_arrival_time')-column('departure_time'))
stop_times_zurich = stop_times_zurich.select('trip_id', 'stop_id', 'arrival_time', 'departure_time', 'next_stop', 'trip_duration')
stop_times_zurich.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('Number of stop times around zurich:', 1453316)
DataFrame[trip_id: string, stop_id: string, arrival_time: string, departure_time: string, next_stop: string, trip_duration: double]

In [11]:
max_trip_duration = 2 #duration in hour 
def filter_edge_on_time(edges_df, day_id, hour, minute):
    edges_df = edges_df.join(day_trips(day_id), on='trip_id')
    arrival_minute = hour*60+minute
    min_dep_time = arrival_minute - 60*60*max_trip_duration
    edges_df = edges_df.filter((col('departure_time') > min_dep_time) & 
                                            (col('arrival_time') <= arrival_minute))

    return edges_df

# Example of graph construction: Wednesday arrival at 11:30:00
edges = (filter_edge_on_time(stop_times_zurich, 2, 11, 30)
         .rdd.map(lambda r: (r['stop_id'], r['next_stop'], {'duration': r['trip_duration'],
                                                          'time': float(r['departure_time'])})).collect())
print('Number of edges:', len(edges))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('Number of edges:', 104449)

In [12]:
_ = graph.add_edges_from(edges)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
stop_ids_zurich = stops_zurich.where(col('stop_name') == 'Zürich HB').select('stop_id').rdd.flatMap(lambda x: x).collect()
stop_ids_zurich

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[u'8503000', u'8503000:0:10', u'8503000:0:11', u'8503000:0:12', u'8503000:0:13', u'8503000:0:14', u'8503000:0:15', u'8503000:0:16', u'8503000:0:17', u'8503000:0:18', u'8503000:0:3', u'8503000:0:31', u'8503000:0:32', u'8503000:0:33', u'8503000:0:34', u'8503000:0:4', u'8503000:0:41/42', u'8503000:0:43/44', u'8503000:0:5', u'8503000:0:6', u'8503000:0:7', u'8503000:0:8', u'8503000:0:9', u'8503000P']

In [14]:
nodes_reachable_from_zurich = set(stop_ids_zurich)

for id_ in stop_ids_zurich:
    nodes_reachable_from_zurich = nodes_reachable_from_zurich.union(nx.descendants(graph, id_))
    
new_edges = [e for e in graph.edges(data=True) if e[0] in nodes_reachable_from_zurich and e[1] in nodes_reachable_from_zurich]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:
graph_zurich = nx.MultiDiGraph()
graph_zurich.add_nodes_from(nodes_reachable_from_zurich)
_ = graph_zurich.add_edges_from(new_edges)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [16]:
nx.descendants(graph_zurich, '8503000:0:10')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

set([u'8503509:0:4', u'8503011:0:1', u'8503312:0:3', u'8503312:0:2', u'8503310:0:3', u'8503310:0:2', u'8503102:0:2', u'8503102:0:3', u'8503009:0:4', u'8503129:0:3', u'8503508:0:3', u'8503129:0:4', u'8503000:0:34', u'8503000:0:33', u'8503000:0:32', u'8503000:0:31', u'8503104:0:2', u'8503104:0:3', u'8503340:0:1', u'8502209:0:2', u'8503509:0:3', u'8503307:0:2', u'8503147:0:1', u'8503141:0:1', u'8503007:0:3', u'8503126:0:1', u'8503126:0:2', u'8503141:0:2', u'8503311:0:4', u'8503311:0:3', u'8502221:0:2', u'8503313:0:5', u'8503003:0:2', u'8503003:0:3', u'8503003:0:1', u'8503000:0:14', u'8503000:0:16', u'8503000:0:11', u'8503000:0:13', u'8503000:0:12', u'8503147:0:2', u'8503203:0:2', u'8503125:0:3', u'8503125:0:2', u'8503308:0:4', u'8502222:0:3', u'8503103:0:3', u'8503103:0:2', u'8503202:0:3', u'8503340:0:2', u'8503004:0:1', u'8503004:0:2', u'8503306:0:1', u'8503306:0:2', u'8503306:0:3', u'8503007:0:2', u'8503204:0:4', u'8503101:0:3', u'8503101:0:4', u'8503305:0:6', u'8503305:0:5', u'8503305:

In [17]:
dijkstra_path(graph_zurich, '8503000:0:10', '8502208:0:3', weight='duration')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

['8503000:0:10', u'8503016:0:2', u'8503006:0:4', u'8503000:0:31', u'8503011:0:1', u'8503010:0:1', u'8503009:0:4', u'8503200:0:1', u'8503201:0:1', u'8503202:0:4', u'8502209:0:2', u'8502208:0:3']

In [18]:
gen = nx.all_simple_paths(graph_zurich, '8503000:0:10', '8502208:0:3')

for path in gen:
    print(path, flush=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
invalid syntax (<stdin>, line 4)
  File "<stdin>", line 4
    print(path, flush=True)
                     ^
SyntaxError: invalid syntax



In [19]:
graph.out_edges('8503000:0:10', data=True)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

OutMultiEdgeDataView([('8503000:0:10', u'8503006:0:6', {'duration': 6.0, 'time': 665.0}), ('8503000:0:10', u'8503006:0:6', {'duration': 6.0, 'time': 489.0}), ('8503000:0:10', u'8503016:0:2', {'duration': 9.0, 'time': 575.0})])

In [34]:
def filter_edges(edges, time):
    edges = [edge for edge in edges if edge[2]['time'] >= time]
    
    destinations = set([edge[1] for edge in edges])
    earliest_edges = []
    
    for destination in destinations:
        edges_to_dest = [edge for edge in edges if edge[1] == destination]
        earliest_edge = sorted(edges_to_dest, key=lambda edge: edge[2]['time'] + edge[2]['duration'])[0]
        earliest_edges.append(earliest_edge)
    
    return earliest_edges

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [49]:
def get_paths(graph, source, target, time, prev_path=[], num_hops=0, max_num_hops=4):
    if source == target:
        return [prev_path]
    elif num_hops >= max_num_hops:
        return None
    else:
        out_edges = graph.out_edges(source, data=True)
        out_edges = filter_edges(out_edges, time)
        # Should manually check
        print(time)
        print(out_edges)
        paths = []
        
        for out_edge in out_edges:
            new_paths = get_paths(graph, out_edge[1], target, time=out_edge[2]['time']+out_edge[2]['duration'], prev_path=prev_path+[source], num_hops=num_hops+1)
            if new_paths is not None:
                paths = paths + new_paths
                    
        return paths

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [37]:
out_edges = graph.out_edges('8503006:0:6', data=True)
out_edges = [edge for edge in out_edges]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [40]:
out_edges[0]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('8503006:0:6', u'8503016:0:1', {'duration': 5.0, 'time': u'586'})

In [55]:
dijkstra_path(graph_zurich, '8503000:0:10', '8503011:0:1')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

['8503000:0:10', u'8503016:0:2', u'8503006:0:4', u'8503000:0:5', u'8503011:0:1']

In [52]:
nx.descendants(graph_zurich, '8503000:0:10')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

set([u'8503509:0:4', u'8503011:0:1', u'8503312:0:3', u'8503312:0:2', u'8503310:0:3', u'8503310:0:2', u'8503102:0:2', u'8503102:0:3', u'8503009:0:4', u'8503129:0:3', u'8503508:0:3', u'8503129:0:4', u'8503000:0:34', u'8503000:0:33', u'8503000:0:32', u'8503000:0:31', u'8503104:0:2', u'8503104:0:3', u'8503340:0:1', u'8502209:0:2', u'8503509:0:3', u'8503307:0:2', u'8503147:0:1', u'8503141:0:1', u'8503007:0:3', u'8503126:0:1', u'8503126:0:2', u'8503141:0:2', u'8503311:0:4', u'8503311:0:3', u'8502221:0:2', u'8503313:0:5', u'8503003:0:2', u'8503003:0:3', u'8503003:0:1', u'8503000:0:14', u'8503000:0:16', u'8503000:0:11', u'8503000:0:13', u'8503000:0:12', u'8503147:0:2', u'8503203:0:2', u'8503125:0:3', u'8503125:0:2', u'8503308:0:4', u'8502222:0:3', u'8503103:0:3', u'8503103:0:2', u'8503202:0:3', u'8503340:0:2', u'8503004:0:1', u'8503004:0:2', u'8503306:0:1', u'8503306:0:2', u'8503306:0:3', u'8503007:0:2', u'8503204:0:4', u'8503101:0:3', u'8503101:0:4', u'8503305:0:6', u'8503305:0:5', u'8503305:

In [50]:
get_paths(graph_zurich, '8503000:0:10', '8503202:0:4', 570, max_num_hops=6)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

570
[('8503000:0:10', u'8503006:0:6', {'duration': 6.0, 'time': 665.0}), ('8503000:0:10', u'8503016:0:2', {'duration': 9.0, 'time': 575.0})]
671.0
[(u'8503006:0:6', u'8503016:0:1', {'duration': 5.0, 'time': 676.0}), (u'8503006:0:6', u'8503016:0:2', {'duration': 4.0, 'time': 682.0}), (u'8503006:0:6', u'8503020:0:2', {'duration': 4.0, 'time': 690.0})]
681.0
[]
686.0
[(u'8503016:0:2', u'8503307:0:2', {'duration': 4.0, 'time': 687.0})]
691.0
[]
694.0
[]
584.0
[(u'8503016:0:2', u'8503307:0:2', {'duration': 4.0, 'time': 597.0})]
601.0
[(u'8503307:0:2', u'8503305:0:2', {'duration': 5.0, 'time': 601.0})]
606.0
[(u'8503305:0:2', u'8503304:0:2', {'duration': 3.0, 'time': 617.0})]
[]

In [23]:
# Example for stop_times filtered on wednesday
stop_times_wed = day_trips(2).join(stop_times, on='trip_id')
stop_times_wed.show(5)

stop_times_wed.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

404950

In [6]:
# Can't run, the count makes it timeout. I asked Tao why
#print('Full stop times have', stop_times.count(), 'entries, filtered has', stop_times_wed.count())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
#for r in stops_zurich.collect():
#    if r['stop_id'] == '8503000':
#        print((r['stop_id'], {'name': r['stop_name'].encode('utf-8'), 'lat': r['stop_lat'], 'lon': r['stop_lon']}))
#        break

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(u'8503000', {'lat': 47.3781762039461, 'lon': 8.54019357578468, 'name': 'Z\xc3\xbcrich HB'})