# High-level Preprocessing Steps:
    1. Change column names and types
    2. Cleaning: date outside range, negative duration, duration > 18000s, null values
    3. Add Features from external dataset: hour, day, is_school_holiday, weather features (temperature, precipitation, wind, etc.)
    4. Check and deal with null values    
    

In [1]:
from pyspark.sql.functions import to_timestamp, date_format, hour, dayofweek
from pyspark.sql.functions import isnan, when, count, col, split, concat, lit
from pyspark.sql.functions import to_date, create_map
from itertools import chain
from pyspark.sql import SparkSession, Window, functions as F
import pandas as pd
import seaborn as sns
from datetime import datetime
import matplotlib.pyplot as plt
%matplotlib inline


In [2]:
# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("Preprocess Data")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .getOrCreate()
)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/20 18:02:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# read datasets
taxi = spark.read.parquet('../data/raw/tlc_data')
weather = spark.read.csv('../data/raw/other_data/jfk_weather.csv', header=True)
school_holiday = spark.read.csv('../data/raw/other_data/nyc_school_holiday.csv',
                                header=True)
zones = pd.read_csv("../data/raw/taxi_zones/taxi+_zone_lookup.csv")

test_taxi = spark.read.parquet('../data/raw/test_data/test_data.parquet')
test_weather = spark.read.csv('../data/raw/test_data/weather_test_data.csv', header=True)

                                                                                

In [4]:
weather.limit(5)

STATION,DATE,SOURCE,REPORT_TYPE,CALL_SIGN,QUALITY_CONTROL,AA1,DEW,KB1,MA1,MD1,MF1,OC1,RH1,TMP,VIS,WND
74486094789,2018-01-01T00:00:00,4,FM-12,99999,V020,,-2001,,999999102591,"3,1,014,1,+999,9",,,,-1111,016000199,"310,1,N,0093,1"
74486094789,2018-01-01T00:51:00,7,FM-15,KJFK,V030,1000095.0,-2005,,102715102625,,,,,-1175,"016093,5,N,5","320,5,N,0067,5"
74486094789,2018-01-01T01:51:00,7,FM-15,KJFK,V030,1000095.0,-2065,,102715102625,,,,,-1175,"016093,5,N,5","330,5,N,0093,5"
74486094789,2018-01-01T02:51:00,7,FM-15,KJFK,V030,1000095.0,-2005,,102715102625,"3,9,001,9,+999,9",,,,-1225,"016093,5,N,5","310,5,N,0093,5"
74486094789,2018-01-01T03:00:00,4,FM-12,99999,V020,,-2001,,999999102591,"3,1,001,1,+999,9",,,,-1221,016000199,"310,1,N,0093,1"


In [5]:
taxi.printSchema()

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)



In [6]:
# taxi Dataset
# renaming a few columns
field_name_change = {"PULocationID": "pu_location_id", 
                     "DOLocationID": "do_location_id"}
for old, new in field_name_change.items():
    taxi = taxi.withColumnRenamed(old, new)
    test_taxi = test_taxi.withColumnRenamed(old, new)
    
# converting columns type
for field in ('pu_location_id', 'do_location_id'):
    taxi = taxi.withColumn(
                field,
                F.col(field).cast('INT')
             )
    test_taxi = test_taxi.withColumn(
                field,
                F.col(field).cast('INT')
             )

In [7]:
columns_interest = ['tpep_pickup_datetime', 'tpep_dropoff_datetime', 
                    'pu_location_id', 'do_location_id', 'total_amount']
taxi = taxi[columns_interest]
test_taxi = test_taxi[columns_interest]

weather_col_interest = ['DATE', 'TMP', 'DEW', 'AA1', 'WND', 'MA1', 'VIS']
weather = weather[weather_col_interest]
test_weather = test_weather[weather_col_interest]

In [8]:
weather.count(), taxi.count()

                                                                                

(27170, 187469831)

In [9]:
taxi.printSchema()

root
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- pu_location_id: integer (nullable = true)
 |-- do_location_id: integer (nullable = true)
 |-- total_amount: double (nullable = true)



In [10]:
weather.printSchema()

root
 |-- DATE: string (nullable = true)
 |-- TMP: string (nullable = true)
 |-- DEW: string (nullable = true)
 |-- AA1: string (nullable = true)
 |-- WND: string (nullable = true)
 |-- MA1: string (nullable = true)
 |-- VIS: string (nullable = true)



In [11]:
weather.limit(5)

DATE,TMP,DEW,AA1,WND,MA1,VIS
2018-01-01T00:00:00,-1111,-2001,,"310,1,N,0093,1",999999102591,016000199
2018-01-01T00:51:00,-1175,-2005,1000095.0,"320,5,N,0067,5",102715102625,"016093,5,N,5"
2018-01-01T01:51:00,-1175,-2065,1000095.0,"330,5,N,0093,5",102715102625,"016093,5,N,5"
2018-01-01T02:51:00,-1225,-2005,1000095.0,"310,5,N,0093,5",102715102625,"016093,5,N,5"
2018-01-01T03:00:00,-1221,-2001,,"310,1,N,0093,1",999999102591,016000199


In [12]:
# Weather Dataset

# change column names
field_name_change = {"DATE": "date_time", "TMP": "temperature",
                     "DEW": "dew_point", "AA1": "precipitation", 
                     "WND": "wind_direction", "MA1": "pressure",
                     "VIS": "visibility"}
for old, new in field_name_change.items():
    weather = weather.withColumnRenamed(old, new)
    test_weather = test_weather.withColumnRenamed(old, new)

# change column data types
weather = weather.withColumn('date_time', F.col("date_time").cast('TIMESTAMP'))
test_weather = test_weather.withColumn('date_time', F.col("date_time").cast('TIMESTAMP'))

for field in ('temperature', 'dew_point'):
    weather = weather.withColumn(field, 
                                 concat(weather[field].substr(1, 1), 
                                        weather[field].substr(2, 4), 
                                        lit('.'),weather[field].substr(7, 1)) \
                                 .cast('double'))
    test_weather = test_weather.withColumn(field, 
                                 concat(test_weather[field].substr(1, 1), 
                                        test_weather[field].substr(2, 4), 
                                        lit('.'),test_weather[field].substr(7, 1)) \
                                 .cast('double'))

for field in (['wind_direction', 'visibility']):
    weather = weather.withColumn(field, 
                                 split(weather[field], ',') \
                                 .getItem(0).cast('INT'))
    test_weather = test_weather.withColumn(field, 
                                 split(test_weather[field], ',') \
                                 .getItem(0).cast('INT'))

weather = weather.withColumn('precipitation', 
                             split(weather['precipitation'], ',') \
                             .getItem(1).cast('INT'))
test_weather = test_weather.withColumn('precipitation', 
                             split(test_weather['precipitation'], ',') \
                             .getItem(1).cast('INT'))

weather = weather.withColumn('pressure', 
                             split(weather['pressure'], ',') \
                             .getItem(2).cast('INT'))
test_weather = test_weather.withColumn('pressure', 
                             split(test_weather['pressure'], ',') \
                             .getItem(2).cast('INT'))



In [13]:
weather.limit(5)

date_time,temperature,dew_point,precipitation,wind_direction,pressure,visibility
2018-01-01 00:00:00,-111.1,-200.1,,310,10259,16000
2018-01-01 00:51:00,-117.5,-200.5,0.0,320,10262,16093
2018-01-01 01:51:00,-117.5,-206.5,0.0,330,10262,16093
2018-01-01 02:51:00,-122.5,-200.5,0.0,310,10262,16093
2018-01-01 03:00:00,-122.1,-200.1,,310,10259,16000


---
# Weather Data Cleaning
#### 1. Change 9s to null

In [14]:
# found many entries where there are 9s in multiple columns
# documentation stated these are missing values
weather.filter(F.col('visibility') == 999999).limit(5)

date_time,temperature,dew_point,precipitation,wind_direction,pressure,visibility
2018-01-01 04:59:00,9999.9,9999.9,0.0,999,,999999
2018-01-01 04:59:00,9999.9,9999.9,,999,,999999
2018-01-02 04:59:00,9999.9,9999.9,0.0,999,,999999
2018-01-03 04:59:00,9999.9,9999.9,0.0,999,,999999
2018-01-04 04:59:00,9999.9,9999.9,0.0,999,,999999


In [15]:
# set features to null when equals to 9s
miss_values = {"precipitation": 9999, "temperature": 9999.9, 
               "dew_point": 9999.9, "wind_direction": 999, 
               "pressure": 99999, "visibility":999999}
for field, val in miss_values.items():
    weather = weather.withColumn(field, 
                             when(col(field) == val, None) \
                             .otherwise(col(field)))
    test_weather = test_weather.withColumn(field, 
                             when(col(field) == val, None) \
                             .otherwise(col(field)))

#### 2. Fill null values by using the median of 6 hours timeframe segment

In [16]:
weather.count

<bound method DataFrame.count of +-------------------+-----------+---------+-------------+--------------+--------+----------+
|          date_time|temperature|dew_point|precipitation|wind_direction|pressure|visibility|
+-------------------+-----------+---------+-------------+--------------+--------+----------+
|2018-01-01 00:00:00|     -111.1|   -200.1|         null|           310|   10259|     16000|
|2018-01-01 00:51:00|     -117.5|   -200.5|            0|           320|   10262|     16093|
|2018-01-01 01:51:00|     -117.5|   -206.5|            0|           330|   10262|     16093|
|2018-01-01 02:51:00|     -122.5|   -200.5|            0|           310|   10262|     16093|
|2018-01-01 03:00:00|     -122.1|   -200.1|         null|           310|   10259|     16000|
|2018-01-01 03:51:00|     -122.5|   -194.5|            0|           310|   10266|     16093|
|2018-01-01 04:51:00|     -128.5|   -194.5|            0|           330|   10262|     16093|
|2018-01-01 04:59:00|       null|    

In [17]:
dict_null = {col:weather.filter(weather[col].isNull()).count() 
             for col in weather.columns}
dict_null

{'date_time': 0,
 'temperature': 776,
 'dew_point': 777,
 'precipitation': 6214,
 'wind_direction': 1880,
 'pressure': 962,
 'visibility': 756}

In [18]:
weather.count()

27170

In [19]:
weather = weather.withColumn(
    "DATE",
    to_date(col("date_time"),"yyyy-MM-dd")
)
test_weather = test_weather.withColumn(
    "DATE",
    to_date(col("date_time"),"yyyy-MM-dd")
)

weather = weather.withColumn('hour', hour(weather.date_time))
test_weather = test_weather.withColumn('hour', hour(test_weather.date_time))


In [20]:
# partition 24 hours into group of 6 (4 hours each)
segment = {}
seg = 0
for h in range(0, 24):
    if h % 6 == 0:
        seg += 1
    segment[h] = seg

# map dictionary to a new column in weather dataframe
mapping_expr = create_map([lit(x) for x in chain(*segment.items())])

weather = weather.withColumn("time_segment", mapping_expr[col("hour")])
test_weather = test_weather.withColumn("time_segment", mapping_expr[col("hour")])

In [21]:
weather.columns

['date_time',
 'temperature',
 'dew_point',
 'precipitation',
 'wind_direction',
 'pressure',
 'visibility',
 'DATE',
 'hour',
 'time_segment']

In [22]:
window = Window.partitionBy(["date", "time_segment"]).orderBy("date")

fields = ['temperature', 'dew_point', 'precipitation', 
          'wind_direction', 'visibility', 'pressure']

for field in fields:
    weather = (
        weather.withColumn("median", F.expr(f'percentile_approx({field}, 0.5)').over(window)) \
            .withColumn(field, F.when(F.col(field).isNull(), 
                                              F.col("median")) \
                        .otherwise(F.col(field))).drop("median")
    )
    test_weather = (
        test_weather.withColumn("median", F.expr(f'percentile_approx({field}, 0.5)').over(window)) \
            .withColumn(field, F.when(F.col(field).isNull(), 
                                              F.col("median")) \
                        .otherwise(F.col(field))).drop("median")
    )


In [23]:
dict_null = {col:weather.filter(weather[col].isNull()).count() 
             for col in weather.columns}
dict_null

{'date_time': 0,
 'temperature': 0,
 'dew_point': 0,
 'precipitation': 32,
 'wind_direction': 0,
 'pressure': 0,
 'visibility': 0,
 'DATE': 0,
 'hour': 0,
 'time_segment': 0}

In [24]:
dict_null = {col:test_weather.filter(test_weather[col].isNull()).count() 
             for col in test_weather.columns}
dict_null

{'date_time': 0,
 'temperature': 0,
 'dew_point': 0,
 'precipitation': 172,
 'wind_direction': 8,
 'pressure': 0,
 'visibility': 0,
 'DATE': 0,
 'hour': 0,
 'time_segment': 0}

In [25]:
test_weather.filter(test_weather['wind_direction'].isNull()).groupby('DATE').agg({'date_time':'count'})

DATE,count(date_time)
2020-11-09,8


In [26]:
test_weather.filter(test_weather['precipitation'].isNull()).groupby('DATE').agg({'date_time':'count'})

DATE,count(date_time)
2020-09-07,24
2020-09-09,24
2020-09-11,24
2020-09-08,25
2020-09-12,24
2020-09-14,24
2020-09-10,27


This null values dates does not matter since we will only merge with Jan-2020 taxi data

In [27]:
test_weather.filter(test_weather['precipitation'].isNull()).groupby(['DATE','hour']).agg({'date_time':'count'}).orderBy(['DATE','hour'])

DATE,hour,count(date_time)
2020-09-07,6,2
2020-09-07,7,1
2020-09-07,8,1
2020-09-07,9,2
2020-09-07,10,1
2020-09-07,11,1
2020-09-07,12,2
2020-09-07,13,1
2020-09-07,14,1
2020-09-07,15,2


It appears that there are 2 wide time gaps where precipitation data are null. By the past weather data from https://www.timeanddate.com/weather/usa/new-york/historic?month=12&year=2019, it appear that there is no rain within these hours. Hence, we will change these null values to 0 mm precipitation.

In [28]:
weather = weather.withColumn("precipitation", F.when(F.col("precipitation").isNull(), 0) \
                        .otherwise(F.col("precipitation"))).drop("time_segment")


In [29]:
dict_null = {col:weather.filter(weather[col].isNull()).count() 
             for col in weather.columns}
dict_null

{'date_time': 0,
 'temperature': 0,
 'dew_point': 0,
 'precipitation': 0,
 'wind_direction': 0,
 'pressure': 0,
 'visibility': 0,
 'DATE': 0,
 'hour': 0}

---
# Taxi Data Cleaning
#### 1. Remove dates that are outside range

In [30]:
taxi = taxi.withColumn("tpep_pickup_datetime",
                   to_timestamp(col("tpep_pickup_datetime"))) \
         .withColumn("pu_year", date_format(col("tpep_pickup_datetime"), "y"))
test_taxi = test_taxi.withColumn("tpep_pickup_datetime",
                   to_timestamp(col("tpep_pickup_datetime"))) \
         .withColumn("pu_year", date_format(col("tpep_pickup_datetime"), "y"))

In [31]:
taxi.agg({'tpep_pickup_datetime': 'max'}).show()
taxi.agg({'tpep_pickup_datetime': 'min'}).show()


                                                                                

+-------------------------+
|max(tpep_pickup_datetime)|
+-------------------------+
|      2090-12-31 06:41:26|
+-------------------------+





+-------------------------+
|min(tpep_pickup_datetime)|
+-------------------------+
|      2001-01-01 00:01:48|
+-------------------------+



                                                                                

In [32]:
taxi = taxi.filter(F.col('pu_year').isin([2018, 2019])).drop('pu_year')

test_taxi = test_taxi.withColumn("tpep_pickup_datetime",
                   to_timestamp(col("tpep_pickup_datetime"))) \
         .withColumn("pu_month", date_format(col("tpep_pickup_datetime"), "M"))

test_taxi = test_taxi.filter((F.col('pu_year') == 2020) &
                             (F.col('pu_month') == 1)).drop('pu_year', 'pu_month')

#### 2. Remove if drop-off timestamp is < pick-up timestamp

In [33]:
# create new column to calculate taxi duration in seconds
taxi = taxi.withColumn(
    'trip_duration',
    (col("tpep_dropoff_datetime").cast("long") - 
     col('tpep_pickup_datetime').cast("long")))

test_taxi = test_taxi.withColumn(
    'trip_duration',
    (col("tpep_dropoff_datetime").cast("long") - 
     col('tpep_pickup_datetime').cast("long")))

In [34]:
taxi = taxi.where((F.col('trip_duration') > 0))

#### 3. Remove if duration is more than 1 day (do - pu > 18000)

In [35]:
taxi.where((F.col('trip_duration') > 18000)).limit(5)

                                                                                

tpep_pickup_datetime,tpep_dropoff_datetime,pu_location_id,do_location_id,total_amount,trip_duration
2018-03-01 00:02:42,2018-03-01 23:52:20,48,143,6.8,85778
2018-03-01 00:05:51,2018-03-01 23:53:55,234,232,14.63,85684
2018-03-01 00:04:21,2018-03-01 23:28:33,43,107,15.96,84252
2018-03-01 00:34:36,2018-03-02 00:05:46,163,7,15.8,84670
2018-03-01 00:15:23,2018-03-01 23:39:52,114,79,6.8,84269


In [36]:
taxi = taxi.where((F.col('trip_duration') < 18000))
test_taxi = test_taxi.where((F.col('trip_duration') < 18000))

#### 4. Only select pick up location from Airports

In [37]:
zones = zones.dropna()

In [38]:
zones.dropna().loc[zones['Zone'].str.contains('Airport')]

Unnamed: 0,LocationID,Borough,Zone,service_zone
0,1,EWR,Newark Airport,EWR
131,132,Queens,JFK Airport,Airports
137,138,Queens,LaGuardia Airport,Airports


In [49]:
# average selling price for airport pickup trips
airport = taxi.where((F.col('pu_location_id').isin([1, 132, 138])))
airport.agg({
    'total_amount': 'sum',
    'pu_location_id': 'count'
})

                                                                                

sum(total_amount),count(pu_location_id)
513761515.9809384,10308204


In [48]:
# average selling price for all taxi trips
taxi.agg({
    'total_amount': 'sum',
    'pu_location_id': 'count'
})

                                                                                

sum(total_amount),count(pu_location_id)
3297405116.683793,186841578


In [53]:
# proportion of airport trip by total_amount 
5.137615159809384E8/3.297405116683793E9 * 100

15.580782397087725

In [54]:
# proportion of airport trip by number of trips 

10308204/186841578

0.05517082498628865

In [39]:
taxi = taxi.where((F.col('pu_location_id').isin([1, 132, 138])))
test_taxi = test_taxi.where((F.col('pu_location_id').isin([1, 132, 138])))

In [40]:
taxi.count()

                                                                                

10308204

In [41]:
test_taxi.count()

                                                                                

346338

#### 5. Check for null values

In [42]:
dict_null = {col:taxi.filter(taxi[col].isNull()).count() for col in taxi.columns}
dict_null

                                                                                

{'tpep_pickup_datetime': 0,
 'tpep_dropoff_datetime': 0,
 'pu_location_id': 0,
 'do_location_id': 0,
 'total_amount': 0,
 'trip_duration': 0}

In [43]:
dict_null = {col:test_taxi.filter(test_taxi[col].isNull()).count() for col in test_taxi.columns}
dict_null

                                                                                

{'tpep_pickup_datetime': 0,
 'tpep_dropoff_datetime': 0,
 'pu_location_id': 0,
 'do_location_id': 0,
 'total_amount': 0,
 'trip_duration': 0}

---
# Create new features

#### 1. From existing taxi dataset

In [44]:
# create new column pickup day of week
for abbr in ('pu', 'do'):
    if abbr == 'pu':
        long = "pickup"
    else:
        long = "dropoff"
    taxi = taxi.withColumn(f"tpep_{long}_datetime",
                       to_timestamp(col(f"tpep_{long}_datetime"))) \
             .withColumn(f"{abbr}_dow", dayofweek(col(f"tpep_{long}_datetime")))
    # create new column pickup hour
    taxi = taxi.withColumn(f'{abbr}_hour', hour(taxi[f'tpep_{long}_datetime']))
    # create helper column pickup date format
    taxi = taxi.withColumn(
                            f"{abbr}_date",
                            to_date(col(f"tpep_{long}_datetime"),"yyyy-MM-dd")
                          )
    
    
    test_taxi = test_taxi.withColumn(f"tpep_{long}_datetime",
                       to_timestamp(col(f"tpep_{long}_datetime"))) \
             .withColumn(f"{abbr}_dow", dayofweek(col(f"tpep_{long}_datetime")))
    # create new column pickup hour
    test_taxi = test_taxi.withColumn(f'{abbr}_hour', hour(test_taxi[f'tpep_{long}_datetime']))
    test_taxi = test_taxi.withColumn(
                            f"{abbr}_date",
                            to_date(col(f"tpep_{long}_datetime"),"yyyy-MM-dd")
                          )



#### 2a. From external dataset (School Holidays)

In [45]:
sch_hol = pd.read_csv("../data/raw/other_data/nyc_school_holiday.csv", sep=";")
sch_hol['DATE'] = pd.to_datetime(sch_hol['DATE'], format='%d/%m/%y')
sch_hol_date = sch_hol['DATE'].dt.date.tolist()

# create new column to identify if that day is school holiday
taxi = taxi.withColumn(
    'is_school_holiday',
    F.when(
        (F.col('pu_date').isin(sch_hol_date)),
        1
    ).otherwise(0)
)

test_taxi = test_taxi.withColumn(
    'is_school_holiday',
    F.when(
        (F.col('pu_date').isin(sch_hol_date)),
        1
    ).otherwise(0)
)

In [46]:
sch_hol.head()

Unnamed: 0,DATE,EVENT
0,2018-01-01,Winter Recess (Schools closed)
1,2018-01-15,Dr. Martin Luther King Jr. Day (schools closed)
2,2018-02-16,Lunar New Year (schools closed)
3,2018-02-19,Midwinter Recess (includes Washington’s Birthd...
4,2018-02-20,Midwinter Recess (includes Washington’s Birthd...


In [47]:
# check if 01-01-2018 is marked as holiday
taxi.filter(F.col('pu_date') == "2018-01-01").limit(5)

                                                                                

tpep_pickup_datetime,tpep_dropoff_datetime,pu_location_id,do_location_id,total_amount,trip_duration,pu_dow,pu_hour,pu_date,do_dow,do_hour,do_date,is_school_holiday
2018-01-01 00:49:19,2018-01-01 01:11:58,138,238,42.06,1359,2,0,2018-01-01,2,1,2018-01-01,1
2018-01-01 00:32:15,2018-01-01 00:59:44,132,61,33.3,1649,2,0,2018-01-01,2,0,2018-01-01,1
2018-01-01 00:37:31,2018-01-01 01:24:24,132,265,148.56,2813,2,0,2018-01-01,2,1,2018-01-01,1
2018-01-01 00:58:09,2018-01-01 01:35:58,132,220,90.66,2269,2,0,2018-01-01,2,1,2018-01-01,1
2018-01-01 00:52:28,2018-01-01 01:28:50,132,262,72.8,2182,2,0,2018-01-01,2,1,2018-01-01,1


#### 2b. From external dataset (Hourly Weather)

In [48]:
clean_weather = weather \
                    .groupBy(['date', 'hour']) \
                    .agg(
                        F.mean("temperature").alias("temperature"),
                        F.mean("dew_point").alias("dew_point"),
                        F.mean("precipitation").alias("precipitation"),
                        F.mean("wind_direction").alias("wind_direction"),
                        F.mean("visibility").alias("visibility"),
                        F.mean("pressure").alias("pressure"),
                    ) \
                    .orderBy(["date", 'hour'])

clean_test_weather = test_weather \
                    .groupBy(['date', 'hour']) \
                    .agg(
                        F.mean("temperature").alias("temperature"),
                        F.mean("dew_point").alias("dew_point"),
                        F.mean("precipitation").alias("precipitation"),
                        F.mean("wind_direction").alias("wind_direction"),
                        F.mean("visibility").alias("visibility"),
                        F.mean("pressure").alias("pressure"),
                    ) \
                    .orderBy(["date", 'hour'])

clean_weather.show()

+----------+----+-----------+-------------------+-------------+-----------------+----------+--------+
|      date|hour|temperature|          dew_point|precipitation|   wind_direction|visibility|pressure|
+----------+----+-----------+-------------------+-------------+-----------------+----------+--------+
|2018-01-01|   0|     -114.3|             -200.3|          0.0|            315.0|   16046.5| 10260.5|
|2018-01-01|   1|     -117.5|             -206.5|          0.0|            330.0|   16093.0| 10262.0|
|2018-01-01|   2|     -122.5|             -200.5|          0.0|            310.0|   16093.0| 10262.0|
|2018-01-01|   3|     -122.3|             -197.3|          0.0|            310.0|   16046.5| 10262.5|
|2018-01-01|   4|     -124.5|-198.23333333333335|          0.0|316.6666666666667|   16093.0| 10262.0|
|2018-01-01|   5|     -128.5|             -194.5|          0.0|            330.0|   16093.0| 10266.0|
|2018-01-01|   6|     -130.8|             -194.3|          0.0|            320.0| 

In [49]:
taxi.columns

['tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'pu_location_id',
 'do_location_id',
 'total_amount',
 'trip_duration',
 'pu_dow',
 'pu_hour',
 'pu_date',
 'do_dow',
 'do_hour',
 'do_date',
 'is_school_holiday']

In [50]:
grouped_taxi = (taxi
                .groupby(['pu_date', 'pu_hour', 'pu_dow',
                          'is_school_holiday'])
                .agg({'pu_date':'count'}).orderBy(['pu_date', 'pu_hour'])
                .withColumnRenamed('count(pu_date)', 'trip_freq'))

grouped_test_taxi = (test_taxi
                .groupby(['pu_date', 'pu_hour', 'pu_dow',
                          'is_school_holiday'])
                .agg({'pu_date':'count'}).orderBy(['pu_date', 'pu_hour'])
                .withColumnRenamed('count(pu_date)', 'trip_freq'))

In [51]:
grouped_test_taxi.limit(5)

                                                                                

pu_date,pu_hour,pu_dow,is_school_holiday,trip_freq
2020-01-01,0,4,1,211
2020-01-01,1,4,1,85
2020-01-01,2,4,1,46
2020-01-01,3,4,1,35
2020-01-01,4,4,1,51


In [52]:
grouped_taxi.limit(5)

                                                                                

pu_date,pu_hour,pu_dow,is_school_holiday,trip_freq
2018-01-01,0,2,1,240
2018-01-01,1,2,1,141
2018-01-01,2,2,1,33
2018-01-01,3,2,1,26
2018-01-01,4,2,1,50


In [53]:
# merge taxi and weather dataset on date and hour
# remove taxi data if weather is not available (inner join)
sdf = (grouped_taxi \
           .join(clean_weather, on=[grouped_taxi['pu_date'] == clean_weather['date'], 
                 grouped_taxi['pu_hour'] == clean_weather['hour']], how='inner')
           .drop(clean_weather['date'])
           .drop(clean_weather['hour'])
      )

test_sdf = (grouped_test_taxi \
           .join(clean_test_weather, on=[grouped_test_taxi['pu_date'] == clean_test_weather['date'], 
                 grouped_test_taxi['pu_hour'] == clean_test_weather['hour']], how='inner')
           .drop(clean_test_weather['date'])
           .drop(clean_test_weather['hour'])
      )

In [54]:
# taxi number of rows before merge vs after weather merge
(grouped_taxi.count(), sdf.count())

                                                                                

(17518, 17518)

In [55]:
(grouped_test_taxi.count(), test_sdf.count())

                                                                                

(744, 744)

In [56]:
# check if merge matches perfectly
clean_weather.filter((F.col('date') == "2018-03-01") & (F.col('hour').isin([0,1,2]))).limit(3)

date,hour,temperature,dew_point,precipitation,wind_direction,visibility,pressure
2018-03-01,0,67.3,44.3,0.0,200.0,16046.5,10118.5
2018-03-01,1,83.5,44.5,0.0,200.0,16093.0,10117.0
2018-03-01,2,83.5,50.5,0.0,170.0,16093.0,10114.0


In [57]:
grouped_taxi.filter((F.col('pu_date') == "2018-03-01") & (F.col('pu_hour').isin([0,1,2]))).limit(3)

                                                                                

pu_date,pu_hour,pu_dow,is_school_holiday,trip_freq
2018-03-01,0,5,0,310
2018-03-01,1,5,0,59
2018-03-01,2,5,0,6


In [58]:
sdf.filter((F.col('pu_date') == "2018-03-01") & (F.col('pu_hour').isin([0,1,2]))).limit(3)

                                                                                

pu_date,pu_hour,pu_dow,is_school_holiday,trip_freq,temperature,dew_point,precipitation,wind_direction,visibility,pressure
2018-03-01,2,5,0,6,83.5,50.5,0.0,170.0,16093.0,10114.0
2018-03-01,0,5,0,310,67.3,44.3,0.0,200.0,16046.5,10118.5
2018-03-01,1,5,0,59,83.5,44.5,0.0,200.0,16093.0,10117.0


In [59]:
dict_null = {col:sdf.filter(sdf[col].isNull()).count() for col in sdf.columns}
dict_null

                                                                                

{'pu_date': 0,
 'pu_hour': 0,
 'pu_dow': 0,
 'is_school_holiday': 0,
 'trip_freq': 0,
 'temperature': 0,
 'dew_point': 0,
 'precipitation': 0,
 'wind_direction': 0,
 'visibility': 0,
 'pressure': 0}

### Sample for visualization

In [60]:
SAMPLE_SIZE = 0.05
df = taxi.sample(SAMPLE_SIZE, seed=0).toPandas()
df.to_parquet('../data/curated/taxi_sample.parquet')

                                                                                

### Save final sdf to curated data folder

In [61]:
sdf.write.parquet("../data/curated/merged_df.parquet")

                                                                                

In [62]:
test_sdf.write.parquet("../data/curated/merged_test_df.parquet")

                                                                                

In [63]:
clean_weather.write.parquet("../data/curated/clean_weather.parquet")

In [64]:
taxi.write.parquet("../data/curated/taxi.parquet")

                                                                                