In [None]:
import pyspark.sql.functions as F
from pyspark.sql.functions import lit, col, when
from math import radians, cos, sin, asin, sqrt
from pyspark.sql import SparkSession

In [None]:
import os
from IPython import get_ipython
username = os.environ['RENKU_USERNAME']
server = "http://iccluster044.iccluster.epfl.ch:8998"

get_ipython().run_cell_magic(
    'spark',
    line='config', 
    cell="""{{ "name": "{0}-final-project", "executorMemory": "4G", "executorCores": 4, "numExecutors": 10, "driverMemory": "4G" }}""".format(username)
)

In [None]:
get_ipython().run_line_magic(
    "spark", f"""add -s {username}-final-project -l python -u {server} -k"""
)
spark = SparkSession.builder.appName("final-project").getOrCreate()

### Select stops that are close to Zurich Hbf (within 15 km)

In [None]:
stops = spark.read.orc('/data/sbb/orc/allstops')
zurich_HB_location = stops.filter(stops.stop_name == 'Zürich HB').select('stop_lat', 'stop_lon').first()
bc_zhb_location = spark.sparkContext.broadcast(zurich_HB_location)

In [None]:
@F.udf
def distance(lat1, lon1, lat2 = bc_zhb_location.value.stop_lat, lon2 = bc_zhb_location.value.stop_lon):
    R = 6371
    
    d_lat = radians(lat2 - lat1)
    d_lon = radians(lon2 - lon1)
    
    a = sin(d_lat / 2) ** 2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(d_lon / 2) ** 2
    c = 2 * asin(sqrt(a))
    
    return R * c

In [None]:
distance_limit = 15.0

stops = stops.filter(distance(stops.stop_lat, stops.stop_lon) <= distance_limit)
stops.count()

### Join routes with stop times and filter stop times

In [None]:
routes = spark.read.orc('/data/sbb/part_orc/timetables/routes/year=2022/month=5/day=25')
trips = spark.read.orc('/data/sbb/part_orc/timetables/trips/year=2022/month=5/day=25')
stop_times = spark.read.orc('/data/sbb/part_orc/timetables/stop_times/year=2022/month=5/day=25')

# consider only trips between 6 a.m. and 6 p.m.
stop_times = stop_times.where((stop_times.departure_time >= "06:00:00") & (stop_times.arrival_time <= "18:00:00")).join(trips, 'trip_id').join(routes, 'route_id')
stop_times = stops.join(stop_times, 'stop_id').select('trip_id', 'arrival_time', 'departure_time', 'stop_id', 'stop_sequence', 'route_id', 'route_desc')
stop_times.show()

trips = trips.join(stop_times, 'trip_id', 'leftsemi')
trips.show()

### Create connections
`connections` include stop paris of all trips included.

In [None]:
# join with consecutive stop sequence
connections = stop_times.alias('a').crossJoin(stop_times.alias('b')) \
                                   .where((col('a.stop_sequence') == (col('b.stop_sequence') + 1)) & (col('a.trip_id') == col('b.trip_id')) & (col('a.arrival_time') > col('b.departure_time'))) \
                                   .select(col('a.trip_id'), col('b.departure_time'), col('a.arrival_time'), col('b.stop_id').alias('departure_stop'), col('a.stop_id').alias('arrival_stop'), col('a.route_desc'))
connections.cache()
connections.show()

### Create transfers

In [None]:
walk_spped = 50 # m/s
walking_time_limit = 10

transfers = spark.read.orc('/data/sbb/part_orc/timetables/transfers')
transfers = transfers.join(stops.alias('a'), transfers.from_stop_id == col('a.stop_id')) \
                     .join(stops.alias('b'), transfers.to_stop_id == col('b.stop_id')) \
                     .withColumn('transfer_time', transfers.min_transfer_time / 60).drop('min_transfer_time').select('from_stop_id', 'to_stop_id', 'transfer_time') \
                     .dropDuplicates(['from_stop_id', 'to_stop_id'])

# add transfer for nearby stops within 10 min. walk
stops_without_transfer = stops.join(transfers, stops.stop_id == transfers.from_stop_id, 'leftanti')
footpaths = stops_without_transfer.alias('a').crossJoin(stops_without_transfer.alias('b')) \
                                  .withColumn('transfer_time', distance(col('a.stop_lat'), col('a.stop_lon'), col('b.stop_lat'), col('b.stop_lon')) * 1000 / walk_spped) \
                                  .select(col('a.stop_id').alias('from_stop_id'), col('b.stop_id').alias('to_stop_id'), col('transfer_time')) \
                                  .filter(col('transfer_time') <= walking_time_limit) \
                                  .withColumn('transfer_time', when(col('from_stop_id') == col('to_stop_id'), 2).otherwise(col('transfer_time')))
# add additional footpaths for transferring at the same stop
identity = spark.createDataFrame(stops.select(['stop_id', 'stop_id']).collect(), ['from_stop_id' 'to_stop_id']).withColumn('transfer_time', lit(0))
footpaths = footpaths.union(identity)

transfers = transfers.union(footpaths)

In [None]:
connections_df = connections.toPandas()
connections_df

In [None]:
transfers_df = transfers.toPandas()
transfers_df

In [16]:
connections_df.to_csv('../data/connections.csv', index=False)
transfers_df.to_csv('../data/transfers.csv', index=False)