# SBB planner 2.0.

1. Construct `stop_times_zurich` (now it uses the graph to get nodes, but unecessary), cache it
2. Planner, when user asks journey:
    1. Create graph (2 hours before arrival)
    2. Create 2nd graph with only stations and edges reachable from Zürich
    3. Compute paths
    4. Find shortest path

## 1. Setup:

First we setup our spark application and load the necessary data containing information about specific transit information:  
- `stop_times.txt` : arrival and departure times at stops
- `stops.txt` : information about stops 
- `trips.txt` : information about journeys
- `calendar.txt` : information about which servcies are active on which dates

#### Set up spark:

In [84]:
%%configure
{"conf": {
    "spark.app.name": "dslab-group_final"
}}

A session has already been started. If you intend to recreate the session with new configurations, please include the -f argument.


#### Imports:

In [85]:
import networkx as nx
from geopy.distance import distance as geo_distance
from pyspark.sql import Row
from pyspark.sql.functions import *
from pyspark.sql.types import FloatType
from networkx.algorithms.shortest_paths.weighted import dijkstra_path

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### Load data:

In [49]:
# Loading data, these are snapshots of the all available data
# Calendar and trips are useful to filter the other dataframe according to the day

stop_times = spark.read.format('orc').load('/data/sbb/timetables/orc/stop_times/000000_0')
stops = spark.read.format('orc').load('/data/sbb/timetables/orc/stops/000000_0')
trips = spark.read.format('orc').load('/data/sbb/timetables/orc/trips/000000_0')
calendar = spark.read.format('orc').load('/data/sbb/timetables/orc/calendar/000000_0')

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


## 2. Pre-processing: 

Here, we pre-process the data to correspond to the following criteria:
- only consider journeys at reasonable hours of the day, and on a typical business day, and assuming the schedule of May 13-17, 2019
- allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes
- only consider journeys that start and end on known station coordinates (train station, bus stops, etc.), never from a random location
- only consider stations in a 15km radius of Zürich's train station, Zürich HB (8503000), (lat, lon) = (47.378177, 8.540192)
- only consider stations in the 15km radius that are reachable from Zürich HB, either directly, or via transfers through other stations within the area
- assume that the timetables remain unchanged throughout the 2018 - 2019 period

Filtering criteria for later use (e.g. probability intervals and planner):
- assuming that delays or travel times on the public transport network are uncorrelated with one another
- once a route is computed, a traveller is expected to follow the planned routes to the end, or until it fails (i.e. miss a connection)
- planner will not need to mitigate the traveller's inconvenience if a plan fails

### Stop times during rush-hour: 

Only consider journeys at reasonable hours of the day, thus we take only stop times that are in the window of rush-hour (e.g. from 8 a.m. to 8 p.m.). 

In [50]:
# Filter stop_times to be only in 08:00-19:59:
stop_times = stop_times.where((col('departure_time') >= '08:00:00') 
                              & (col('departure_time') <= '19:59:59'))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


### Stations around Zürich HB:

Only consider stations in a 15km radius of Zürich's train station (Zürich HB). 

First we get the geolocation of Zürich Hauptbahnhof to be able to calculate the distance of the other stations to the Hauptbahnhof. 

In [51]:
zurich_pos = stops.where(column('stop_name') == 'Zürich HB').select('stop_lat', 'stop_lon').collect()
zurich_pos = (zurich_pos[0][0], zurich_pos[0][1])
print('Location of Zürich Hauptbahnhof (lat, lon) :'+str(zurich_pos))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [52]:
def zurich_distance(x, y):
    """zurich_distance: returns the distance of a station to Zurich HB
    @input: (lat,lon) of a station
    @output: distance in km to Zurich HB
    """
    return geo_distance(zurich_pos, (x,y)).km

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


Then we create a dataframe `stops_zurich` of the stations where we add a column for the distance to Zurich HB. In that dataframe, we keep only those that are in a radisu of 15km to the HB. The same filter is applied to the `stop_times` df mentioned above. 

In [53]:
# filter stops:
stops_distance = stops.rdd.map(lambda x: (x['stop_id'], zurich_distance(x['stop_lat'], x['stop_lon'])))
stops_distance = spark.createDataFrame(stops_distance.map(lambda r: Row(stop_id=r[0], 
                                                                        zurich_distance=r[1])))

stops_distance = stops_distance.filter(column('zurich_distance') <= 15)

#print('There are '+str(stops_distance.count())+' stops in a radius of 15km around Zurich HB')

# add distance to HB to stops info and keep only in radius of 15km
stops_zurich = stops_distance.join(stops, on='stop_id')

# keep only stop times in radius of 15km of Zurich
stop_times_zurich = stop_times.join(stops_distance.select('stop_id'), on='stop_id')

#print('There are '+str(stop_times_zurich.count())+ ' stop times in a radius of 15km around  Zurich HB')

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [54]:
# Cache it to save time:
stop_times_zurich.cache()

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


### Have a look at the data we have so far: 

#### Stop times in Zurich: 

In [55]:
stop_times_zurich.show(3)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### Trips in Zurich:

In [56]:
trips.show(3)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### Stops in Zurich:

In [57]:
stops_zurich.show(3)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### Calendar:

In [58]:
calendar.show(3)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


## 3. Create a network:

From the pre-processed data, we would like to create a directed network where each node is a station and each edge between two nodes corresponds to a possible trip. 

A node will have the following attributes:
- stop_name: name of the station (e.g. Zurich HB)
- latitude
- longitude

An directed edge will have the following attributes:
- stop_id: the id of the stop the (directed) edge points from
- next_stop: the id of the stop the edge points to
- duration: the duration of the trip from stop_id to next_stop
- departure time: the time from which the service departs from stop_id

First, we create a function that returns all services that operate on certain days (as some of them might not operate on weekends for example). 

In [59]:
days_dict = {0: 'monday', 1: 'tuesday', 2: 'wednesday', 3: 'thursday', 4: 'friday'}

def day_trips(*day_ids):
    """
    day_trips: gives the trip_ids that operate on certain days
    input: ? 
    output:s spark dataframe with trip_ids
    
    """
    days = [days_dict[day_id] for day_id in day_ids]
    where_clause = " and ".join(days)

    day_services = calendar.where(where_clause).select('service_id')
    return day_services.join(trips, on='service_id').select('trip_id')

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


### Nodes:

Then we create a **multigraph** (e.g. more than one edge allowed between two nodes). 

In [60]:
graph = nx.MultiDiGraph()

nodes = stops_zurich.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
graph.add_nodes_from(nodes)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


### Edges:

Right now, we have only nodes that are stations in our graph and we would like to add edges between them showing possible trips at certain times. For this we need to do a few operations. 

First, in our `stop_times_zurich` table, we have the time of arrival and departure but we would like to have an idea of the time elapsed in minutes since 12 p.m. This way times will be easily subtractable and we can get an idea of trip duration in minutes. So we convert those two columns: 

In [61]:
@udf
def convertToMinute(s):
    """converts seconds to minutes
    """
    h, m, _ = s.split(':')
    h,m = int(h), int(m)
    
    return h*60+m

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [62]:
# Convert time information to minutes elapsed since 0am
stop_times_zurich = stop_times_zurich.withColumn('arrival_time', 
                                                 convertToMinute(column('arrival_time')))
stop_times_zurich = stop_times_zurich.withColumn('departure_time', 
                                                 convertToMinute(column('departure_time')))
stop_times_zurich.show(3)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


Then we want a dataframe that has the trip duration to the next stop from the current one on the trip. For that, we first create a table with the next stop and arrival time for each stop sequence in a trip. 

In [63]:
stop_times_zurich_2 = (stop_times_zurich.withColumn('stop_sequence_prev', column('stop_sequence')-1)
                       .select('trip_id',
                               column('stop_id').alias('next_stop'),
                               column('stop_sequence_prev').alias('stop_sequence'),
                               column('arrival_time').alias('next_arrival_time')))

stop_times_zurich_2.show(2)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


Then we join this to the `stop_times_zurich` table to have trip duration (in minutes) and next stop information. 

In [64]:
# Add trip duration and next stop: 
stop_times_zurich = stop_times_zurich.join(stop_times_zurich_2, 
                                           on=['trip_id', 'stop_sequence']).orderBy('trip_id', 'stop_sequence')
stop_times_zurich = stop_times_zurich.withColumn('trip_duration', 
                                                 column('next_arrival_time')-column('departure_time'))
stop_times_zurich = stop_times_zurich.select('trip_id', 
                                             'stop_id', 'arrival_time', 'departure_time', 
                                             'next_stop', 'trip_duration')
stop_times_zurich.show(2)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [65]:
stop_times_zurich.cache()

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


To make our computations easier, we add a further assumption stating that we are interested only in trips between two stations that are done in under 2 hours. 

For this, we create a function that selects all trips on a certain week-day in a window of 2 hours before a certain arrival time. Thus if for example we want to arrive at 11:30, we are only interested in possible trips that departed 2 hours before 11:30.  

(**comment: je sais pas si c'est super pratique ça, parce que c'est possible qu'un utilisateur donne pas l'heure d'arrivée mais l'heure de départ**)

In [66]:
MAX_TRIP_DURATION = 2 #duration in hour 

def filter_edge_on_time(edges_df, day_id, hour, minute):
    """
    filter_edge_on_time: constructs edges (and thus trips) that exist in a window of two hours before a given input time
    @input:
    - edges_df: df from which we construct the edges
    - day_id: id of week-day (e.g. wednesday is day id 2, see dictionnary above)
    - hour, minute: time at which we want to arrive somewhere (e.g. 11:30)
    @output: data frame of selected edges
    """
    #select only the trips that occur on that day:
    edges_df= edges_df.join(day_trips(day_id), on='trip_id')
    
    arrival_minute = hour*60+minute
    min_dep_time = arrival_minute - 60*60*MAX_TRIP_DURATION
    
    #keep only those in a window of two hours:
    edges_df = edges_df.filter((col('departure_time') > min_dep_time) & 
                                            (col('arrival_time') <= arrival_minute))

    return edges_df

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


We test this out on Wednesday with arrival at 11:30 

In [67]:
# Example of graph construction: Wednesday arrival at 11:30:00
edges_wed_11_30 = (filter_edge_on_time(stop_times_zurich, 2, 11, 30)
         .rdd.map(lambda r: (r['stop_id'], r['next_stop'], {'duration': r['trip_duration'],
                                                          'time': float(r['departure_time'])})).collect())
print('Number of edges:', len(edges_wed_11_30))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [68]:
_ = graph.add_edges_from(edges_wed_11_30)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### Additionnal edges: walking distances


**Add** : allow short (max 500m "As the Crows Flies") walking distances for transfers between two stations, and assume a walking speed of 50m/1min on a straight line, regardless of obstacles, human-built or natural, such as building, highways, rivers, or lakes

In [69]:
# COMPLETE

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


#### A simple graph


In [70]:
## create a simple graph: 
trip_id = '1.TA.1-231-j19-1.1.H'

#select four stops : 
stops = ['8582462','8572600','8572601','8502553']

stops_info = stops_zurich.where(column('stop_id').isin(stops))
stop_times_info = stop_times_zurich.where((column('stop_id').isin(stops))&(column('trip_id') == trip_id))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [71]:
stop_times_info.show(4)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [72]:
mini_graph = nx.MultiDiGraph()

mini_nodes = stops_info.rdd.map(lambda r: (r[0], {'name': r['stop_name'],
                                              'lat': r['stop_lat'],
                                              'lon': r['stop_lon']})).collect()
mini_graph.add_nodes_from(mini_nodes)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [73]:
mini_graph.nodes['8582462']['name']

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [74]:
mini_edges = (filter_edge_on_time(stop_times_info, 2, 11, 30)
         .rdd.map(lambda r: (r['stop_id'], r['next_stop'], {'duration': r['trip_duration'],
                                                          'time': float(r['departure_time'])})).collect())
print('Number of edges:', len(mini_edges))

mini_graph.add_edges_from(mini_edges)
#mini_graph.add_weighted_edges_from(mini_edges)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [75]:
mini_graph.add_edges_from([('8582462', '8572600',  {'duration': 2.0, 'time': 578.0})])

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [76]:
mini_graph.edges

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [77]:
list(mini_graph.nodes(data=True))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [78]:
list(mini_graph.edges(data=True))

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [79]:
nx.single_source_dijkstra(mini_graph, '8582462', '8572602', weight='duration')

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [80]:
nx.predecessor(mini_graph,'8582462') 

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


##### Comment:
- discussion on k shortest paths in multigraph: [source](https://groups.google.com/forum/#!topic/networkx-discuss/87uC9F0ug8Y)
- paper on algorithms to find k-shortests paths: [source](https://www.ics.uci.edu/~eppstein/pubs/Epp-SJC-98.pdf)

Comment: "I had a look at the single_source_dijkstra() function and with quite
an easy adaptation I can get what I want. Instead of storing nodes in
the path I would just store edges. Then I would return the list of
edges. Then I would just implement David Eppstein's k-shortest paths
algorithm. I didn't understand the comment about the algorithm not
being implemented multigraphs because it slows down the algorithm too
much. As if you look at the special case where a Multigraph is equal
to Graph (i.e. only one edge) the two implementations perform the
same. I agree that having Multigraph requires more work to find the
shortest path as you then have more neighbors to investigate."


In [81]:
mini_graph.succ

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [82]:
#single source dijkstra: Compute shortest paths and lengths in a weighted graph G.
import numpy as np
def single_source_dijkstra_tweaked(G, source, target=None, cutoff=None, weight='weight'):
    """Compute shortest paths and lengths in a weighted graph G.

    Uses Dijkstra's algorithm for shortest paths.

    Parameters
    ----------
    G : NetworkX graph

    source : node label
       Starting node for path

    target : node label, optional
       Ending node for path

    cutoff : integer or float, optional
       Depth to stop the search. Only paths of length <= cutoff are returned.

    Returns
    -------
    distance,path : dictionaries
       Returns a tuple of two dictionaries keyed by node.
       The first dictionary stores distance from the source.
       The second stores the path from the source to that node.


    Examples
    --------
    >>> G=nx.path_graph(5)
    >>> length,path=nx.single_source_dijkstra(G,0)
    >>> print(length[4])
    4
    >>> print(length)
    {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
    >>> path[4]
    [0, 1, 2, 3, 4]

    Notes
    ---------
    Edge weight attributes must be numerical.
    Distances are calculated as sums of weighted edges traversed.

    Based on the Python cookbook recipe (119466) at
    http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/119466

    This algorithm is not guaranteed to work if edge weights
    are negative or are floating point numbers
    (overflows and roundoff errors can cause problems).

    See Also
    --------
    single_source_dijkstra_path()
    single_source_dijkstra_path_length()
    """
    if G.is_multigraph():
        get_weight = lambda u, v, data: np.min(eattr.get(weight, 1) for eattr in data.values())
    else:
        get_weight = lambda u, v, data: data.get(weight, 1)

    return dijkstra_tweaked(G, source, get_weight, cutoff=cutoff)

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [83]:
from collections import deque
from heapq import heappush, heappop
from itertools import count


def dijkstra_tweaked(G, source, get_weight, paths=None, cutoff=None,target=None):
    """Implementation of Dijkstra's algorithm

    Parameters
    ----------
    G : NetworkX graph

    source : node label
       Starting node for path

    get_weight: function
        Function for getting edge weight

    pred: list, optional(default=None)
        List of predecessors of a node

    paths: dict, optional (default=None)
        Path from the source to a target node.

    target : node label, optional
       Ending node for path

    cutoff : integer or float, optional
       Depth to stop the search. Only paths of length <= cutoff are returned.

    Returns
    -------
    distance,path : dictionaries
       Returns a tuple of two dictionaries keyed by node.
       The first dictionary stores distance from the source.
       The second stores the path from the source to that node.
    """
    G_succ = G.succ if G.is_directed() else G.adj

    push = heappush
    pop = heappop
    dist = {}  # dictionary of final distances
    seen = {source: 0}
    c = count()
    fringe = []  # use heapq with (distance,label) tuples
    push(fringe, (0, next(c), source))
    while fringe:
        (d, _, v) = pop(fringe)
        if v in dist:
            continue  # already searched this node.
        dist[v] = d
        if v == target:
            break

        for u, e in G_succ[v].items():
            cost = get_weight(v, u, e)
            if cost is None:
                continue
            print(cost)
            vu_dist = dist[v] + get_weight(v, u, e)
            if cutoff is not None:
                if vu_dist > cutoff:
                    continue
            if u in dist:
                if vu_dist < dist[u]:
                    raise ValueError('Contradictory paths found:',
                                     'negative weights?')
            elif u not in seen or vu_dist < seen[u]:
                seen[u] = vu_dist
                push(fringe, (vu_dist, next(c), u))
                if paths is not None:
                    paths[u] = paths[v] + [u]
                if pred is not None:
                    pred[u] = [v]
            elif vu_dist == seen[u]:
                if pred is not None:
                    pred[u].append(v)

    if paths is not None:
        return (dist, paths)
    if pred is not None:
        return (pred, dist)
    return dist

An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [46]:
get_weight = lambda u, v, data: np.min(eattr.get(weight, 1) for eattr in data.values())


An error was encountered:
Invalid status code '404' from http://iccluster044.iccluster.epfl.ch:8998/sessions/5878 with error payload: "Session '5878' not found."


In [44]:
single_source_dijkstra_tweaked(mini_graph, '8582462', '8572602', weight='duration')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
unsupported operand type(s) for +: 'int' and 'generator'
Traceback (most recent call last):
  File "<stdin>", line 62, in single_source_dijkstra_tweaked
  File "<stdin>", line 60, in dijkstra_tweaked
TypeError: unsupported operand type(s) for +: 'int' and 'generator'



## 4. Search for paths: 
BROUILLON DEPUIS LA

In [23]:
stop_ids_zurich = stops_zurich.where(col('stop_name') == 'Zürich HB').select('stop_id').rdd.flatMap(lambda x: x).collect()
stop_ids_zurich

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[u'8503000', u'8503000:0:10', u'8503000:0:11', u'8503000:0:12', u'8503000:0:13', u'8503000:0:14', u'8503000:0:15', u'8503000:0:16', u'8503000:0:17', u'8503000:0:18', u'8503000:0:3', u'8503000:0:31', u'8503000:0:32', u'8503000:0:33', u'8503000:0:34', u'8503000:0:4', u'8503000:0:41/42', u'8503000:0:43/44', u'8503000:0:5', u'8503000:0:6', u'8503000:0:7', u'8503000:0:8', u'8503000:0:9', u'8503000P']

In [24]:
nodes_reachable_from_zurich = set(stop_ids_zurich)

for id_ in stop_ids_zurich:
    nodes_reachable_from_zurich = nodes_reachable_from_zurich.union(nx.descendants(graph, id_))
    
new_edges = [e for e in graph.edges(data=True) if e[0] in nodes_reachable_from_zurich and e[1] in nodes_reachable_from_zurich]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
graph_zurich = nx.MultiDiGraph()
graph_zurich.add_nodes_from(nodes_reachable_from_zurich)
_ = graph_zurich.add_edges_from(new_edges)

In [None]:
nx.descendants(graph_zurich, '8503000:0:10')

In [None]:
dijkstra_path(graph_zurich, '8503000:0:10', '8502208:0:3', weight='duration')

In [None]:
gen = nx.all_simple_paths(graph_zurich, '8503000:0:10', '8502208:0:3')

for path in gen:
    print(path, flush=True)

In [None]:
graph.out_edges('8503000:0:10', data=True)

In [None]:
def filter_edges(edges, time):
    edges = [edge for edge in edges if edge[2]['time'] >= time]
    
    destinations = set([edge[1] for edge in edges])
    earliest_edges = []
    
    for destination in destinations:
        edges_to_dest = [edge for edge in edges if edge[1] == destination]
        earliest_edge = sorted(edges_to_dest, key=lambda edge: edge[2]['time'] + edge[2]['duration'])[0]
        earliest_edges.append(earliest_edge)
    
    return earliest_edges

In [None]:
def get_paths(graph, source, target, time, prev_path=[], num_hops=0, max_num_hops=4):
    if source == target:
        return [prev_path]
    elif num_hops >= max_num_hops:
        return None
    else:
        out_edges = graph.out_edges(source, data=True)
        out_edges = filter_edges(out_edges, time)
        # Should manually check
        print(time)
        print(out_edges)
        paths = []
        
        for out_edge in out_edges:
            new_paths = get_paths(graph, out_edge[1], target, time=out_edge[2]['time']+out_edge[2]['duration'], prev_path=prev_path+[source], num_hops=num_hops+1)
            if new_paths is not None:
                paths = paths + new_paths
                    
        return paths

In [None]:
out_edges = graph.out_edges('8503006:0:6', data=True)
out_edges = [edge for edge in out_edges]

In [None]:
out_edges[0]

In [None]:
dijkstra_path(graph_zurich, '8503000:0:10', '8503011:0:1')

In [None]:
nx.descendants(graph_zurich, '8503000:0:10')

In [None]:
get_paths(graph_zurich, '8503000:0:10', '8503202:0:4', 570, max_num_hops=6)

In [None]:
# Example for stop_times filtered on wednesday
stop_times_wed = day_trips(2).join(stop_times, on='trip_id')
stop_times_wed.show(5)

stop_times_wed.count()

In [None]:
# Can't run, the count makes it timeout. I asked Tao why
#print('Full stop times have', stop_times.count(), 'entries, filtered has', stop_times_wed.count())

In [None]:
#for r in stops_zurich.collect():
#    if r['stop_id'] == '8503000':
#        print((r['stop_id'], {'name': r['stop_name'].encode('utf-8'), 'lat': r['stop_lat'], 'lon': r['stop_lon']}))
#        break