### Datatypes and Columns Schema

In [1]:
spark

In [2]:
import io
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# Import some modules we will need later on
from pyspark.sql.functions import col, isnan, when, count, udf, to_date, year, month, date_format, size, split
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, MinMaxScaler
from pyspark.sql import SparkSession
from google.cloud import storage

In [4]:
#Read only 2019 cleaned trip data's files
sdf = spark.read.parquet("XXX/fhvhv_tripdata_2019-*.parquet_cleaned.parquet")

                                                                                

In [5]:
sdf.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- request_datetime: timestamp_ntz (nullable = true)
 |-- pickup_datetime: timestamp_ntz (nullable = true)
 |-- dropoff_datetime: timestamp_ntz (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: long (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- wav_request_flag: string (nullable = true)



In [6]:
sdf.columns

['hvfhs_license_num',
 'dispatching_base_num',
 'request_datetime',
 'pickup_datetime',
 'dropoff_datetime',
 'PULocationID',
 'DOLocationID',
 'trip_miles',
 'trip_time',
 'base_passenger_fare',
 'tolls',
 'bcf',
 'sales_tax',
 'congestion_surcharge',
 'tips',
 'driver_pay',
 'shared_request_flag',
 'shared_match_flag',
 'wav_request_flag']

In [7]:
#TOTAL NUMBER OF ROWS = 206 MILION
sdf.count()

                                                                                

206972656

### List of columns by dtype


In [8]:
# String
string_sdf = ['hvfhs_license_num','dispatching_base_num', 'shared_request_flag', 'shared_match_flag', 'wav_request_flag', 'yearmonth', 'dayofweek']

#Integer
int_sdf = [ 'PULocationID', 'DOLocationID', 'year', 'month']

#Datetime
date_sdf = ['pickup_datetime', 'dropoff_datetime', 'request_datetime']

#Float/double
float_sdf = ['trip_miles','trip_time','base_passenger_fare','tolls','bcf', 'sales_tax', 'congestion_surcharge', 'airport_fee','tips','driver_pay', 'weekend'] #Note: Weekend only has binary values 1.0 or 0.0


In [9]:
from pyspark.sql.functions import col

# Apply DoubleType to both 'trip_time' and 'airport_fee'
sdf = sdf.withColumn("trip_time", col("trip_time").cast("double"))

# Show the schema and data to confirm the changes
sdf.printSchema()
sdf.show()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- request_datetime: timestamp_ntz (nullable = true)
 |-- pickup_datetime: timestamp_ntz (nullable = true)
 |-- dropoff_datetime: timestamp_ntz (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: double (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- wav_request_flag: string (nullable = true)



                                                                                

+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+-----------------+----------------+
|hvfhs_license_num|dispatching_base_num|   request_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls| bcf|sales_tax|congestion_surcharge|tips|driver_pay|shared_request_flag|shared_match_flag|wav_request_flag|
+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+-----------------+----------------+
|           HV0003|              B02867|2019-02-01 00:01:26|2019-02-01 00:05:18|2019-02-01 00:14:57|         245|         251|      2.45|    579.0|          

### Extracting date time features

In [10]:
# 1. Extract date and time features (Year, Month, Day, Day of Week, Quarter, Hour)
sdf = sdf.withColumn("year", year(col("pickup_datetime")))
sdf = sdf.withColumn("month", month(col("pickup_datetime")))   # Numeric month like 11
sdf = sdf.withColumn("yearmonth", date_format(col("pickup_datetime"), "yyyy-MM"))   # Like 2023-01   2023-02 etc.
sdf = sdf.withColumn("dayofweek", date_format(col("pickup_datetime"), "EEEE"))         # 'Monday' 'Tuesday' etc.
sdf = sdf.withColumn("weekend", when(sdf.dayofweek == 'Saturday',1.0).when(sdf.dayofweek == 'Sunday', 1.0).otherwise(0))

# Check columns to see if we got good values
sdf.select(["year", "month", "yearmonth", "dayofweek", "weekend"]).show()

[Stage 5:>                                                          (0 + 1) / 1]

+----+-----+---------+---------+-------+
|year|month|yearmonth|dayofweek|weekend|
+----+-----+---------+---------+-------+
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
|2019|    2|  2019-02|   Friday|    0.0|
+----+-----+---------+---------+-------+
only showing top

                                                                                

In [11]:
#UPDATED Schema with new datetime extractions
sdf.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- request_datetime: timestamp_ntz (nullable = true)
 |-- pickup_datetime: timestamp_ntz (nullable = true)
 |-- dropoff_datetime: timestamp_ntz (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: double (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- wav_request_flag: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- yearmonth: string (nullable = true)
 |-- 

### String Indexer

In [12]:
input_cols = ['hvfhs_license_num', 'dispatching_base_num', 'shared_request_flag', 
              'shared_match_flag', 'wav_request_flag', 'yearmonth', 'dayofweek']

output_cols = [col + '_index' for col in input_cols]

# Apply StringIndexer and apply transformations
indexer = StringIndexer(inputCols=input_cols, outputCols=output_cols)
indexed_sdf = indexer.fit(sdf).transform(sdf)

# Show the result
indexed_sdf.show()

24/11/14 22:33:13 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+-----------------+----------------+----+-----+---------+---------+-------+-----------------------+--------------------------+-------------------------+-----------------------+----------------------+---------------+---------------+
|hvfhs_license_num|dispatching_base_num|   request_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls| bcf|sales_tax|congestion_surcharge|tips|driver_pay|shared_request_flag|shared_match_flag|wav_request_flag|year|month|yearmonth|dayofweek|weekend|hvfhs_license_num_index|dispatching_base_num_index|shared_request_flag_index|shared_match_flag_index|wav_request_flag_index|yearmonth_index|dayofweek_index|
+-----------------+--------------------+--------

                                                                                

In [13]:
print(output_cols)

['hvfhs_license_num_index', 'dispatching_base_num_index', 'shared_request_flag_index', 'shared_match_flag_index', 'wav_request_flag_index', 'yearmonth_index', 'dayofweek_index']


In [14]:
# Show only the newly indexed columns
indexed_sdf.select(output_cols).show()

+-----------------------+--------------------------+-------------------------+-----------------------+----------------------+---------------+---------------+
|hvfhs_license_num_index|dispatching_base_num_index|shared_request_flag_index|shared_match_flag_index|wav_request_flag_index|yearmonth_index|dayofweek_index|
+-----------------------+--------------------------+-------------------------+-----------------------+----------------------+---------------+---------------+
|                    0.0|                      20.0|                      1.0|                    0.0|                   0.0|            7.0|            1.0|
|                    0.0|                      18.0|                      0.0|                    0.0|                   0.0|            7.0|            1.0|
|                    1.0|                       0.0|                      0.0|                    1.0|                   0.0|            7.0|            1.0|
|                    1.0|                       0.0|

                                                                                

In [15]:
### CODE FOR UNDERSTANDING INDEX AND THEIR CORRESPONDING LABELS
# Retrieve the StringIndexerModel
indexer_model = indexer.fit(sdf)

# Access the labelsArray for the 'yearmonth' column (it should be the 6th entry, index 5)
yearmonth_labels = indexer_model.labelsArray[5]  # Since 'yearmonth' was the 6th column in input_cols

# Print the mapping for yearmonth_index
print(f"Mapping for yearmonth_index: {yearmonth_labels}")




Mapping for yearmonth_index: ('2019-03', '2019-05', '2019-12', '2019-04', '2019-11', '2019-06', '2019-10', '2019-02', '2019-09', '2019-07', '2019-08')


                                                                                

### One-Hot Encoding + Vector Assembler

In [17]:
# OneHotEncoder to encode the indexed columns and integer columns
encoder = OneHotEncoder(
    inputCols=['PULocationID', 'DOLocationID', 'year', 'month',
               'hvfhs_license_num_index', 'dispatching_base_num_index', 
               'shared_request_flag_index', 'shared_match_flag_index', 
               'wav_request_flag_index', 'yearmonth_index', 'dayofweek_index'],
    outputCols=['PULocationID_vector', 'DOLocationID_vector', 'year_vector', 'month_vector',
                'hvfhs_license_num_index_vector', 'dispatching_base_num_index_vector',
                'shared_request_flag_index_vector', 'shared_match_flag_index_vector',
                'wav_request_flag_index_vector', 'yearmonth_index_vector', 'dayofweek_index_vector'],
    dropLast=False  # Keep all categories
)

# Apply OneHotEncoder transformation
encoded_sdf = encoder.fit(indexed_sdf).transform(indexed_sdf)

# VectorAssembler to combine the encoded columns and integer columns into a single vector column
assembler = VectorAssembler(
    inputCols=['PULocationID_vector', 'DOLocationID_vector', 'year_vector', 'month_vector',
               'hvfhs_license_num_index_vector', 'dispatching_base_num_index_vector',
               'shared_request_flag_index_vector', 'shared_match_flag_index_vector',
               'wav_request_flag_index_vector', 'yearmonth_index_vector', 'dayofweek_index_vector',
               'trip_miles', 'trip_time', 'base_passenger_fare', 'tolls', 'bcf', 'sales_tax',
               'congestion_surcharge', 'tips', 'driver_pay', 'weekend'],  # float columns are included
    outputCol="features" 
)

# Apply VectorAssembler transformation
assembled_sdf = assembler.transform(encoded_sdf)

# Show the resulting DataFrame ( a portion)
assembled_sdf.select(['PULocationID_vector', 'DOLocationID_vector', 'year_vector', 'month_vector', 
                       'hvfhs_license_num_index_vector', 'dispatching_base_num_index_vector', 
                       'shared_request_flag_index_vector', 'shared_match_flag_index_vector', 
                       'wav_request_flag_index_vector', 'yearmonth_index_vector', 'dayofweek_index_vector',
                       'features']).show(truncate=False)

[Stage 18:>                                                         (0 + 1) / 1]

+-------------------+-------------------+-------------------+--------------+------------------------------+---------------------------------+--------------------------------+------------------------------+-----------------------------+----------------------+----------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|PULocationID_vector|DOLocationID_vector|year_vector        |month_vector  |hvfhs_license_num_index_vector|dispatching_base_num_index_vector|shared_request_flag_index_vector|shared_match_flag_index_vector|wav_request_flag_index_vector|yearmonth_index_vector|dayofweek_index_vector|features                                                                                                                                                                                     |
+-------------------+-------------------+---------------

                                                                                

### Min-Max Scaling

In [18]:
from pyspark.ml.feature import VectorAssembler, MinMaxScaler
from pyspark.sql.functions import col

# List of 'double' columns to scale
double_columns = ['trip_miles', 'trip_time', 'base_passenger_fare', 'tolls', 
                  'bcf', 'sales_tax', 'congestion_surcharge', 'tips', 'driver_pay']

# Apply VectorAssembler and MinMaxScaler for each 'double' column
for col_name in double_columns:
    # Assemble each column into a vector
    assembler = VectorAssembler(inputCols=[col_name], outputCol=col_name + "Vector")
    indexed_sdf = assembler.transform(indexed_sdf) 
    
    # Scale the vector column
    scaler = MinMaxScaler(inputCol=col_name + "Vector", outputCol=col_name + "_scaled")
    indexed_sdf = scaler.fit(indexed_sdf).transform(indexed_sdf)

# Show the resulting DataFrame with both original and scaled columns
indexed_sdf.show(truncate=False)


[Stage 46:>                                                         (0 + 1) / 1]

+-----------------+--------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+----+---------+--------------------+----+----------+-------------------+-----------------+----------------+----+-----+---------+---------+-------+-----------------------+--------------------------+-------------------------+-----------------------+----------------------+---------------+---------------+----------------+-----------------------+---------------+-----------------------+-------------------------+--------------------------+-----------+-----------------------+---------+-----------------------+---------------+-----------------------+--------------------------+---------------------------+----------+-----------------------+----------------+----------------------+
|hvfhs_license_num|dispatching_base_num|request_datetime   |pickup_datetime    |dropoff_datetime   |PULocationID|DOLocationID|trip_miles|trip_time|bas

                                                                                