In [96]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import substring

# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Tutorial 2")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config('spark.driver.memory', '4g')
    .config('spark.executor.memory', '2g')
    .getOrCreate()
)

# Preprocess Weather Data

In [97]:
wdf = spark.read.parquet('../data/raw/weathers/weather-2022')
wdf_schema = wdf.schema

                                                                                

In [98]:
# Combine the weather dataset of year 2022 and 2023
wdf = spark.read.schema(wdf_schema).parquet('../data/raw/weathers/*')

wdf.show(1, vertical=True, truncate=100)

-RECORD 0-------------------
 date | 2022-01-01T00:00:00 
 wnd  | 160,1,N,0046,1      
 cig  | 99999,9,9,N         
 vis  | 016000,1,9,9        
 tmp  | +0106,1             
 dew  | +0078,1             
 slp  | 10141,1             
only showing top 1 row



In [99]:
wdf.count()

22380

## Clean weather data 

In [100]:
from pyspark.sql.functions import to_timestamp, regexp_replace

# replace the character 'T' in the date column with ' ' (space) 
wdf_cleaned = wdf.withColumn("Date", regexp_replace(F.col("Date"), "T", " "))

In [101]:
# convert the date from data type of string to timestamp
wdf = wdf_cleaned.withColumn("Date", to_timestamp(F.col("Date"), "yyyy-MM-dd HH:mm:ss"))

In [102]:
# Define the date range for analysis (inclusive of bound)
dates = ("2022-11-01",  "2023-05-31")
wdf = wdf.where(F.col('date').between(*dates))

In [103]:
wdf.count()

7889

In [104]:
def extract(sdf, start_index, end_index, extracted_column):
    """
    Extract the useful information from each entry accordingly

    Parameters:
    - sdf: Pyspark dataframe ready to be extracted
    - start_index: Starting index of the useful information
    - end_index: Ending index of the useful information
    - extracted_column: Column where the info to be extracted
    """
    sdf = sdf.withColumn(extracted_column, 
        (substring(F.col(extracted_column), start_index, end_index - start_index + 1)).cast("double"))
    return sdf

In [105]:
# Take out the information from each column except the 'date' column
wdf = extract(wdf, 9, 12, "wnd")
wdf = extract(wdf, 1, 5, "cig")
wdf = extract(wdf, 1, 6, "vis")
wdf = extract(wdf, 1, 5, "tmp")
wdf = extract(wdf, 1, 5, "dew")
wdf = extract(wdf, 1, 5, "slp")

In [106]:
wdf.show()

+-------------------+------+-------+--------+------+------+-------+
|               Date|   wnd|    cig|     vis|   tmp|   dew|    slp|
+-------------------+------+-------+--------+------+------+-------+
|2023-01-01 00:15:00|   0.0|   91.0|   805.0|  78.0|  72.0|99999.0|
|2023-01-01 00:51:00|  15.0|   91.0|   805.0|  78.0|  72.0|10098.0|
|2023-01-01 01:51:00|   0.0|   91.0|   805.0|  83.0|  72.0|10086.0|
|2023-01-01 02:38:00|   0.0|  122.0|  3219.0|  89.0|  83.0|99999.0|
|2023-01-01 02:51:00|   0.0|  122.0|  3219.0|  83.0|  78.0|10086.0|
|2023-01-01 03:00:00|   0.0|99999.0|  3200.0|  83.0|  78.0|10086.0|
|2023-01-01 03:49:00|  31.0|  122.0| 12875.0| 120.0| 110.0|99999.0|
|2023-01-01 03:51:00|  31.0|  122.0| 12875.0| 122.0| 111.0|10074.0|
|2023-01-01 04:51:00|  31.0|  122.0| 16093.0| 122.0| 106.0|10073.0|
|2023-01-01 04:59:00|9999.0|99999.0|999999.0|9999.0|9999.0|99999.0|
|2023-01-01 04:59:00|9999.0|99999.0|999999.0|9999.0|9999.0|99999.0|
|2023-01-01 05:19:00|  15.0|  518.0| 16093.0| 11

In [107]:
wdf.printSchema()

root
 |-- Date: timestamp (nullable = true)
 |-- wnd: double (nullable = true)
 |-- cig: double (nullable = true)
 |-- vis: double (nullable = true)
 |-- tmp: double (nullable = true)
 |-- dew: double (nullable = true)
 |-- slp: double (nullable = true)



In [109]:
def cal_missing_propor(sdf,column, num):
    """
    Calculate the proportion of missing values in a specific column

    Parameters:
    - sdf: Pyspark dataframe 
    - column: Column that are going to be examine
    - num: number of figures that represent the missing values 
    """   
    total_num = wdf.count()
    invalid_record = "9" * num
    missing_count = sdf.filter(F.col(column) == int(invalid_record)).count()
    return missing_count / total_num


In [110]:
cal_missing_propor(wdf, "wnd", 4)

                                                                                

0.027633413613892762

In [111]:
cal_missing_propor(wdf, "cig", 5)

0.17049055647103561

In [112]:
cal_missing_propor(wdf, "vis", 6)

0.027633413613892762

In [113]:
cal_missing_propor(wdf, "tmp", 4)

0.027633413613892762

In [114]:
cal_missing_propor(wdf, "dew", 4)

0.027633413613892762

In [115]:
cal_missing_propor(wdf, "slp", 5)

0.14767397642286728

In [116]:
# Discard all the rows that contain missing values
wdf_filtered = wdf.where(
    (F.col('wnd') != 9999)
    & (F.col('cig') != 99999)
    & (F.col('vis') != 999999)
    & (F.col('tmp') != 9999)
    & (F.col('dew') != 9999)
    & (F.col('slp') != 99999)
)

In [117]:
wdf_filtered.show()

+-------------------+----+-------+-------+-----+-----+-------+
|               Date| wnd|    cig|    vis|  tmp|  dew|    slp|
+-------------------+----+-------+-------+-----+-----+-------+
|2023-01-01 00:51:00|15.0|   91.0|  805.0| 78.0| 72.0|10098.0|
|2023-01-01 01:51:00| 0.0|   91.0|  805.0| 83.0| 72.0|10086.0|
|2023-01-01 02:51:00| 0.0|  122.0| 3219.0| 83.0| 78.0|10086.0|
|2023-01-01 03:51:00|31.0|  122.0|12875.0|122.0|111.0|10074.0|
|2023-01-01 04:51:00|31.0|  122.0|16093.0|122.0|106.0|10073.0|
|2023-01-01 05:51:00|31.0|  518.0|16093.0|117.0|100.0|10072.0|
|2023-01-01 06:51:00|51.0| 2438.0|16093.0|122.0|100.0|10067.0|
|2023-01-01 07:51:00|51.0|22000.0|16093.0|122.0| 94.0|10077.0|
|2023-01-01 08:51:00|46.0|22000.0|16093.0|122.0| 89.0|10080.0|
|2023-01-01 09:00:00|46.0|22000.0|16000.0|122.0| 89.0|10080.0|
|2023-01-01 09:51:00|26.0|22000.0|16093.0|117.0| 78.0|10084.0|
|2023-01-01 10:51:00|51.0|22000.0|16093.0|117.0| 72.0|10093.0|
|2023-01-01 11:51:00|41.0|22000.0|16093.0|111.0| 72.0|1

In [118]:
wdf_filtered.count()

5601

In [124]:
# Extract the date and hour into individual column
wdf_filtered = wdf_filtered.withColumn("hour", F.hour("date"))
wdf_filtered = wdf_filtered.withColumn("date", F.to_date("date"))
# Create a column that shows the day in week of a date
wdf_filtered = wdf_filtered.withColumn("day_week", F.dayofweek("date"))

wdf_filtered.show()

+----------+----+-------+-------+-----+-----+-------+----+--------+
|      date| wnd|    cig|    vis|  tmp|  dew|    slp|hour|day_week|
+----------+----+-------+-------+-----+-----+-------+----+--------+
|2023-01-01|15.0|   91.0|  805.0| 78.0| 72.0|10098.0|   0|       1|
|2023-01-01| 0.0|   91.0|  805.0| 83.0| 72.0|10086.0|   1|       1|
|2023-01-01| 0.0|  122.0| 3219.0| 83.0| 78.0|10086.0|   2|       1|
|2023-01-01|31.0|  122.0|12875.0|122.0|111.0|10074.0|   3|       1|
|2023-01-01|31.0|  122.0|16093.0|122.0|106.0|10073.0|   4|       1|
|2023-01-01|31.0|  518.0|16093.0|117.0|100.0|10072.0|   5|       1|
|2023-01-01|51.0| 2438.0|16093.0|122.0|100.0|10067.0|   6|       1|
|2023-01-01|51.0|22000.0|16093.0|122.0| 94.0|10077.0|   7|       1|
|2023-01-01|46.0|22000.0|16093.0|122.0| 89.0|10080.0|   8|       1|
|2023-01-01|46.0|22000.0|16000.0|122.0| 89.0|10080.0|   9|       1|
|2023-01-01|26.0|22000.0|16093.0|117.0| 78.0|10084.0|   9|       1|
|2023-01-01|51.0|22000.0|16093.0|117.0| 72.0|100

In [125]:
# unscaled the features based on the data dictionary
scaled_col = ['wnd', 'tmp', 'dew', 'slp']
for column in scaled_col:
    wdf_filtered = wdf_filtered.withColumn(column, F.col(column)/ 10)

In [126]:
wdf_filtered.show()

+----------+---+-------+-------+----+----+------+----+--------+
|      date|wnd|    cig|    vis| tmp| dew|   slp|hour|day_week|
+----------+---+-------+-------+----+----+------+----+--------+
|2023-01-01|1.5|   91.0|  805.0| 7.8| 7.2|1009.8|   0|       1|
|2023-01-01|0.0|   91.0|  805.0| 8.3| 7.2|1008.6|   1|       1|
|2023-01-01|0.0|  122.0| 3219.0| 8.3| 7.8|1008.6|   2|       1|
|2023-01-01|3.1|  122.0|12875.0|12.2|11.1|1007.4|   3|       1|
|2023-01-01|3.1|  122.0|16093.0|12.2|10.6|1007.3|   4|       1|
|2023-01-01|3.1|  518.0|16093.0|11.7|10.0|1007.2|   5|       1|
|2023-01-01|5.1| 2438.0|16093.0|12.2|10.0|1006.7|   6|       1|
|2023-01-01|5.1|22000.0|16093.0|12.2| 9.4|1007.7|   7|       1|
|2023-01-01|4.6|22000.0|16093.0|12.2| 8.9|1008.0|   8|       1|
|2023-01-01|4.6|22000.0|16000.0|12.2| 8.9|1008.0|   9|       1|
|2023-01-01|2.6|22000.0|16093.0|11.7| 7.8|1008.4|   9|       1|
|2023-01-01|5.1|22000.0|16093.0|11.7| 7.2|1009.3|  10|       1|
|2023-01-01|4.1|22000.0|16093.0|11.1| 7.

In [127]:
# aggregate the weather data so its time granularity decreases to hour
hourly_weather = wdf_filtered\
    .groupBy('date', 'hour', 'day_week')\
    .agg(
        # taking the mean for every feature
        F.mean('wnd').alias('wind_speed'),
        F.mean('CIG').alias('ceiling_height'),
        F.mean('VIS').alias('distance_dimension'),
        F.mean('TMP').alias('temperature'),
        F.mean('DEW').alias('dew_point_temp'),
        F.mean('SLP').alias('atm_pressure'),
    )\
    .orderBy('date', 'hour')
hourly_weather.show()



+----------+----+--------+----------+--------------+------------------+-----------+--------------+------------+
|      date|hour|day_week|wind_speed|ceiling_height|distance_dimension|temperature|dew_point_temp|atm_pressure|
+----------+----+--------+----------+--------------+------------------+-----------+--------------+------------+
|2022-11-01|   0|       3|       2.6|         792.0|           14484.0|       16.1|          12.2|      1016.4|
|2022-11-01|   1|       3|       2.6|         823.0|           16093.0|       16.1|          12.8|      1016.5|
|2022-11-01|   2|       3|       3.1|         975.0|           16093.0|       16.1|          12.8|      1016.3|
|2022-11-01|   3|       3|       3.1|        3048.0|           11265.0|       15.6|          12.8|      1016.3|
|2022-11-01|   4|       3|       3.1|         335.0|           12875.0|       15.6|          12.8|      1016.3|
|2022-11-01|   5|       3|       2.6|         244.0|           12875.0|       15.6|          13.3|      

                                                                                

In [128]:
# save the result
hourly_weather.write.mode('overwrite').parquet('../data/curated/hourly_weather')

                                                                                

# Preprocessing taxi data

In [129]:
ydf = spark.read.parquet('../data/raw/2023-01-yellow')
sdf_schema = ydf.schema
sdf_schema

StructType([StructField('trip_distance', DoubleType(), True), StructField('tpep_pickup_datetime', TimestampNTZType(), True), StructField('tpep_dropoff_datetime', TimestampNTZType(), True), StructField('pulocationid', IntegerType(), True), StructField('dolocationid', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('passenger_count', LongType(), True)])

In [130]:
# Combine all the taxi records together
sdf = spark.read.schema(sdf_schema).parquet('../data/raw/*')

sdf.show(1, vertical=True, truncate=100)

-RECORD 0------------------------------------
 trip_distance         | 2.0                 
 tpep_pickup_datetime  | 2022-12-01 00:37:35 
 tpep_dropoff_datetime | 2022-12-01 00:47:35 
 pulocationid          | 170                 
 dolocationid          | 237                 
 fare_amount           | 8.5                 
 passenger_count       | 1                   
only showing top 1 row



In [131]:
sdf.count()

23313034

In [132]:
sdf.show()

+-------------+--------------------+---------------------+------------+------------+-----------+---------------+
|trip_distance|tpep_pickup_datetime|tpep_dropoff_datetime|pulocationid|dolocationid|fare_amount|passenger_count|
+-------------+--------------------+---------------------+------------+------------+-----------+---------------+
|          2.0| 2022-12-01 00:37:35|  2022-12-01 00:47:35|         170|         237|        8.5|              1|
|          8.4| 2022-12-01 00:34:35|  2022-12-01 00:55:21|         138|         141|       26.0|              0|
|          0.8| 2022-12-01 00:33:26|  2022-12-01 00:37:34|         140|         140|        5.0|              1|
|          3.0| 2022-12-01 00:45:51|  2022-12-01 00:53:16|         141|          79|       10.0|              1|
|         0.76| 2022-12-01 00:49:49|  2022-12-01 00:54:13|         261|         231|        5.0|              1|
|          2.6| 2022-12-01 00:25:25|  2022-12-01 00:35:38|         237|         164|       10.5|

In [133]:
# Define the date range (inclusive of bound)
dates = ("2022-11-01",  "2023-05-31")

sdf.where(
    (F.col('tpep_pickup_datetime').between(*dates))
    & (F.col('tpep_dropoff_datetime').between(*dates))
).count()

                                                                                

22723065

In [134]:
# Define the date range (inclusive of bound)
dates = ("2022-11-01",  "2023-05-31")

# outlier detection and remove the outliers
sdf_filtered = sdf.where(
    (F.col('passenger_count') > 0)  # remove negative passenger count
    & (F.col('trip_distance') > 0.5) # remove negative or short trip distance
    & (F.col('pulocationid') >= 1)   # ensure the location is within the range 1-263
    & (F.col('pulocationid') <= 263)
    & (F.col('dolocationid') >= 1)
    & (F.col('dolocationid') <= 263)
    & (F.col('fare_amount') > 2.5)  # remove negative or invalid fare amount
    & (F.col('tpep_pickup_datetime').between(*dates))   #specify the data range
    & (F.col('tpep_dropoff_datetime').between(*dates))
    & (F.col('tpep_dropoff_datetime') > F.col('tpep_pickup_datetime')) # remove negative trip duration
)

In [135]:
sdf_filtered.show()

+-------------+--------------------+---------------------+------------+------------+-----------+---------------+
|trip_distance|tpep_pickup_datetime|tpep_dropoff_datetime|pulocationid|dolocationid|fare_amount|passenger_count|
+-------------+--------------------+---------------------+------------+------------+-----------+---------------+
|          2.0| 2022-12-01 00:37:35|  2022-12-01 00:47:35|         170|         237|        8.5|              1|
|          0.8| 2022-12-01 00:33:26|  2022-12-01 00:37:34|         140|         140|        5.0|              1|
|          3.0| 2022-12-01 00:45:51|  2022-12-01 00:53:16|         141|          79|       10.0|              1|
|         0.76| 2022-12-01 00:49:49|  2022-12-01 00:54:13|         261|         231|        5.0|              1|
|          2.6| 2022-12-01 00:25:25|  2022-12-01 00:35:38|         237|         164|       10.5|              2|
|         0.94| 2022-12-01 00:05:37|  2022-12-01 00:10:48|          79|         144|        5.5|

In [136]:
sdf_filtered.count()

                                                                                

19989535

In [158]:
# number of outliers being removed 
num_removed = 23313034 - 19989535
num_removed

3323499

In [159]:
# proportion of dataset being removed 
proportion_remove = 3323499 / 23313034
proportion_remove

0.14255969428946916

In [137]:
# Extract the date and hour into individual column
sdf_transformed = sdf_filtered.withColumn("hour", F.hour("tpep_pickup_datetime"))
sdf_transformed = sdf_transformed.withColumn("date", F.to_date("tpep_pickup_datetime"))
# Create a column that shows the day in week of a date
sdf_transformed = sdf_transformed.withColumn("day_week", F.dayofweek("tpep_pickup_datetime"))



In [138]:
sdf_transformed.show()

+-------------+--------------------+---------------------+------------+------------+-----------+---------------+----+--------+----------+
|trip_distance|tpep_pickup_datetime|tpep_dropoff_datetime|pulocationid|dolocationid|fare_amount|passenger_count|hour|day_week|      date|
+-------------+--------------------+---------------------+------------+------------+-----------+---------------+----+--------+----------+
|          2.0| 2022-12-01 00:37:35|  2022-12-01 00:47:35|         170|         237|        8.5|              1|   0|       5|2022-12-01|
|          0.8| 2022-12-01 00:33:26|  2022-12-01 00:37:34|         140|         140|        5.0|              1|   0|       5|2022-12-01|
|          3.0| 2022-12-01 00:45:51|  2022-12-01 00:53:16|         141|          79|       10.0|              1|   0|       5|2022-12-01|
|         0.76| 2022-12-01 00:49:49|  2022-12-01 00:54:13|         261|         231|        5.0|              1|   0|       5|2022-12-01|
|          2.6| 2022-12-01 00:25:2

In [139]:
# Group the data based on date, hour and location to get the hourly demand in a specific region
aggregated_result = sdf_transformed\
    .groupBy('date', 'hour', 'day_week','pulocationid')\
    .agg(
        F.count("*").alias("num_trips"), 
    )\
    .orderBy('date', 'hour', 'pulocationid')

In [140]:
aggregated_result.show()



+----------+----+--------+------------+---------+
|      date|hour|day_week|pulocationid|num_trips|
+----------+----+--------+------------+---------+
|2022-11-01|   0|       3|           4|        4|
|2022-11-01|   0|       3|           7|        2|
|2022-11-01|   0|       3|          13|        2|
|2022-11-01|   0|       3|          24|        2|
|2022-11-01|   0|       3|          41|        4|
|2022-11-01|   0|       3|          42|        2|
|2022-11-01|   0|       3|          43|        4|
|2022-11-01|   0|       3|          45|        1|
|2022-11-01|   0|       3|          48|       89|
|2022-11-01|   0|       3|          50|       10|
|2022-11-01|   0|       3|          66|        1|
|2022-11-01|   0|       3|          68|       48|
|2022-11-01|   0|       3|          70|        2|
|2022-11-01|   0|       3|          74|        2|
|2022-11-01|   0|       3|          75|        2|
|2022-11-01|   0|       3|          79|      120|
|2022-11-01|   0|       3|          80|        3|


                                                                                

In [141]:
aggregated_result.count()

                                                                                

438305

In [142]:
# Create a column to indicate whether a date is public holiday or weekend
holiday = ['2022-11-24', '2022-12-25', '2023-01-01', '2023-01-16', '2023-02-20', '2023-05-29']
aggregated_taxi = aggregated_result.withColumn("is_weekend_holiday",
    F.col("day_week").isin([6,7]) | F.col("date").isin(holiday))

In [143]:
aggregated_taxi.write.mode('overwrite').parquet('../data/curated/agg_taxi')

                                                                                

# Combination of Weather and Taxi Data

In [144]:
# join the weather and taxi data
hourly_demand = aggregated_taxi.join(hourly_weather, 
    on = ['date', 'hour', 'day_week'], how = "inner")\
    .orderBy('date', 'hour','pulocationid')

# Sample data

In [145]:
SAMPLE_SIZE = 0.20

In [146]:
# subsample the data for visualization and modelling
full_df = hourly_demand.sample(SAMPLE_SIZE, seed = 0).toPandas()
full_df.to_parquet('../data/curated/location_sample_data.parquet')

                                                                                

In [147]:
hourly_demand.write.mode('overwrite').parquet('../data/curated/location_demand')

                                                                                

In [148]:
import pandas as pd

In [149]:
# one-hot encoding for location id (categorical attribute)
hourly_demand = pd.get_dummies(hourly_demand.toPandas(), columns = ['pulocationid'])

                                                                                

In [150]:
hourly_demand.head()

Unnamed: 0,date,hour,day_week,num_trips,is_weekend_holiday,wind_speed,ceiling_height,distance_dimension,temperature,dew_point_temp,...,pulocationid_254,pulocationid_255,pulocationid_256,pulocationid_257,pulocationid_258,pulocationid_259,pulocationid_260,pulocationid_261,pulocationid_262,pulocationid_263
0,2022-11-01,0,3,4,False,2.6,792.0,14484.0,16.1,12.2,...,0,0,0,0,0,0,0,0,0,0
1,2022-11-01,0,3,2,False,2.6,792.0,14484.0,16.1,12.2,...,0,0,0,0,0,0,0,0,0,0
2,2022-11-01,0,3,2,False,2.6,792.0,14484.0,16.1,12.2,...,0,0,0,0,0,0,0,0,0,0
3,2022-11-01,0,3,2,False,2.6,792.0,14484.0,16.1,12.2,...,0,0,0,0,0,0,0,0,0,0
4,2022-11-01,0,3,4,False,2.6,792.0,14484.0,16.1,12.2,...,0,0,0,0,0,0,0,0,0,0


In [151]:
hourly_demand.to_parquet('../data/curated/hourly_demand.parquet')

In [152]:
hourly_demand = spark.read.parquet('../data/curated/hourly_demand.parquet')

In [153]:
hourly_demand.count()

438234

In [154]:
# sample the may data which is used for prediction later
may_month = ("2023-05-01",  "2023-05-31")
demand_may = hourly_demand.filter(F.col('Date').between(*may_month))

In [155]:
demand_may.show()

[Stage 326:>                                                        (0 + 1) / 1]

+----------+----+--------+---------+------------------+----------+--------------+------------------+-----------+--------------+------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+-----

                                                                                

In [156]:
demand_may.count()

63869

In [164]:
# sample the data to be used for training data in modelling
train_months = ("2022-11-01", '2023-04-30')
df = hourly_demand.filter(F.col('date').between(*train_months)).sample(SAMPLE_SIZE, seed = 0)
df.write.mode('overwrite').parquet('../data/curated/sample_data.parquet')


                                                                                

In [165]:
df.count()

74806

In [166]:
may_df = demand_may.sample(SAMPLE_SIZE, seed = 0).toPandas()
may_df.to_parquet('../data/curated/may_sample_data.parquet')

                                                                                

In [167]:
print(len(may_df))

12825
