In [1]:
import geopandas as gpd
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from shapely.geometry import Point
import mplleaflet
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T

spark = (
    SparkSession.builder
        .master("local")
        .appName("TFL Notebook")
        .config('spark.executor.memory', '8G')
        .config('spark.driver.memory', '16G')
        .config('spark.driver.maxResultSize', '10G')
        .getOrCreate()
)

trips = spark.read.parquet("../data/parquet_trip")

trips.createOrReplaceTempView("trips")
trips = spark.sql("""
    select *,
    CASE WHEN start_station_name = end_station_name THEN "Round Trip" ELSE "Point to Point" END AS trip_type
    from trips
""")
trips.createOrReplaceTempView("trips")

BIKE_POINTS_FILE = "../data/bike-points.csv"

schema = T.StructType([
    T.StructField("idx",       T.IntegerType(), False),
    T.StructField("id",        T.IntegerType(), False),
    T.StructField("name",      T.StringType(),  False),
    T.StructField("latitude",  T.DoubleType(),  False),
    T.StructField("longitude", T.DoubleType(),  False),
    T.StructField("osgb_x",    T.DoubleType(),  False),
    T.StructField("osgb_y",    T.DoubleType(),  False),
    T.StructField("numdocks",  T.LongType(),    False),
    T.StructField("num_bikes", T.LongType(),    False),
    T.StructField("num_empty", T.LongType(),    False)
])
bike_points = spark.read.csv(BIKE_POINTS_FILE, schema=schema, header='true', mode="PERMISSIVE")
bike_points.createOrReplaceTempView("bike_points")

def to_geo(df, x_field='longitude', y_field='latitude'):
    geometry = [Point(xy) for xy in zip(df[x_field], df[y_field])]
    return gpd.GeoDataFrame(df, geometry=geometry)

station_totals = spark.sql("""
    select start_station_id, count(*) as trip_count, sum(duration) as duration
    from trips 
    group by start_station_id
""")
station_totals.createOrReplaceTempView("station_totals")

station_data = spark.sql("""
    select bp.id, bp.longitude, bp.latitude, st.duration, st.trip_count
    from bike_points bp join station_totals st on bp.id = st.start_station_id
""")
station_data.createOrReplaceTempView("station_data")

In [2]:
df = to_geo(station_data.toPandas())

fig,ax=plt.subplots(figsize=(10,10))

df.plot(ax=ax, markersize=30, marker='o', column='trip_count', cmap='coolwarm')

mplleaflet.display(fig=fig, crs=df.crs, tiles='cartodb_positron')