# Should reorder a bit notebook and clean it, and write basis of pipeline

1. Construct `stop_times_zurich` (now it uses the graph to get nodes, but unecessary), cache it
2. Planner, when user asks journey:
    1. Create graph (2 hours before arrival)
    2. Create 2nd graph with only stations and edges reachable from Zürich
    3. Compute paths
    4. Find shortest path

## Questions
1. Dijkstra
    1. Independent proba? doesn't really make sense
    2. Can we use our own algorithm? There are better algorithms to check different paths
    3. How is this similar to dijkstra if we not only consider the best paths?
2. Probability
    1. Check history, where? `spark.read.orc("hdfs:///data/sbb/orc/istdaten/").registerTempTable("actual")`
        1. Is `stop_id` same?
        2. Compute std for each stop, train_type and hour
    2. Compute std for each train type and hour and assume gaussian distribution around expected time?
3. Walking (small task)
    1. Walking are edges we compute at the beginning with data frame with station distance? (2 minutes + walking time)
    2. When checking possible paths we check edges in graph and walking edges. (Doesn't really make sense to put walking edges in network because they could be taken at any hour)

## Assumptions
- Trains always depart on time

## 1. Setup:

First we setup our spark application and load the necessary data containing information about specific transit information:  
- `stop_times.txt` : arrival and departure times at stops
- `stops.txt` : information about stops 
- `trips.txt` : information about journeys
- `calendar.txt` : information about which servcies are active on which dates

#### Set up spark:

In [16]:
%%configure
{"conf": {
    "spark.app.name": "dslab-group_final"
}}

A session has already been started. If you intend to recreate the session with new configurations, please include the -f argument.


#### Imports:

In [17]:
import networkx as nx
from geopy.distance import distance as geo_distance
from pyspark.sql import Row
import pyspark.sql.functions as f
from pyspark.sql.functions import *
from pyspark.sql.types import FloatType
from networkx.algorithms.shortest_paths.weighted import dijkstra_path

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Load data:

In [18]:
# Loading data, these are snapshots of the all available data
# Calendar and trips are useful to filter the other dataframe according to the day

stop_times = spark.read.format('orc').load('/data/sbb/timetables/orc/stop_times/000000_0')
stops = spark.read.format('orc').load('/data/sbb/timetables/orc/stops/000000_0')
trips = spark.read.format('orc').load('/data/sbb/timetables/orc/trips/000000_0')
calendar = spark.read.format('orc').load('/data/sbb/timetables/orc/calendar/000000_0')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 2. Pre-processing: 

Here, we pre-process the data to correspond to the following criteria:
- only consider journeys at reasonable hours of the day, and on a typical business day, and assuming the schedule of May 13-17, 2019
- allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes
- only consider journeys that start and end on known station coordinates (train station, bus stops, etc.), never from a random location
- only consider stations in a 15km radius of Zürich's train station, Zürich HB (8503000), (lat, lon) = (47.378177, 8.540192)
- only consider stations in the 15km radius that are reachable from Zürich HB, either directly, or via transfers through other stations within the area
- assume that the timetables remain unchanged throughout the 2018 - 2019 period

Filtering criteria for later use (e.g. probability intervals and planner):
- assuming that delays or travel times on the public transport network are uncorrelated with one another
- once a route is computed, a traveller is expected to follow the planned routes to the end, or until it fails (i.e. miss a connection)
- planner will not need to mitigate the traveller's inconvenience if a plan fails

### Stop times during rush-hour: 

Only consider journeys at reasonable hours of the day, thus we take only stop times that are in the window of rush-hour (e.g. from 8 a.m. to 8 p.m.). 

In [19]:
# Filter stop_times to be only in 08:00-19:59:
stop_times = stop_times.where((col('departure_time') >= '08:00:00') 
                              & (col('departure_time') <= '19:59:59'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Stations around Zürich HB:

Only consider stations in a 15km radius of Zürich's train station (Zürich HB). 

First we get the geolocation of Zürich Hauptbahnhof to be able to calculate the distance of the other stations to the Hauptbahnhof. 

In [20]:
zurich_pos = stops.where(column('stop_name') == 'Zürich HB').select('stop_lat', 'stop_lon').collect()
zurich_pos = (zurich_pos[0][0], zurich_pos[0][1])
print('Location of Zürich Hauptbahnhof (lat, lon) :'+str(zurich_pos))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Location of Zürich Hauptbahnhof (lat, lon) :(47.3781762039461, 8.54019357578468)

In [21]:
@udf("float")
def compute_distance(x1, y1, x2, y2):
    return geo_distance((x1, y1), (x2,y2)).m

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
def zurich_distance(x, y):
    """zurich_distance: returns the distance of a station to Zurich HB
    @input: (lat,lon) of a station
    @output: distance in km to Zurich HB
    """
    return geo_distance(zurich_pos, (x,y)).km

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Then we create a dataframe `stops_zurich` of the stations where we add a column for the distance to Zurich HB. In that dataframe, we keep only those that are in a radisu of 15km to the HB. The same filter is applied to the `stop_times` df mentioned above. 

In [23]:
# filter stops:
stops_distance = stops.rdd.map(lambda x: (x['stop_id'], zurich_distance(x['stop_lat'], x['stop_lon'])))
stops_distance = spark.createDataFrame(stops_distance.map(lambda r: Row(stop_id=r[0], 
                                                                        zurich_distance=r[1])))

stops_distance = stops_distance.filter(column('zurich_distance') <= 15)

#print('There are '+str(stops_distance.count())+' stops in a radius of 15km around Zurich HB')

# add distance to HB to stops info and keep only in radius of 15km
stops_zurich = stops_distance.join(stops, on='stop_id')

# keep only stop times in radius of 15km of Zurich
stop_times_zurich = stop_times.join(stops_distance.select('stop_id'), on='stop_id')

#print('There are '+str(stop_times_zurich.count())+ ' stop times in a radius of 15km around  Zurich HB')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [24]:
# Cache it to save time:
stop_times_zurich.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[stop_id: string, trip_id: string, arrival_time: string, departure_time: string, stop_sequence: smallint, pickup_type: tinyint, drop_off_type: tinyint]

### Have a look at the data we have so far: 

#### Stop times in Zurich: 

In [25]:
stop_times_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|stop_id|             trip_id|arrival_time|departure_time|stop_sequence|pickup_type|drop_off_type|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|8502508|9.TA.1-303-j19-1.2.R|    19:55:00|      19:55:00|            6|          0|            0|
|8502508|12.TA.1-303-j19-1...|    09:55:00|      09:55:00|            6|          0|            0|
|8502508|13.TA.1-303-j19-1...|    08:25:00|      08:25:00|            6|          0|            0|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
only showing top 3 rows

#### Trips in Zurich:

In [26]:
trips.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------+--------------------+------------------+---------------+------------+
|   route_id|service_id|             trip_id|     trip_headsign|trip_short_name|direction_id|
+-----------+----------+--------------------+------------------+---------------+------------+
|1-1-C-j19-1|  TA+b0001|5.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            108|           1|
|1-1-C-j19-1|  TA+b0001|7.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            112|           1|
|1-1-C-j19-1|  TA+b0001|9.TA.1-1-C-j19-1.3.R|Zofingen, Altachen|            116|           1|
+-----------+----------+--------------------+------------------+---------------+------------+
only showing top 3 rows

#### Stops in Zurich:

In [27]:
stops_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
|    stop_id|   zurich_distance|           stop_name|        stop_lat|        stop_lon|location_type|parent_station|
+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
|    8500926|11.510766966884365|Oetwil a.d.L., Sc...|47.4236270123012| 8.4031825286317|             |              |
|    8502186|10.798985488832079|Dietikon Stoffelbach|47.3934058321612|8.39894248049007|             |      8502186P|
|8502186:0:1|10.800041577194426|Dietikon Stoffelbach|47.3934666445388|8.39894248049007|             |      8502186P|
+-----------+------------------+--------------------+----------------+----------------+-------------+--------------+
only showing top 3 rows

#### Calendar:

In [28]:
calendar.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|
+----------+------+-------+---------+--------+------+--------+------+
|  TA+b0nx9|  true|   true|     true|    true|  true|   false| false|
|  TA+b03bf|  true|   true|     true|    true|  true|   false| false|
|  TA+b0008|  true|   true|     true|    true|  true|   false| false|
+----------+------+-------+---------+--------+------+--------+------+
only showing top 3 rows

## 3. Create a network:

From the pre-processed data, we would like to create a directed network where each node is a station and each edge between two nodes corresponds to a possible trip. 

A node will have the following attributes:
- stop_name: name of the station (e.g. Zurich HB)
- latitude
- longitude

An directed edge will have the following attributes:
- stop_id: the id of the stop the (directed) edge points from
- next_stop: the id of the stop the edge points to
- duration: the duration of the trip from stop_id to next_stop
- departure time: the time from which the service departs from stop_id

First, we create a function that returns all services that operate on certain days (as some of them might not operate on weekends for example). 

In [29]:
days_dict = {0: 'monday', 1: 'tuesday', 2: 'wednesday', 3: 'thursday', 4: 'friday'}

def day_trips(*day_ids):
    """
    day_trips: gives the trip_ids that operate on certain days
    input: ? 
    output:s spark dataframe with trip_ids
    
    """
    days = [days_dict[day_id] for day_id in day_ids]
    where_clause = " and ".join(days)

    day_services = calendar.where(where_clause).select('service_id')
    return day_services.join(trips, on='service_id').select('trip_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Nodes:

Then we create a **multigraph** (e.g. more than one edge allowed between two nodes). 

In [30]:
graph = nx.MultiDiGraph()

nodes = stops_zurich.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
graph.add_nodes_from(nodes)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Edges:

Right now, we have only nodes that are stations in our graph and we would like to add edges between them showing possible trips at certain times. For this we need to do a few operations. 

First, in our `stop_times_zurich` table, we have the time of arrival and departure but we would like to have an idea of the time elapsed in minutes since 12 p.m. This way times will be easily subtractable and we can get an idea of trip duration in minutes. So we convert those two columns: 

In [31]:
@udf
def convertToMinute(s):
    """converts seconds to minutes
    """
    h, m, _ = s.split(':')
    h,m = int(h), int(m)
    
    return h*60+m

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [32]:
# Convert time information to minutes elapsed since 0am
stop_times_zurich = stop_times_zurich.withColumn('arrival_time', 
                                                 convertToMinute(column('arrival_time')))
stop_times_zurich = stop_times_zurich.withColumn('departure_time', 
                                                 convertToMinute(column('departure_time')))
stop_times_zurich.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|stop_id|             trip_id|arrival_time|departure_time|stop_sequence|pickup_type|drop_off_type|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
|8502508|9.TA.1-303-j19-1.2.R|        1195|          1195|            6|          0|            0|
|8502508|12.TA.1-303-j19-1...|         595|           595|            6|          0|            0|
|8502508|13.TA.1-303-j19-1...|         505|           505|            6|          0|            0|
+-------+--------------------+------------+--------------+-------------+-----------+-------------+
only showing top 3 rows

Then we want a dataframe that has the trip duration to the next stop from the current one on the trip. For that, we first create a table with the next stop and arrival time for each stop sequence in a trip. 

In [33]:
stop_times_zurich_2 = (stop_times_zurich.withColumn('stop_sequence_prev', column('stop_sequence')-1)
                       .select('trip_id',
                               column('stop_id').alias('next_stop'),
                               column('stop_sequence_prev').alias('stop_sequence'),
                               column('arrival_time').alias('next_arrival_time')))

stop_times_zurich_2.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+---------+-------------+-----------------+
|             trip_id|next_stop|stop_sequence|next_arrival_time|
+--------------------+---------+-------------+-----------------+
|9.TA.1-303-j19-1.2.R|  8502508|            5|             1195|
|12.TA.1-303-j19-1...|  8502508|            5|              595|
+--------------------+---------+-------------+-----------------+
only showing top 2 rows

Then we join this to the `stop_times_zurich` table to have trip duration (in minutes) and next stop information. 

In [34]:
# Add trip duration and next stop: 
stop_times_zurich = stop_times_zurich.join(stop_times_zurich_2, 
                                           on=['trip_id', 'stop_sequence']).orderBy('trip_id', 'stop_sequence')
stop_times_zurich = stop_times_zurich.withColumn('trip_duration', 
                                                 column('next_arrival_time')-column('departure_time'))
stop_times_zurich = stop_times_zurich.select('trip_id', 
                                             'stop_id', 'arrival_time', 'departure_time', 
                                             'next_stop', 'trip_duration')
stop_times_zurich.show(2)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-------+------------+--------------+---------+-------------+
|             trip_id|stop_id|arrival_time|departure_time|next_stop|trip_duration|
+--------------------+-------+------------+--------------+---------+-------------+
|1.TA.1-231-j19-1.1.H|8582462|         578|           578|  8572600|          1.0|
|1.TA.1-231-j19-1.1.H|8572600|         579|           579|  8572601|          0.0|
+--------------------+-------+------------+--------------+---------+-------------+
only showing top 2 rows

In [35]:
#stop_times_zurich.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Routes

In [115]:
from pyspark.sql.types import IntegerType

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [116]:
routes = spark.read.format('orc').load('/data/sbb/timetables/orc/routes')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [117]:
trip_id_route_desc = trips.join(routes, 'route_id').select(col('trip_id'), col('route_desc')).distinct().join(stop_times_zurich, 'trip_id')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [118]:
translate_route_desc = {
    'TGV': 'TGV',
    'Eurocity': 'EC',
    'tandseilbahn': 'AT',
    'Regionalzug': 'R',
    'RegioExpress': 'RE',
    'S-Bahn': 'S',
    'Luftseilbahn': '',
    'Sesselbahn': '',
    'Taxi': '',
    'Fähre': '',
    'Tram': 'Tram',
    'ICE': 'ICE',
    'Bus': 'Bus',
    'Gondelbahn': '',
    'Nacht-Zug': '',
    'Standseilbahn': 'AT',
    'Auoreisezug': 'ARZ',
    'Eurostar': 'EC',
    'Schiff': '',
    'Schnellzug': 'TGV',
    'Intercity': 'IC',
    'InterRegio': 'IR',
    'Extrazug': 'EXT',
    'Metro': 'Metro'
}

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [119]:
@udf("string")
def translate_dict(text):
    return translate_route_desc[text]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [120]:
%%local
import pandas as pd
delay_distribution = pd.read_pickle('pickle_delay_distribution')

In [121]:
%%send_to_spark -i delay_distribution -t df -m 1000000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'delay_distribution' as 'delay_distribution' to Spark kernel

In [122]:
trip_id_route_desc = trip_id_route_desc.withColumn('route_desc_translated', translate_dict(col('route_desc')))\
                                       .withColumn('hour', (col('arrival_time')/60).cast(IntegerType())).cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [123]:
delay_distribution = delay_distribution.select(col('mean'), col('std'), col('hour').alias('hour_2'), col('stop_id').alias('stop_id_2'), col('verkehrsmittel_text'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [124]:
stop_times_final = trip_id_route_desc.join(delay_distribution, (trip_id_route_desc.hour == delay_distribution.hour_2) &\
                                            (trip_id_route_desc.stop_id == delay_distribution.stop_id_2) &\
                                            (trip_id_route_desc.route_desc_translated == delay_distribution.verkehrsmittel_text), how='left').cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [136]:
stop_times_final = stop_times_final.select('trip_id', 'stop_id', col('route_desc').alias('train_type'),\
                                           'arrival_time', 'departure_time', 'next_stop', 'trip_duration', 'mean', 'std')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

To make our computations easier, we add a further assumption stating that we are interested only in trips between two stations that are done in under 2 hours. 

For this, we create a function that selects all trips on a certain week-day in a window of 2 hours before a certain arrival time. Thus if for example we want to arrive at 11:30, we are only interested in possible trips that departed 2 hours before 11:30.  

(**comment: je sais pas si c'est super pratique ça, parce que c'est possible qu'un utilisateur donne pas l'heure d'arrivée mais l'heure de départ**)

In [139]:
MAX_TRIP_DURATION = 2 #duration in hour 

def filter_edge_on_time(edges_df, day_id, hour, minute):
    """
    filter_edge_on_time: constructs edges (and thus trips) that exist in a window of two hours before a given input time
    @input:
    - edges_df: df from which we construct the edges
    - day_id: id of week-day (e.g. wednesday is day id 2, see dictionnary above)
    - hour, minute: time at which we want to arrive somewhere (e.g. 11:30)
    @output: data frame of selected edges
    """
    #select only the trips that occur on that day:
    edges_df= edges_df.join(day_trips(day_id), on='trip_id')
    
    arrival_minute = hour*60+minute
    min_dep_time = arrival_minute - 60*60*MAX_TRIP_DURATION
    
    #keep only those in a window of two hours:
    edges_df = edges_df.filter((col('departure_time') > min_dep_time) & 
                                            (col('arrival_time') <= arrival_minute))

    return edges_df

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We test this out on Wednesday with arrival at 11:30 

In [140]:
# Example of graph construction: Wednesday arrival at 11:30:00
edges_wed_11_30_df = filter_edge_on_time(stop_times_final, 2, 11, 30)
edges_wed_11_30 = (edges_wed_11_30_df
                     .rdd.map(lambda r: (r['stop_id'], r['next_stop'], {'duration': r['trip_duration'],
                                                                        'time': float(r['departure_time']),
                                                                        'train_type': r['train_t'],
                                                                        'std': r['std'],
                                                                        'mean': r['mean']})).collect())
print('Number of edges:', len(edges_wed_11_30))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('Number of edges:', 104449)

# Walking edges (about 15 minutes to run, -> load pickle)

**Add** : allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes

In [75]:
stops_pos = stops.join(stops_distance, 'stop_id').select(col('stop_id'), col('stop_lat'), col('stop_lon'))
stops_pos = stops_pos.select(col('stop_id').alias('stop_id_1'), col('stop_lat').alias('stop_lat_1'), col('stop_lon').alias('stop_lon_1'))
stops_pos = stops_pos.crossJoin(stops_pos.select(col('stop_id_1').alias('stop_id_2'), col('stop_lat_1').alias('stop_lat_2'), col('stop_lon_1').alias('stop_lon_2')))
stops_pos = stops_pos.where(col('stop_id_1') != col('stop_id_2'))
stops_pos.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+----------------+---------------+-----------+----------------+----------------+
|stop_id_1|      stop_lat_1|     stop_lon_1|  stop_id_2|      stop_lat_2|      stop_lon_2|
+---------+----------------+---------------+-----------+----------------+----------------+
|  8500926|47.4236270123012|8.4031825286317|    8502186|47.3934058321612|8.39894248049007|
|  8500926|47.4236270123012|8.4031825286317|8502186:0:1|47.3934666445388|8.39894248049007|
|  8500926|47.4236270123012|8.4031825286317|8502186:0:2|47.3935274568464|8.39894248049007|
|  8500926|47.4236270123012|8.4031825286317|   8502186P|47.3934058321612|8.39894248049007|
|  8500926|47.4236270123012|8.4031825286317|    8502187|47.3646945560768|8.37709545277724|
|  8500926|47.4236270123012|8.4031825286317|8502187:0:1|47.3647554015789|8.37709545277724|
|  8500926|47.4236270123012|8.4031825286317|8502187:0:2|47.3648162470108|8.37709545277724|
|  8500926|47.4236270123012|8.4031825286317|   8502187P|47.3646945560768|8.37709545277724|

In [76]:
stops_pos_dist = stops_pos.withColumn('distance', compute_distance(col('stop_lat_1'), col('stop_lon_1'), col('stop_lat_2'), col('stop_lon_2')))
stops_pos_dist.cache()
stops_pos_dist.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+----------------+---------------+-----------+----------------+----------------+---------+
|stop_id_1|      stop_lat_1|     stop_lon_1|  stop_id_2|      stop_lat_2|      stop_lon_2| distance|
+---------+----------------+---------------+-----------+----------------+----------------+---------+
|  8500926|47.4236270123012|8.4031825286317|    8502186|47.3934058321612|8.39894248049007|3375.1602|
|  8500926|47.4236270123012|8.4031825286317|8502186:0:1|47.3934666445388|8.39894248049007|3368.4294|
|  8500926|47.4236270123012|8.4031825286317|8502186:0:2|47.3935274568464|8.39894248049007|3361.6992|
|  8500926|47.4236270123012|8.4031825286317|   8502186P|47.3934058321612|8.39894248049007|3375.1602|
|  8500926|47.4236270123012|8.4031825286317|    8502187|47.3646945560768|8.37709545277724|6841.6157|
|  8500926|47.4236270123012|8.4031825286317|8502187:0:1|47.3647554015789|8.37709545277724| 6835.137|
|  8500926|47.4236270123012|8.4031825286317|8502187:0:2|47.3648162470108|8.37709545277724|6

In [77]:
stops_pos_dist.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

3532520

In [78]:
walking_edges = stops_pos_dist.select(col('stop_id_1').alias('source'), col('stop_id_2').alias('target'), col('distance'))\
                                        .where(col('distance') <= 500)\
                                        .withColumn('duration', col('distance')/50)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [79]:
walking_edges.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+-----------+---------+-------------------+
|     source|     target| distance|           duration|
+-----------+-----------+---------+-------------------+
|    8500926|    8590616|122.61607| 2.4523214721679687|
|    8500926|    8590737| 300.6712|  6.013424072265625|
|    8502186|8502186:0:1|6.7610297|0.13522059440612794|
|    8502186|8502186:0:2|13.522052| 0.2704410362243652|
|    8502186|   8502186P|      0.0|                0.0|
|    8502186|    8502270|478.78476|  9.575695190429688|
|    8502186|8502270:0:1|485.54355|  9.710870971679688|
|    8502186|    8590200| 483.6799|  9.673598022460938|
|    8502186|    8590203|455.34915|  9.106983032226562|
|8502186:0:1|    8502186|6.7610297|0.13522059440612794|
|8502186:0:1|8502186:0:2| 6.761022| 0.1352204418182373|
|8502186:0:1|   8502186P|6.7610297|0.13522059440612794|
|8502186:0:1|    8502270| 472.0255|   9.44051025390625|
|8502186:0:1|8502270:0:1| 478.7842|  9.575684204101563|
|8502186:0:1|    8590200| 484.2558|  9.685115966

## Save to local

In [80]:
%%spark -o walking_edges -n 1000000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [81]:
%%local
walking_edges.to_pickle('pickle_walking_times')

## Load from local

In [129]:
%%local
import pandas as pd
walking_edges = pd.read_pickle('pickle_walking_times')

In [130]:
%%send_to_spark -i walking_edges -t df -m 1000000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'walking_edges' as 'walking_edges' to Spark kernel

In [131]:
walking_edges

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[distance: double, duration: double, source: string, target: string]

## Write to pickle edges to create graph

In [None]:
select('trip_id', 'stop_id', 'route_desc', 'arrival_time', 'departure_time', 'next_stop', 'trip_duration', 'mean', 'std')

In [143]:
edges_wed_11_30_df

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[trip_id: string, stop_id: string, train_type: string, arrival_time: string, departure_time: string, next_stop: string, trip_duration: double, mean: double, std: double]

In [141]:
edges_walk_pickle = walking_edges.select(col('source'), col('target'), col('duration').alias('duration'))\
                                 .withColumn('time', lit(-1))\
                                 .withColumn('mean', lit(.0))\
                                 .withColumn('std', lit(.0))\
                                 .withColumn('train_type', lit('walking'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [144]:
edges_pickle = edges_wed_11_30_df.select(col('stop_id').alias('source'),\
                                         col('next_stop').alias('target'),\
                                         col('trip_duration').alias('duration'),\
                                         col('departure_time').alias('time'),\
                                         col('mean').alias('mean'),\
                                         col('std').alias('std'),\
                                         col('train_type').alias('train_type'))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [145]:
edges_pickle = edges_pickle.union(edges_walk_pickle)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [146]:
edges_pickle.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

116983

In [147]:
edges_pickle.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+-------+--------+----+--------------+--------------+----------+
|     source| target|duration|time|          mean|           std|train_type|
+-----------+-------+--------+----+--------------+--------------+----------+
|    8580301|8588553|     1.0| 516| 92.3712036402|242.0808686425|       Bus|
|    8573228|8573226|     1.0| 528| 68.5944875108| 81.7246172461|       Bus|
|    8588553|8573211|     3.0| 517| 92.4037229713|239.4063539951|       Bus|
|    8573232|8573230|     3.0| 522| 77.4878024691| 76.9400443276|       Bus|
|    8573211|8573232|     2.0| 520| 82.3203154337| 99.9239625875|       Bus|
|    8506895|8573228|     1.0| 527| 66.9796543032| 83.1064189911|       Bus|
|    8573226|8506889|     2.0| 529| 53.8281865285|  81.727450431|       Bus|
|    8506889|8506897|     1.0| 531| 67.7848460931| 77.5835308172|       Bus|
|    8573230|8573229|     1.0| 525|  85.143594306| 89.7685750527|       Bus|
|8573205:0:H|8580301|     0.0| 516|          null|          null|       Bus|

## Save to local

In [148]:
%%spark -o edges_pickle -n 1000000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [149]:
%%local
edges_pickle.to_pickle('edges_graph_with_walk')

## Load from local

In [57]:
%%local
import pandas as pd
edges_pickle = pd.read_pickle('edges_graph_with_walk')

In [58]:
%%send_to_spark -i edges_pickle -t df -m 1000000

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'test2' as 'test2' to Spark kernel

In [72]:
graph = nx.from_pandas_edgelist(edges_pickle.toPandas(), edge_attr=['time', 'duration'], create_using=nx.MultiDiGraph)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [73]:
out_edges = graph.edges('8503006:0:6', data=True)
out_edges = [edge for edge in out_edges]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [74]:
out_edges[0]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

('8503006:0:6', u'8591063', {'duration': -1.0, 'time': u'4.1365289307'})

In [None]:
walking_edges_wed = walking_edges\
                    .rdd.map(lambda r: (r['source'], r['target'], {'duration': r['duration'],
                                                                  'time': -1})).collect()

In [None]:
_ = graph.add_edges_from(walking_edges_wed)

# Filter to stations, reachable from Zürich


Filter Amdahl for now

In [None]:
stop_ids_zurich = stops_zurich.where(col('stop_name') == 'Zürich HB').select('stop_id').rdd.flatMap(lambda x: x).collect()
stop_ids_zurich

In [None]:
nodes_reachable_from_zurich = set(stop_ids_zurich)

for id_ in stop_ids_zurich:
    nodes_reachable_from_zurich = nodes_reachable_from_zurich.union(nx.descendants(graph, id_))
    
new_edges = [e for e in graph.edges(data=True) if e[0] in nodes_reachable_from_zurich and e[1] in nodes_reachable_from_zurich]

In [None]:
graph_zurich = nx.MultiDiGraph()
graph_zurich.add_nodes_from(nodes_reachable_from_zurich)
_ = graph_zurich.add_edges_from(new_edges)

## 4. Search for paths: 
BROUILLON DEPUIS LA

In [44]:
from heapq import heappush, heappop
from itertools import count

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [46]:
# returns weights of edge: 
def get_time_custom(graph, source, target, j):
    attr = graph.edges[(source, target, j)]
    return attr['time']

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [48]:
# returns weights of edge: 
def get_weight_custom(graph, source, target, j):
    attr = graph.edges[(source, target, j)]
    return attr['duration']

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [64]:
def dijkstra_time(G, first_source, INPUT_TIME, paths=None, last_target=None):
    
    G_succ = G.succ if G.is_directed() else G.adj
    
    paths = {first_source: [first_source]}
    e_paths = {first_source: []}

    push = heappush
    pop = heappop
    dist = {}  # dictionary of final distances
    
    # dictionnary of wthether it's the first time a node is visited
    seen = {first_source: INPUT_TIME}
    
    
    c = count()
    fringe = []  # use heapq with (distance,label) tuples
    
    #push(fringe, (0, next(c), first_source))
    push(fringe, (INPUT_TIME, next(c), first_source))
    
    while fringe:
        #take the node to look at: 
        (d, _, source) = pop(fringe)
        #print('Looking at node: '+source)
        
        # check if node has already been looked at: 
        if source in dist:
            continue  # already searched this node.
        
        # update the distance of the node
        dist[source] = d
        
        #stop if the node we look at is the target obviously
        if source == last_target:
            break
            
        # Look at all direct descendents from the source node: 
        for target, edges in G_succ[source].items():
            # Because it's a multigraph, need to look at all edges between two nodes:
            for edge_id in edges:
                
                dep_time_edge = get_time_custom(G, source, target, edge_id)
                
                # NOTE : A DECIDER SI ON VEUT >= ou >, certains bus continuent
                # en 0' donc ça doit être >= pour moi: 
                
                if dep_time_edge >= dist[source]:
                    # Get the duration between two nodes:
                    duration_cost = get_weight_custom(G, source, target, edge_id)

                    if duration_cost is None:
                            continue

                    # Add the weight to the current distance of a node
                    current_dist = dep_time_edge+ duration_cost

                    # if target has already been visited once and has a final distance:
                    if target in dist:
                            # if we find a distance smaller than the actual distance in dic
                            # raise error because dic distances contains only final distances
                            if current_dist < dist[target]:
                                raise ValueError('Contradictory paths found:',
                                                 'negative weights?')

                    # either node node been seen before or the current distance is smaller than the 
                    # proposed distance in seen[target]:
                    elif target not in seen or current_dist < seen[target]:
                        # update the seen distance
                        seen[target] = current_dist
                        # push it onto the heap so that we will look at its descendants later
                        push(fringe, (current_dist, next(c), target))

                        # update the paths till target:
                        if paths is not None:
                            #paths[target] = paths[source] + [target]
                            e_paths[target] = e_paths[source] + [(source, target, {'departure_time':dep_time_edge, 'duration':duration_cost})]
    
    if paths is not None:
        return dist
    return dist

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
nx.descendants(graph_zurich, '8503000:0:10')

In [None]:
dijkstra_path(graph_zurich, '8503000:0:10', '8502208:0:3', weight='duration')

In [None]:
gen = nx.all_simple_paths(graph_zurich, '8503000:0:10', '8502208:0:3')

for path in gen:
    print(path, flush=True)

In [None]:
graph.out_edges('8503000:0:10', data=True)

In [None]:
def filter_edges(edges, time):
    edges = [edge for edge in edges if edge[2]['time'] >= time]
    
    destinations = set([edge[1] for edge in edges])
    earliest_edges = []
    
    for destination in destinations:
        edges_to_dest = [edge for edge in edges if edge[1] == destination]
        earliest_edge = sorted(edges_to_dest, key=lambda edge: edge[2]['time'] + edge[2]['duration'])[0]
        earliest_edges.append(earliest_edge)
    
    return earliest_edges

In [None]:
def get_paths(graph, source, target, time, prev_path=[], num_hops=0, max_num_hops=4):
    if source == target:
        return [prev_path]
    elif num_hops >= max_num_hops:
        return None
    else:
        out_edges = graph.out_edges(source, data=True)
        out_edges = filter_edges(out_edges, time)
        # Should manually check
        print(time)
        print(out_edges)
        paths = []
        
        for out_edge in out_edges:
            new_paths = get_paths(graph, out_edge[1], target, time=out_edge[2]['time']+out_edge[2]['duration'], prev_path=prev_path+[source], num_hops=num_hops+1)
            if new_paths is not None:
                paths = paths + new_paths
                    
        return paths

In [None]:
out_edges = graph.out_edges('8503006:0:6', data=True)
out_edges = [edge for edge in out_edges]

In [None]:
out_edges[0]

In [None]:
dijkstra_path(graph_zurich, '8503000:0:10', '8503011:0:1')

In [None]:
nx.descendants(graph_zurich, '8503000:0:10')

In [None]:
get_paths(graph_zurich, '8503000:0:10', '8503202:0:4', 570, max_num_hops=6)

In [None]:
# Example for stop_times filtered on wednesday
stop_times_wed = day_trips(2).join(stop_times, on='trip_id')
stop_times_wed.show(5)

stop_times_wed.count()

In [None]:
# Can't run, the count makes it timeout. I asked Tao why
#print('Full stop times have', stop_times.count(), 'entries, filtered has', stop_times_wed.count())

In [None]:
#for r in stops_zurich.collect():
#    if r['stop_id'] == '8503000':
#        print((r['stop_id'], {'name': r['stop_name'].encode('utf-8'), 'lat': r['stop_lat'], 'lon': r['stop_lon']}))
#        break