In [3]:
# pandas and plotting libraries for visualizations
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# module containing functions for manipulation pyspark dataframes
import pyspark.sql.functions as f

# class which will let us create spark objects
from pyspark.sql import SparkSession

# helper functions for the class
from helpers import display, read_df, write_df

## [PySpark SQL docs](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html)

## Create a Spark Session

In [2]:
spark = (
    SparkSession
    .builder
    .appName('data_exploration')
    .master('local[2]')
    .getOrCreate()
)

## Read in data file

In [4]:
df = read_df(spark, '../taxi_2016.parquet')

In [5]:
display(df)

Unnamed: 0,trip_id,taxi_id,start_time,end_time,trip_miles,pickup_census_tract,dropoff_census_tract,fare,tips,trip_total,payment_type,company
0,10b8ae23847aeff11587cf33d3d9040631051749,deec568f18d18fb40292b918b82311ac8006ada6a8a1ff...,2016-01-03 14:00:00,2016-01-03 14:15:00,1.00,1.703132e+10,1.703128e+10,6.50,0.00,6.50,Cash,
1,10b8b00d0f54e49fa09ba66b9d1fe5c3a44bd639,22c2c92ba4cc6458780df495b72b3fb22a83cfd79728bf...,2016-05-09 19:30:00,2016-05-09 19:45:00,0.00,1.703128e+10,1.703184e+10,5.00,2.00,7.00,Credit Card,Taxi Affiliation Services
2,10b8b20f6e87bfaf0a658b2725b9a5ce29cb1a43,3a4af34907634686a9444651f65638134c2a972bfc5994...,2016-03-09 08:45:00,2016-03-09 08:45:00,0.90,1.703184e+10,1.703108e+10,5.75,2.00,7.75,Credit Card,Northwest Management LLC
3,10b8b21fba6a350d71ab815febecae503187490b,e6ec7ccfda661f2050f932f672395ff77cd1359b3ae62b...,2016-04-20 00:30:00,2016-04-20 00:45:00,1.60,1.703183e+10,1.703132e+10,7.50,2.00,9.50,Credit Card,Taxi Affiliation Services
4,10b8b22261d36a0c51deab377237260e171765d6,f837696d781bfe4f3da410734cbfa51b47f328de974a96...,2016-04-12 14:30:00,2016-04-12 14:30:00,0.80,,,5.75,0.00,5.75,Cash,Taxi Affiliation Services
5,10b8b2a042e9bc7e8ca948cff76f959bcd81655b,cae7b41a1af6467dbadb3446672faf7eed1cf0079be5ac...,2016-07-29 00:30:00,2016-07-29 00:30:00,2.20,,,9.00,0.00,10.00,Cash,Taxi Affiliation Services
6,10b8b2c847a032eba3500dfe134baf3472e6d1bc,2c2846e06a06b51bf72e041bce82b95d5471941476ff25...,2016-01-27 16:45:00,2016-01-27 17:30:00,15.50,,,41.00,0.00,45.00,Cash,Choice Taxi Association
7,10b8b33db913f0b34ed75945748cebc5cc6c5b4f,af7c294a5ab2a7fcd6fb4b06820fe8abf65e649d814526...,2016-02-07 17:15:00,2016-02-07 17:30:00,3.11,1.703108e+10,1.703183e+10,13.00,0.00,13.00,Cash,
8,10b8b470339ffada95f7dc0c9180058b7eb99df4,6778d44f5164d57ce76c95242f4f3af1dd0986082f1be3...,2016-01-30 02:15:00,2016-01-30 02:15:00,0.00,,,7.00,0.00,8.00,Cash,Taxi Affiliation Services
9,10b8b62e54671f95d079fea70268e7208e90c7b5,493b6af5931ea2c7c6a82d9d6e2e17af3abae397edd72e...,2016-04-05 16:00:00,2016-04-05 16:45:00,1.10,1.703198e+10,1.703132e+10,46.50,10.30,61.80,Credit Card,Taxi Affiliation Services


In [None]:
df.columns

In [None]:
display(df, 10)

In [None]:
total_rows = df.count() # ~3 million trips
print(total_rows)

In [None]:
display(df.agg(f.countDistinct('taxi_id')))

In [None]:
display(df.agg((f.count('trip_miles')/total_rows).alias('trip_miles')))

In [None]:
display(df.agg(*[(f.count(c)/total_rows).alias(c) for c in df.columns]))

In [None]:
trips_per_taxi = df.groupBy('taxi_id').count()

In [None]:
display(trips_per_taxi, 10)

In [None]:
plt.figure()
sns.distplot(trips_per_taxi.select('count').toPandas()).set_title('Trips Per Taxi');

In [None]:
distance_traveled_per_trip = (
    df
    .where(f.col('trip_miles').isNotNull())
    .groupBy('trip_id')
    .agg(f.sum('trip_miles').alias('miles'))
)

In [None]:
display(distance_traveled_per_trip, 10)

In [None]:
plt.figure()
(
    sns
    .distplot(distance_traveled_per_trip.select('miles').toPandas())
    .set_title('Miles Traveled Per Taxi')
);

In [None]:
# when do most trips occur? 
display(df.groupby(f.hour('start_time')).count().orderBy('count', ascending=False))
# early evening, 5-7pm is when taxi trips start

In [None]:
# what's the most common length for a trip?

In [None]:
trip_length = (
    df
    .withColumn('minutes', (f.unix_timestamp(f.col('end_time')) - f.unix_timestamp(f.col('start_time')))/60)
    .groupBy('minutes')
    .count()
    .where(f.col('minutes').isNotNull())
    .where('minutes < 1048')
    .where('minutes > 0')
)

display(trip_length.orderBy('count', ascending=False), 10)