## Import the necessary libraries

In [2]:
import os
import pandas as pd
pd.set_option("display.max_columns", 50)
import matplotlib.pyplot as plt
import warnings
import plotly.express as px
import plotly.graph_objects as go
import pyspark.sql.functions as SFunc
import pyspark.sql.functions as F
from pyspark.sql import Window
warnings.simplefilter(action='ignore', category=UserWarning)
%matplotlib inline

## Set up HIVE, spark and connect to cluster

In [3]:
%load_ext sparkmagic.magics

Cleaning up livy sessions on exit is enabled


In [4]:
import os
from IPython import get_ipython
username = os.environ['RENKU_USERNAME']
server = "http://iccluster044.iccluster.epfl.ch:8998"

# set the application name as "<your_gaspar_id>-homework3"
get_ipython().run_cell_magic(
    'spark',
    line='config', 
    cell="""{{ "name": "{0}-homework3-1", "executorMemory": "4G", "executorCores": 4, "numExecutors": 10, "driverMemory": "4G" }}""".format(username)
)

In [5]:
get_ipython().run_line_magic(
    "spark", f"""add -s {username}-finalproject-1 -l python -u {server} -k"""
)

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
4572,application_1680948035106_4188,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


In [6]:
%%spark
print('We are using Spark %s' % spark.version)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

We are using Spark 2.4.8.7.1.8.0-801

In [7]:
from pyhive import hive

# Set python variables from environment variables
username = os.environ['USERNAME']
hive_host = os.environ['HIVE_SERVER2'].split(':')[0]
hive_port = os.environ['HIVE_SERVER2'].split(':')[1]

# create connection
conn = hive.connect(
    host=hive_host,
    port=hive_port,
    # auth="KERBEROS",
    # kerberos_service_name = "hive"
)

# create cursor
cur = conn.cursor()

print(f"your username is {username}")
print(f"you are connected to {hive_host}:{hive_port}")

your username is digennar
you are connected to iccluster044.iccluster.epfl.ch:10000


## Preprocessing

We need to preprocess the stops and the sbb information

In [8]:
# We start by creating a dataframe querying from sbb_orc the information we need to provide the result required
query = """
    select STOP_ID as id, STOP_NAME as name, STOP_LAT as lat, STOP_LON as lon
    from {0}.allstops_orc
""".format(username)
geostops_df = pd.read_sql(query, conn)

In [9]:
geostops_df[geostops_df["name"] == "Zürich HB"].head(1)

Unnamed: 0,id,name,lat,lon
13238,8503000,Zürich HB,47.378178,8.540212


#### Filtering on location
Let us start by keeping only the stops that are in a 15km radius from Zurich station. We thus need to calculate the distance between Zurich HB and the other stops. For this, we can use the coordinates.
Code from https://www.geeksforgeeks.org/program-distance-two-points-earth/

In [10]:
from math import radians, cos, sin, asin, sqrt
def distance(lat1, lat2, lon1, lon2):
     
    # The math module contains a function named
    # radians which converts from degrees to radians.
    lon1 = radians(lon1)
    lon2 = radians(lon2)
    lat1 = radians(lat1)
    lat2 = radians(lat2)
      
    # Haversine formula
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
 
    c = 2 * asin(sqrt(a))
    
    # Radius of earth in kilometers. Use 3956 for miles
    r = 6371
      
    # calculate the result
    return(c * r)

In [11]:
# Now we filter based on the distance from Zürich HB
lat_Zur = 47.3781762039461
lon_Zur = 8.54021154209037
# max dist in km
max_dist = 15

print("Number of stops before filtering on distance from Zürich HB: ", len(geostops_df))
geostops_df["Distance"] = geostops_df.apply(lambda x: distance(float(x["lat"]), lat_Zur,  float(x["lon"]), lon_Zur), axis=1)
geostops_df = geostops_df[geostops_df["Distance"] <= 15]
print("Number of stops after filtering on distance from Zürich HB: ", len(geostops_df))                          

Number of stops before filtering on distance from Zürich HB:  46690
Number of stops after filtering on distance from Zürich HB:  2122


In [12]:
print("Number of different id's: ", geostops_df["id"].nunique())
print("Number of different names: ", geostops_df["name"].nunique())

Number of different id's:  2122
Number of different names:  1601


In [13]:
# Remove the stops that have a parent
geostops_df = geostops_df[~geostops_df['id'].str.contains("Parent")]
geostops_df = geostops_df[~geostops_df['id'].str.contains(":")]
geostops_df = geostops_df[~geostops_df['id'].str.contains("P")]

# check if all stops have different coordinates now
## what is booleans 
## num_dist_coordinates = booleans.groupby(["lat", "lon"]).ngroups
num_dist_coordinates = geostops_df.groupby(["lat", "lon"]).ngroups
assert(len(geostops_df) == num_dist_coordinates)

# check if all names are distinct
assert(geostops_df["name"].nunique() == num_dist_coordinates)

In [14]:
geostops_df.head()

Unnamed: 0,id,name,lat,lon,Distance
7896,176,Zimmerberg-Basistunnel,47.351677,8.521957,3.251541
10878,8500926,"Oetwil a.d.L., Schweizäcker",47.423626,8.403183,11.484985
12363,8502075,"Zürich Flughafen, Carterminal",47.451023,8.563729,8.291228
12534,8502186,Dietikon Stoffelbach,47.393326,8.39896,10.766797
12538,8502187,Rudolfstetten Hofacker,47.36467,8.376952,12.385806


### Walking
Create a dictionary that for each stop contains all other stops that are reachable by foot (< 0.5 km)

In [15]:
within_walking_distance = {}
import copy
max_dist = 0.5
all_names = set(geostops_df["name"].unique())
for stop in all_names:
    within_walking_distance[stop] = set()
    stop_data = stops_dict[stop]
    lat, lon = float(stop_data["lat"]), float(stop_data["lon"])
    reachable = []
    all_others = copy.deepcopy(all_names)
    all_others.remove(stop)
    for other in all_others:
        other_stop_data = stops_dict[other]
        lat_oth, lon_oth = float(other_stop_data["lat"]), float(other_stop_data["lon"])
        dist = distance(lat, lat_oth, lon, lon_oth)
        if dist < max_dist:
            within_walking_distance[stop].add(other)
    

NameError: name 'stops_dict' is not defined

### Timetable data (calendar, trips, routes & stop_times)

In [16]:
%%spark
# Calendar data
calendar = spark.read.csv("/data/sbb/part_csv/timetables/calendar", header=True, encoding='utf8')
calendar.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------+---------+--------+------+--------+------+----------+--------+----+-----+---+
|service_id|monday|tuesday|wednesday|thursday|friday|saturday|sunday|start_date|end_date|year|month|day|
+----------+------+-------+---------+--------+------+--------+------+----------+--------+----+-----+---+
|        TA|     1|      1|        1|       0|     0|       0|     1|  20221211|20221214|2022|   12|  7|
|      TA#1|     1|      1|        1|       1|     1|       1|     1|  20211212|20221210|2022|   12|  7|
|  TA+00000|     0|      0|        0|       0|     0|       1|     0|  20211212|20221210|2022|   12|  7|
|  TA+00010|     1|      1|        1|       0|     0|       0|     1|  20211212|20221210|2022|   12|  7|
|  TA+001c0|     1|      1|        1|       1|     1|       0|     0|  20211212|20221210|2022|   12|  7|
+----------+------+-------+---------+--------+------+--------+------+----------+--------+----+-----+---+
only showing top 5 rows

In [17]:
%%spark
# Trip data
trips = spark.read.csv("/data/sbb/part_csv/timetables/trips", header=True, encoding='utf8')
trips.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+----------+--------------------+--------------------+---------------+------------+----+-----+---+
|     route_id|service_id|             trip_id|       trip_headsign|trip_short_name|direction_id|year|month|day|
+-------------+----------+--------------------+--------------------+---------------+------------+----+-----+---+
|91-10-A-j22-1|     TA+6V|1.TA.91-10-A-j22-...|Oberwil BL, Hüsli...|          21313|           0|2022|   10| 12|
|91-10-A-j22-1|  TA+ndka0|10.TA.91-10-A-j22...|Oberwil BL, Hüsli...|          24673|           0|2022|   10| 12|
|91-10-A-j22-1|  TA+bh200|100.TA.91-10-A-j2...|Oberwil BL, Hüsli...|          21461|           0|2022|   10| 12|
|91-10-A-j22-1|     TA+rV|1000.TA.91-10-A-j...|    Dornach, Bahnhof|          51219|           0|2022|   10| 12|
|91-10-A-j22-1|     TA+rV|1001.TA.91-10-A-j...|    Dornach, Bahnhof|          51723|           0|2022|   10| 12|
+-------------+----------+--------------------+--------------------+---------------+------------

In [18]:
%%spark
# Routes data
routes = spark.read.csv("/data/sbb/part_csv/timetables/routes", header=True, encoding='utf8')
routes.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+---------+----------------+---------------+----------+----------+----+-----+---+
|     route_id|agency_id|route_short_name|route_long_name|route_desc|route_type|year|month|day|
+-------------+---------+----------------+---------------+----------+----------+----+-----+---+
|91-10-A-j22-1|       37|              10|           null|         T|       900|2022|   12|  7|
|91-10-A-j23-1|       78|             S10|           null|         S|       109|2022|   12|  7|
|91-10-B-j22-1|       78|             S10|           null|         S|       109|2022|   12|  7|
|91-10-B-j23-1|       11|             S10|           null|         S|       109|2022|   12|  7|
|91-10-C-j22-1|       11|             S10|           null|         S|       109|2022|   12|  7|
+-------------+---------+----------------+---------------+----------+----------+----+-----+---+
only showing top 5 rows

In [19]:
%%spark
# Routes data
stop_times = spark.read.csv("/data/sbb/part_csv/timetables/stop_times", header=True, encoding='utf8')
stop_times.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+------------+--------------+-------+-------------+-----------+-------------+----+-----+---+
|            trip_id|arrival_time|departure_time|stop_id|stop_sequence|pickup_type|drop_off_type|year|month|day|
+-------------------+------------+--------------+-------+-------------+-----------+-------------+----+-----+---+
|1.TA.91-8-j22-1.1.H|    24:40:00|      24:40:00|8591178|            1|          0|            0|2022|    6| 29|
|1.TA.91-8-j22-1.1.H|    24:41:00|      24:41:00|8591074|            2|          0|            0|2022|    6| 29|
|1.TA.91-8-j22-1.1.H|    24:42:00|      24:42:00|8591131|            3|          0|            0|2022|    6| 29|
|1.TA.91-8-j22-1.1.H|    24:43:00|      24:43:00|8591135|            4|          0|            0|2022|    6| 29|
|1.TA.91-8-j22-1.1.H|    24:44:00|      24:44:00|8580522|            5|          0|            0|2022|    6| 29|
+-------------------+------------+--------------+-------+-------------+-----------+-------------

In [20]:
%%spark
# Stops data
stops = spark.read.csv("/data/sbb/part_csv/timetables/stops", header=True, encoding='utf8')
stops.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+--------------------+----------------+----------------+-------------+--------------+----+-----+---+
|stop_id|           stop_name|        stop_lat|        stop_lon|location_type|parent_station|year|month|day|
+-------+--------------------+----------------+----------------+-------------+--------------+----+-----+---+
|1100008|Zell (Wiesental),...|47.7100842702352|7.85964788274668|         null|          null|2022|   12|  7|
|1100009|Zell (Wiesental),...|47.7131911044794|7.86290876722849|         null|          null|2022|   12|  7|
|1100010|           Atzenbach|47.7146175266411| 7.8723500608659|         null|          null|2022|   12|  7|
|1100011|     Mambach, Brücke|47.7282088873189| 7.8774704579861|         null|          null|2022|   12|  7|
|1100012|  Mambach, Mühlschau|47.7340818684375| 7.8813871126254|         null|          null|2022|   12|  7|
+-------+--------------------+----------------+----------------+-------------+--------------+----+-----+---+
only showing top 5 

In [21]:
%%spark
# inner join for the stops and stop_times dfs on the STOP_ID column
joined_df = stop_times.join(stops, on="STOP_ID", how="inner")

# only select the relevant cols
result_df = joined_df.select(
    "TRIP_ID",
    "ARRIVAL_TIME",
    "DEPARTURE_TIME",
    "STOP_ID",
    "STOP_NAME",
    "STOP_LAT",
    "STOP_LON"
)

result_df.drop_duplicates()
result_df.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+--------------+-------+--------------------+----------------+----------------+
|             TRIP_ID|ARRIVAL_TIME|DEPARTURE_TIME|STOP_ID|           STOP_NAME|        STOP_LAT|        STOP_LON|
+--------------------+------------+--------------+-------+--------------------+----------------+----------------+
|1.TA.79-730-5-j20...|    08:58:00|      08:58:00|1100150|Schlächtenhaus, S...|47.6898550981112|7.74233688977719|
|1.TA.79-730-5-j20...|    08:58:00|      08:58:00|1100150|Schlächtenhaus, S...|47.6898550981112|7.74233688977719|
|1.TA.79-730-5-j20...|    08:58:00|      08:58:00|1100150|Schlächtenhaus, S...|47.6898550981112|7.74233688977719|
|1.TA.79-730-5-j20...|    08:58:00|      08:58:00|1100150|Schlächtenhaus, S...|47.6898550981112|7.74233688977719|
|1.TA.79-730-5-j20...|    08:58:00|      08:58:00|1100150|Schlächtenhaus, S...|47.6898550981112|7.74233688977719|
+--------------------+------------+--------------+-------+--------------------+---------

## Istdaten
Load the dataset. Keep only rows from Zürich area and maek sure product ID is not null


In [22]:
####################################################################################
########  Only run if you don't have the table yet, takes 2 minutes ################
####################################################################################
selected_stops = tuple(set(geostops_df["name"].unique().tolist()))

query = """
    drop table if exists {0}.zur_istdaten
""".format(username)

cur.execute(query)

query = """
    create external table {0}.zur_istdaten
    as
    SELECT BETRIEBSTAG as trip_date, FAHRT_BEZICHNER as trip_id, PRODUKT_ID as transport_type, LINIEN_ID as train_number, LINIEN_TEXT as service_type,
    VERKEHRSMITTEL_TEXT as verkehrs_mittel, ZUSATZFAHRT_TF as additional_trip, FAELLT_AUS_TF as cancelled_trip, HALTESTELLEN_NAME as stop_name, 
    unix_timestamp(ANKUNFTSZEIT, 'dd.MM.yyyy HH:mm') as scheduled_ar, unix_timestamp(AN_PROGNOSE,'dd.MM.yyyy hh:mm:ss') as act_ar, 
    unix_timestamp(ABFAHRTSZEIT, 'dd.MM.yyyy HH:mm') as scheduled_dep, unix_timestamp(AB_PROGNOSE,'dd.MM.yyyy hh:mm:ss') as act_dep,
    AN_PROGNOSE_STATUS as measuring_method, DURCHFAHRT_TF as skip_stop
    from {0}.sbb_orc
    where BETRIEBSTAG like '__.__.____' and PRODUKT_ID is not NULL and PRODUKT_ID <> ''
    and HALTESTELLEN_NAME in {1}
    """.format(username, selected_stops)

cur.execute(query)

In [23]:
query = """
    select * from {0}.zur_istdaten
    limit 1
""".format(username)

cur.execute(query)
cur.fetchall()

[('28.12.2022',
  '85:882:129321-18101-1',
  'Bus',
  '85:882:105',
  '5',
  'B',
  'false',
  'false',
  'Brütten, Harossen',
  1672270740,
  1672268400,
  1672270740,
  1672268400,
  'REAL',
  'false')]

In [24]:
## TO BE CORRECTED
%%spark

complete_zur_istdaten = stop_times.join(zur_istdaten, on="STOP_NAME", how="left")

#  Fill missing values in the joined dataframe with values from the stop_times dataframe
complete_zur_istdaten = joined_df.select(
    coalesce(zur_istdaten.stop_name, result_df.STOP_NAME).alias("stop_name"),
    coalesce(zur_istdaten.scheduled_ar, result_df.ARRIVAL_TIME).alias("arrival_time"), # scheduled arrival time
    coalesce(zur_istdaten.act_ar, result_df.ARRIVAL_TIME).alias("act_ar"), # real arrival time, assume it is always on time
    coalesce(zur_istdaten.scheduled_dep, result_df.DEPARTURE_TIME).alias("scheduled_dep"), #scheduled dep time
    coalesce(zur_istdaten.act_dep, result_df.DEPARTURE_TIME).alias("scheduled_dep"), #actual dep time, assume it is on time if its missing
)




UsageError: Line magic function `%%spark` not found.


In [None]:
######################################################
######   INCOMPLETE CODE! DO NOT RUN! ################
######################################################


# LEFT JOIN to join on the trip_id column.
# COALESCE to choose choose the non-null value for each column between dfs. 
# If there is a non-null value in real_df, it will be used, otherwise the value from df_timetable will be used.
# UNION to append entries of 
query = """
    LEFT JOIN df_real ON df_timetable.trip_id = df_real.trip_id
    SELECT 
        COALESCE(df_real.trip_id, df_timetable.trip_id) AS trip_id,
        COALESCE(df_real.arrival_time, df_timetable.arrival_time) AS arrival_time,
        COALESCE(df_real.departure_time, df_timetable.departure_time) AS departure_time,
    
    INSERT OVERWRITE TABLE df_real
    SELECT 
        COALESCE(df_real.trip_id, df_timetable.trip_id) AS trip_id,
        COALESCE(df_real.arrival_time, df_timetable.arrival_time) AS arrival_time,
        COALESCE(df_real.departure_time, df_timetable.departure_time) AS departure_time,
    FROM df_timetable
    UNION
    SELECT * FROM df_real WHERE df_real.trip_id NOT IN (SELECT trip_id FROM df_timetable)
"""

In [None]:
# Get delays for arrival, departure
# Per stop per trip 

# Our approach

# At query time
# Calculate ecdf between two changeovers from all past trips from A to B (ecdf at B for trains/trips that passed through A)
# Calculate the delay in the confidence interval
# If smller than time to change over, trip is invalid

# Possible improvements
# IF many stops, how to compose probabilities (the same at each or different depending on distribution of delays)
#      Idea: Always take the maximum at each stop

# Other notes
# Only calculate delays of arrival
# Assume train after connection leaves on time, see Q3 in faq of readme

# Questions for wednesday
# Calculate ecdf based on trip id, or use same service at different hours
# How to reconcile service id from calendar.txt to istdaten trip_id

In [None]:
%%spark
df = spark.sql("select * from {0}.zur_istdaten".format("digennar"))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [26]:
%%spark

# df is complete_zur_istdaten

def get_ecdf(df, service_id, start_station, end_station):
    # ecdf at end for trains/trips that passed through start
    
    #print("Number of entries in zur_istdaten: " + str(df.count()))
    
    # Get trains with a given service id
    df_service = df.filter(df.train_number == service_id) # Should this be trip id?
    
    #print("Number of entries for service id: " + str(df_service.count()))
    
    # Get trips that pass through end
    #df_trips = df.filter(df_service.stop_name == end_station) # We should also make sure it passes through start (except if we consider a given service id always go through start if it goes through end)
    #df_ids = df.groupby("trip_id")
    
    #print("Number of entries that go through end station: " + str(df_trips.count()))
    
    # Make sure each id is for a train that passed through A and B
    #assert df_trips.count() == 2 * df_ids.count()
    
    # Calculate all observed delays at end
    df_delays = df_service.withColumn("obs_delay", df_service.act_ar - df_service.scheduled_ar)
    df_delays = df_delays.select(df_delays.obs_delay)
    df_delays = df_delays.na.fill(0) # If missing values consider that the train was on time
    
    # Get ecdf from observed delays (as of now, list of delays in ascending order)
    return df_delays.orderBy("obs_delay")

df_example = get_ecdf(df, "85:882:105", "Brütten, Harossen", "Brütten, Zentrum")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
%%spark

df_example.show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+
|obs_delay|
+---------+
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
|    -2460|
+---------+
only showing top 10 rows

In [None]:
# All past trips that go to B and pass by A for a given service id

# What do we need from part 1
# Istdaten with filled values of scheduled and actual arrival (we don't care about departure so much)


# What do we need from part 2
# List of given trips ids from A to B, B to C and so on until destinations (feasible from timetable)

# Biggest challenge for part 3
# Linking the given trip ids to a service to calculate delays from A to B.

In [28]:
%spark cleanup