In [1]:
%load_ext autoreload
%autoreload
!pip install pyarrow



In [2]:
import pandas as pd
import plotly.graph_objects as go
import ipywidgets as widgets
from ipywidgets import interact, interactive, Layout, Box, FloatText, Textarea, Label
from IPython.display import display, HTML
import numpy as np
import pandas as pd
import datetime
from hdfs3 import HDFileSystem 

from inference import get_stops_data, run_search

In [3]:
df_stops = get_stops_data()       
df_stops=df_stops.drop('in_radius', axis=1)

hdfs = HDFileSystem(user="ebouille")
df_schedule_network = pd.DataFrame()
for path in hdfs.glob("/user/anmaier/schedule_network.orc/hour=*/*.orc"):
    with hdfs.open(path) as f:
        df_schedule_network = pd.concat((df_schedule_network, pd.read_orc(f)), ignore_index=True)

def top_stations(top_n=50):
    return df_schedule_network.groupby('src_stop_name').count().sort_values('src_timestamp', ascending=False).iloc[:top_n, :].index.tolist()    

In [4]:
def add_coordinates(examples):
    #add coordinates to the paths found by the algorithm
    for i in range(len(examples)):
        path=examples[i].path_list
        
        for src in path['src_stop_name']:
            found=False
            for stop in df_stops['stop_name']:
            #print(stop)
                if src == stop:
                    path.loc[path.src_stop_name==src,'src_lat']=df_stops.loc[df_stops.stop_name==stop,'stop_lat'].values[0]
                    path.loc[path.src_stop_name==src, 'src_lon']=df_stops.loc[df_stops.stop_name==stop,'stop_lon'].values[0]
                    found=True
            if not found:
                print('The coordinates for ', src, ' were not found in df_stops')
            
        for dst in path['dst_stop_name']:
            found=False
            for stop in df_stops['stop_name']:
                #print(stop)
                if dst == stop:
                    path.loc[path.dst_stop_name==dst,'dst_lat']=df_stops.loc[df_stops.stop_name==stop,'stop_lat'].values[0]
                    path.loc[path.dst_stop_name==dst, 'dst_lon']=df_stops.loc[df_stops.stop_name==stop,'stop_lon'].values[0]
                    found=True
            if not found:
                if dst!='Arrival':
                    print('The coordinates for ', dst, ' were not found in df_stops')
                
        examples[i].path_list=path

In [5]:
def construct_segments(examples):
    #construct segments per type of transport for each path
    all_paths_segments=[]
    for i in range(len(examples)):
        segments=[]
        last_trip_id = ''
        for _, row in examples[i].path_list.iterrows():
            segment=dict()
            segment['src_lat']=str(row['src_lat'])
            segment['src_lon']=str(row['src_lon'])
            segment['dst_lat']=str(row['dst_lat'])
            segment['dst_lon']=str(row['dst_lon'])
            segment['route']=row['route_desc']
            if row['trip_id']!='' and row['trip_id']!= last_trip_id:
                segment['transfer_node']=(row['src_stop_name'],  row['src_timestamp'])
                last_trip_id = row['trip_id']
            else:
                segment['transfer_node']=None
            segments.append(segment)
            
        all_paths_segments.append(segments)
    return all_paths_segments

In [6]:
def figure(examples, all_paths_segments):
    #display figure
    for i in range(len(examples)):
        path=examples[i].path_list
        
        legend_added = {
        'walking': False,
        'Tram': False,
        'Zug': False,
        'Bus': False,
        'waiting': True,
        'early_arrival': True
        }

        if i==0:
            fig=go.FigureWidget()
            
        for seg in all_paths_segments[i]:
            start_lat=seg['src_lat']
            start_lon=seg['src_lon']
            end_lat=seg['dst_lat']
            end_lon=seg['dst_lon']
            route=seg['route']
            transfer_node=seg['transfer_node']
            
            if route=='walking':
                line={'color': 'black','width':2}
            elif route=='Tram':
                line={'color': 'blue','width': 3}
            elif route=='Zug':
                line={'color': 'red','width': 3}
            elif route=='Bus':
                line={'color': 'green','width': 3}
            
            fig.add_trace(go.Scattermapbox(
                mode="markers+lines",
                lon=[start_lon, end_lon],
                lat=[start_lat, end_lat],
                line=line,
                text=examples[i].probability,
                name='Route '+ str(i+1),
                visible=False,
                legendgroup=route,
                legendgrouptitle_text=route,
                showlegend=(not legend_added[route])))  
        
            legend_added[route] = True
            
            if transfer_node is not None:
            # Add departure
                fig.add_trace(go.Scattermapbox(
                    mode='markers',
                    lat=[start_lat],
                    lon=[start_lon],
                    text=examples[i].probability,
                    name='Route '+ str(i+1),
                    visible=True,
                    marker=dict(
                        size=12,
                        color='red',
                        opacity=0.7,
                    ),
                    hovertext=f"Take {route} from {transfer_node[0]} at {transfer_node[1].hour}:{transfer_node[1].minute}",
                    showlegend=False
                ))
        
   

    # add trace_names for widgets before adding departure and arrival
    trace_names =pd.Series([trace.name for trace in fig.data]).unique()


    departure_name=path['src_stop_name'].values[0]
    arrival_name=path['src_stop_name'].values[-1]
    departure_time = pd.Timestamp(path['src_timestamp'].values[0])
    arrival_time = pd.Timestamp(path['src_timestamp'].values[-1])
    dep_lat=float(path['src_lat'].values[0])
    dep_lon=float(path['src_lon'].values[0])
    arr_lat=float(path['src_lat'].values[-1])
    arr_lon=float(path['src_lon'].values[-1])


    # Add departure
    fig.add_trace(go.Scattermapbox(
        mode='markers',
        lat=[dep_lat],
        lon=[dep_lon],
        text='1',
        name='Departure at ' + departure_name,
        visible=True,
        marker=dict(
            size=12,
            color='orange',
            opacity=0.7,
        ),
        hovertext=f"Departure stop: {departure_name}<br>Departure time: {departure_time.hour}:{departure_time.minute}"
    ))

    #add Arrival
    fig.add_trace(go.Scattermapbox(
        mode='markers',
        lat=[arr_lat],
        lon=[arr_lon],
        text='1',
        name='Arrival at ' + arrival_name,
        visible=True,
        marker=dict(
            size=12,
            color='green',
            opacity=0.7,     
        ),
        hovertext=f"Arrival stop: {arrival_name}<br>Arrival time: {arrival_time.hour}:{arrival_time.minute}"
    ))


    fig.update_layout(
        margin ={'l':0,'t':0,'b':0,'r':0},
        mapbox = {
            'center': {'lon': 10, 'lat': 10},
            'style': "open-street-map",
            'center': dict(lat=path['src_lat'].apply(lambda x : np.mean(x)).mean(), lon=path['src_lon'].apply(lambda x : np.mean(x)).mean()),
            'zoom': 13})
    
    return trace_names, fig

In [7]:
import random
#create interface for the user
top_stops=list(top_stations(100))
data_stops=df_stops.values
display(widgets.HTML(value="<h4> Some of the busiest stops to chose from : </h4>"))
print( random.sample(top_stops,50))

text_dep_widget = widgets.Text(
    placeholder='Choose a departure stop'
)

button_dep = widgets.Button(description='Search')

text_arr_widget = widgets.Text(
    placeholder='Choose an arrival stop'
)

button_arr = widgets.Button(description='Search')

time_picker_dep = widgets.TimePicker(
    description='',
    disabled=False
)

time_picker_arr = widgets.TimePicker(
    description='',
    disabled=False
)

def dep_time_change(change):
    selected_time = change.new
    global departure_hour, departure_minute
    with output:
        departure_hour = selected_time.hour
        departure_minute = selected_time.minute

time_picker_dep.observe(dep_time_change, names='value')

def arr_time_change(change):
    selected_time = change.new
    global arrival_hour, arrival_minute
    with output:
        arrival_hour = selected_time.hour
        arrival_minute = selected_time.minute

time_picker_arr.observe(arr_time_change, names='value')


output=widgets.Output()

def dep_button_click(_):
    global departure_stop_widget
    user_input = text_dep_widget.value.strip()
    with output:
        if user_input in data_stops:
            output.clear_output()
            print(f"Selected departure : {user_input}")
            departure_stop_widget=user_input
        elif user_input=='':
            print(f"")
        else:
            print(f"Not available in SBB timetable stops, please try again")
            
def arr_button_click(_):
    global arrival_stop_widget
    user_input = text_arr_widget.value.strip()
    with output:
        if user_input in data_stops:
            output.clear_output()
            print(f"Selected arrival : {user_input}")
            arrival_stop_widget=user_input
        elif user_input=='':
            print(f"")
        else:
            print(f"Not available in SBB timetable stops, please try again")
            
            
confidence=widgets.FloatSlider(value=0.3, min=0, max=1.0, 
                step=0.01, description='', readout_format='.2f') 

def confidence_slider(level):
    global robustness
    with output:
        robustness=confidence.value

confidence.observe(confidence_slider, names="value")

            
departure_stop_widget=''
arrival_stop_widget=''
departure_hour, departure_minute=datetime.time.min.hour, datetime.time.min.minute
arrival_hour, arrival_minute=datetime.time.min.hour, datetime.time.min.minute
robustness=0.3

button_dep.on_click(dep_button_click)
button_arr.on_click(arr_button_click)


stops_widgets=widgets.HBox([text_dep_widget, time_picker_dep, button_dep, text_arr_widget, time_picker_arr, button_arr])
display(widgets.HTML(value="<h2> Choose the time and stops, press Search buttons and run the next cell: </h2>"))
display(stops_widgets, output)
display(widgets.HTML(value="<div> Minimum confidence: </div>"))
display(confidence)

HTML(value='<h4> Some of the busiest stops to chose from : </h4>')

['Zürich, Zypressenstrasse', 'Zürich, Fischerweg', 'Zürich, Saalsporthalle', 'Zürich, Schwert', 'Zürich, Albisriederplatz', 'Zürich, Morgental', 'Zürich, Lindenplatz', 'Zürich, Kreuzplatz', 'Zürich, Messe/Hallenstadion', 'Zürich, Hubertus', 'Zürich, Farbhof', 'Zürich, Wartau', 'Zürich, Schaffhauserplatz', 'Zürich, Stauffacher', 'Zürich Flughafen, OPC', 'Zürich, Bahnhofquai/HB', 'Zürich, Altes Krematorium', 'Zürich, Sihlstrasse', 'Zürich, Kunsthaus', 'Zürich, Feldeggstrasse', 'Zürich, Kinkelstrasse', 'Zürich, Zwielplatz', 'Zürich, Helmhaus', 'Zürich, Seilbahn Rigiblick', 'Zürich, Schaufelbergerstrasse', 'Zürich, Meierhofplatz', 'Zürich, Haldenegg', 'Zürich, Schiffbau', 'Zürich, Waffenplatzstrasse', 'Zürich, Beckenhof', 'Zürich, Milchbuck', 'Zürich, Röslistrasse', 'Zürich HB', 'Zürich, Kantonalbank', 'Zürich, Löwenplatz', 'Zürich, Waldgarten', 'Zürich, Universität Irchel', 'Zürich, Escher-Wyss-Platz', 'Zürich, Bezirksgebäude', 'Zürich, Schmiede Wiedikon', 'Zürich, Winkelriedstrasse', 'Zü

HTML(value='<h2> Choose the time and stops, press Search buttons and run the next cell: </h2>')

HBox(children=(Text(value='', placeholder='Choose a departure stop'), TimePicker(value=None, step=60.0), Butto…

Output()

HTML(value='<div> Minimum confidence: </div>')

FloatSlider(value=0.3, max=1.0, step=0.01)

In [9]:
#retrieve parameters from the interface for the algorithm and display figure if itineraries found
start_time = datetime.time(departure_hour, departure_minute)
end_time = datetime.time(arrival_hour, arrival_minute)
max_duration=0
if end_time<start_time:
    display(widgets.HTML(value="<div> Please select coherent arrival and departure times</div>"))
else:
    duration = datetime.datetime.combine(datetime.date.today(), end_time) - datetime.datetime.combine(datetime.date.today(), start_time)
    max_duration=duration.seconds/60
    
examples=[]
if departure_stop_widget!=arrival_stop_widget and max_duration>=0:
    examples=run_search(departure_stop_widget, arrival_stop_widget, arrival_hour, arrival_minute, robustness,
                   max_transfers=10, max_duration=max_duration, top_n=10, verbose=True, missing_prob_handle_method='keep')
elif departure_stop_widget!='' and departure_stop_widget==arrival_stop_widget:
    display(widgets.HTML(value="<div> Departure station cannot be the same as arrival one, please try another itinerary</div>"))
if examples and departure_stop_widget!=arrival_stop_widget:
    add_coordinates(examples)
    trace_names, fig_=figure(examples, construct_segments(examples))
    
    def slider_proba(level):
        visible=[]
        for trace in fig_.data:
            if trace.name.startswith('Departure') or trace.name.startswith('Arrival'):
                trace.visible=True
            if float(trace.text) >= level:
                trace.visible = True
            else:
                trace.visible = False
            fig_.update_layout()
            visible.append(trace.visible)
        return visible


    def update_fig(selected_route, level):
        #here we update the selector in function of the dropdown
        #selected_routes = selector.value
        visible=slider_proba(level)
        i=0
        for trace in fig_.data:
            if trace.name.startswith('Departure') or trace.name.startswith('Arrival'):
                trace.visible=True
            if visible[i]:
                i+=1
                #means that for this confidence level the path appears, 
                #we only want to select paths in a predetermined confidence level
                if trace.name==selected_route:
                    trace.visible = True
                elif  not (trace.name.startswith('Departure') or trace.name.startswith('Arrival')):
                    trace.visible = False
            elif not (trace.name.startswith('Departure') or trace.name.startswith('Arrival')): #we dont change anything
                i+=1
        fig_.update_layout()
        fig_.show()

    selector = widgets.Dropdown(
        options=trace_names, 
        value='Route 1',
        description=' ',
        rows=4
    )

    slider=widgets.FloatSlider(value=0.8, min=0, max=1.0, 
                step=0.01, description='Confidence', readout_format='.2f')


    box_layout = Layout(
        display='flex',
        flex_flow='row',
        justify_content='space-between',
        width='auto'
    )

    ui=widgets.Box([selector, slider], layout=box_layout)



    #selector.observe(update_from_selector, "value")
    display(widgets.HTML(value="<h3> Select a route : </h3>"))
    out = widgets.interactive_output(update_fig, {'selected_route':selector, 'level': slider})

    display(ui,out)

elif departure_stop_widget!=arrival_stop_widget:
    display(widgets.HTML(value="<div> No routes found for these parameters </div>"))