In [1]:
#We need to import pyspark to complete our analysis; pyspark is a large library with SQL and machine learning
#It's first necessary to find where it is
#Once found, import pyspark
import findspark
findspark.init()
findspark.find()
import pyspark

#this line is not necessary as the code returns the same output before and after importing pyspark
#findspark.find() 

In [2]:
#this line is not necessary as we have already imported pyspark
#import pyspark

from pyspark.sql import SQLContext, SparkSession #this line imports an additonal libarary necessary for us to use SQL

#It's necessary to define the connection requirements between spark and python
#The first step is stating where the connection should be made (i.e. local)
#Then we limit the memory to be used, as the computer may not have the capacity to handle all of the data
spark = pyspark.sql.SparkSession.builder \
   .master("local") \
   .config("spark.executor.memory", "2gb") \
   .getOrCreate()

#It's then necessary to set the connection, otherwise known as the Context
#We do this for spark, and for sql
sc = spark.sparkContext 
sqlContext = SQLContext(sc)

In [3]:
#Resilient Distributed Datasets (RDD) is a fundamental data structure of Spark. 
#It is an immutable distributed collection of objects. 
#Each dataset in RDD is divided into logical partitions, which may be computed on different nodes of the cluster.
#Formally, an RDD is a read-only, partitioned collection of records.
#Spark makes use of the concept of RDD to achieve faster and efficient MapReduce operations.

#We need to import the taxi data as an RDD in order to analyse using spark
data = sc.textFile("Taxi.csv")

In [9]:
#Here we confirm the "type" of the data that has been imported
#This line has been moved to appear after original import of data rather than after checking the data values
type(data)

pyspark.rdd.RDD

In [10]:
# Take the first n elements of the RDD
data.take(10)

['VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount',
 '',
 '1,2018-12-01 00:28:22,2018-12-01 00:44:07,2,2.50,1,N,148,234,1,12,0.5,0.5,3.95,0,0.3,17.25',
 '1,2018-12-01 00:52:29,2018-12-01 01:11:37,3,2.30,1,N,170,144,1,13,0.5,0.5,2.85,0,0.3,17.15',
 '2,2018-12-01 00:12:52,2018-12-01 00:36:23,1,.00,1,N,113,193,2,2.5,0.5,0.5,0,0,0.3,3.8',
 '1,2018-12-01 00:35:08,2018-12-01 00:43:11,1,3.90,1,N,95,92,1,12.5,0.5,0.5,2.75,0,0.3,16.55',
 '1,2018-12-01 00:21:54,2018-12-01 01:15:13,1,12.80,1,N,163,228,1,45,0.5,0.5,9.25,0,0.3,55.55',
 '1,2018-12-01 00:00:38,2018-12-01 00:29:26,1,18.80,1,N,132,97,1,50.5,0.5,0.5,10.35,0,0.3,62.15',
 '1,2018-12-01 00:59:39,2018-12-01 01:09:07,1,1.00,1,N,246,164,1,7.5,0.5,0.5,0.44,0,0.3,9.24',
 '1,2018-12-01 00:19:19,2018-12-01 00:22:19,1,.30,1,N,161,163,4,4,0.5,0.5,0,0,0.3,5.3']

In [11]:
# Explore the columns names
data.first()

'VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount'

In [12]:
# this line is irrelevant as we have already explored the column names using the data.first()
#data.top(1)

Text files are difficult to analyse, so using SQL

In [13]:
#Import a subset of data (from a new csv file) in pyspark using the SQL Context previously defined
#This creates a new dataframe that allows us to use SQL commands, unlike the textfile we originally imported to view the data
data2  = sqlContext.read.csv("taxi_dec.csv", header = True)

In [14]:
#This code confirms the 'type' of the data post-import
type(data2)

pyspark.sql.dataframe.DataFrame

In [15]:
#Show column names and data type of each
data2.printSchema()

root
 |-- VendorID: string (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: string (nullable = true)
 |-- trip_distance: string (nullable = true)
 |-- RatecodeID: string (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: string (nullable = true)
 |-- DOLocationID: string (nullable = true)
 |-- payment_type: string (nullable = true)
 |-- fare_amount: string (nullable = true)
 |-- extra: string (nullable = true)
 |-- mta_tax: string (nullable = true)
 |-- tip_amount: string (nullable = true)
 |-- tolls_amount: string (nullable = true)
 |-- improvement_surcharge: string (nullable = true)
 |-- total_amount: string (nullable = true)



In [16]:
#Review the shape of the database before removal of null values; row count and column count
print("Row Count: ")
print(data2.count())
print("Column Count: ")
print(len(data2.columns))

Row Count: 
8173231
Column Count: 
17


In [17]:
#Remove all null values within the database
data2.dropna()

DataFrame[VendorID: string, tpep_pickup_datetime: string, tpep_dropoff_datetime: string, passenger_count: string, trip_distance: string, RatecodeID: string, store_and_fwd_flag: string, PULocationID: string, DOLocationID: string, payment_type: string, fare_amount: string, extra: string, mta_tax: string, tip_amount: string, tolls_amount: string, improvement_surcharge: string, total_amount: string]

In [18]:
#Review the shape of the database following removal of null values; row count, column count and column names
#Cell originally had three lines of code, but so only the last output was shown
#Code adjusted to show each output with a header
print("Row Count: ")
print(data2.count())
print("Column Count: ")
print(len(data2.columns))
print("Column Names: ")
print(data2.columns)

Row Count: 
8173231
Column Count: 
17
Column Names: 
['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_distance', 'RatecodeID', 'store_and_fwd_flag', 'PULocationID', 'DOLocationID', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount']


In [19]:
#This code has been moved down to appear after checking the database shape
#Here we're looking at some stats around the data held; specifically around the number of passengers
data2.select("passenger_count").describe().show()

+-------+------------------+
|summary|   passenger_count|
+-------+------------------+
|  count|           8173231|
|   mean|1.5964102813195908|
| stddev| 1.233920232393633|
|    min|                 0|
|    max|                 9|
+-------+------------------+



In [20]:
#View several stats from the database side by side; compare high level view of passengers, distance and cost of trips
data2.select('passenger_count', 'trip_distance', 'total_amount').describe().show()

+-------+------------------+------------------+------------------+
|summary|   passenger_count|     trip_distance|      total_amount|
+-------+------------------+------------------+------------------+
|  count|           8173231|           8173231|           8173231|
|   mean|1.5964102813195908|2.8926264215460553|16.524363505886303|
| stddev| 1.233920232393633| 3.764338945224816|120.42046343163905|
|    min|                 0|               .00|             -0.31|
|    max|                 9|             99.07|             99.98|
+-------+------------------+------------------+------------------+



In [21]:
#Import additional SQL functions to aid analysis
#Use the new functions to create a table based on the set criteria from the database
#This is used to check whether all null values were in fact removed; anticipate all zeros
from pyspark.sql.functions import isnan, when, count, col
data2.select([count(when(isnan(c), c)).alias(c) for c in data2.columns]).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+
|       0|                   0|                    0|              0|            0|         0|                 0|           0|           0|           0|          0|    0|      0|         0|           0|                    0|           0|
+--------+--------------------+-----------------

In [22]:
#This code is irrelvant as we've already dropped null values and then counted the rows
#The code is more efficient so could be used above instead of the existing code
#data2.dropna().count()

In [23]:
#Filter by passenger code, then total amount, then total amount and distance
data = data2.filter((data2.passenger_count <= 4) & (data2.passenger_count > 0))
data = data.filter(data.total_amount >= 2.5) #the law states minimum cost of taxi journey is $2.50
data = data.filter((data.total_amount < 200) & (data.trip_distance < 100)) # total amount is never more than 99, remove?

In [24]:
data = data.filter(data.trip_distance > 0)

In [25]:
data.select('passenger_count', 'total_amount').describe().show()

+-------+------------------+------------------+
|summary|   passenger_count|      total_amount|
+-------+------------------+------------------+
|  count|           5495209|           5495209|
|   mean|  1.34628091488422| 19.36391156930136|
| stddev|0.6915259106620711|14.465422201705064|
|    min|                 1|                10|
|    max|                 4|             99.98|
+-------+------------------+------------------+



In [26]:
data.select('tpep_pickup_datetime', 'tpep_dropoff_datetime').describe().show()

+-------+--------------------+---------------------+
|summary|tpep_pickup_datetime|tpep_dropoff_datetime|
+-------+--------------------+---------------------+
|  count|             5495209|              5495209|
|   mean|                null|                 null|
| stddev|                null|                 null|
|    min| 2003-12-23 05:01:51|  2003-12-23 05:09:00|
|    max| 2020-12-10 20:34:26|  2020-12-10 20:54:46|
+-------+--------------------+---------------------+



In [27]:
pickups = data.select(col('PULocationID').alias('LocationID'))\
                        .groupBy('LocationID')\
                        .count()\
                        .sort('LocationID')\ #irrelevant unless viewing data
                        .select(col('count').alias('pickups'), col('LocationID'))

In [28]:
dropoffs = data.select(col('DOLocationID')\
                        .alias('LocationID'))\
                        .groupBy('LocationID')\
                        .count()\
                        .sort('LocationID')\ #irrelevant unless viewing data
                        .select(col('count').alias('dropoffs'), col('LocationID'))

In [29]:
picks_drops = pickups.join(dropoffs, on= ['LocationID']) #inner join as nothing defined, so default

In [30]:
picks_drops.show()

+----------+-------+--------+
|LocationID|pickups|dropoffs|
+----------+-------+--------+
|       125|  30437|   26820|
|       124|    210|     858|
|        51|    544|    1388|
|         7|   8222|   28315|
|       169|    277|    1530|
|       205|    575|    1398|
|        15|     88|     651|
|       232|   9827|   28553|
|       234| 156298|  125827|
|        54|    202|    1722|
|       155|    470|    1204|
|       132| 165174|   54698|
|       154|     27|      70|
|       200|    227|    2382|
|       101|    159|     544|
|        11|    102|     621|
|       138| 188565|   78204|
|        29|    140|     750|
|        69|    662|    3426|
|       112|   2182|   15268|
+----------+-------+--------+
only showing top 20 rows



In [31]:
#duplicate code, removed
#picks_drops.show()

+----------+-------+--------+
|LocationID|pickups|dropoffs|
+----------+-------+--------+
|       125|  30437|   26820|
|       124|    210|     858|
|        51|    544|    1388|
|         7|   8222|   28315|
|       169|    277|    1530|
|       205|    575|    1398|
|        15|     88|     651|
|       232|   9827|   28553|
|       234| 156298|  125827|
|        54|    202|    1722|
|       155|    470|    1204|
|       132| 165174|   54698|
|       154|     27|      70|
|       200|    227|    2382|
|       101|    159|     544|
|        11|    102|     621|
|       138| 188565|   78204|
|        29|    140|     750|
|        69|    662|    3426|
|       112|   2182|   15268|
+----------+-------+--------+
only showing top 20 rows



In [38]:
picks_drops = picks_drops.toPandas()

AttributeError: 'DataFrame' object has no attribute 'toPandas'

In [None]:
picks_drops.head()

In [None]:
picks_drops['total'] = picks_drops['pickups'] + picks_drops['dropoffs']

In [None]:
import pandas as pd
import numpy as np

In [None]:
import geopandas
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline

In [None]:
NYC_map = geopandas.read_file('taxi_zones/taxi_zones.shp')
NYC_map.crs

In [None]:
NYC_map.tail()

In [None]:
picks_drops['LocationID'] = picks_drops['LocationID'].astype(int)
picks_drops.dtypes

In [None]:
NYC = NYC_map.merge(picks_drops, on = 'LocationID')


In [None]:
NYC.head()

In [None]:
#do not correspond to any plot so irrelevant, removed
#import matplotlib.pyplot as plt
#import seaborn as sns

In [None]:
#line is unnecessary, removed
#fig, ax = plt.subplots(1, 1)

In [None]:
NYC.plot(column = 'total', cmap='OrRd', scheme='quantiles')

In [None]:
top_pickups = NYC[['pickups', 'zone', 'borough']].sort_values('pickups', ascending = False)

In [None]:
top_pickups.head(10)

In [None]:
sc.close()