In [1]:
import pyspark
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import monotonically_increasing_id, col
from pyspark.sql.types import StructField, StructType, StringType, LongType, DoubleType, TimestampType
sc = pyspark.SparkContext(appName='taxi-etl') # create SparkContext 
spark = SparkSession.builder.appName('taxi-etl') \
                    .config("spark.executor.memory", "4g") \
                    .config("spark.driver.memory", "4g") \
                    .getOrCreate()

In [2]:
def read_and_preprocess(spark_session, root_path):

    # Create custom schema for the files
    custom_schema = StructType([StructField('VendorID', LongType(), True),
                                StructField('tpep_pickup_datetime', TimestampType(), True),
                                StructField('tpep_dropoff_datetime', TimestampType(), True),
                                StructField('passenger_count', DoubleType(), True),
                                StructField('trip_distance', DoubleType(), True),
                                StructField('RatecodeID', LongType(), True),
                                StructField('store_and_fwd_flag', StringType(), True),
                                StructField('PULocationID', LongType(), True),
                                StructField('DOLocationID', LongType(), True),
                                StructField('payment_type', LongType(), True),
                                StructField('fare_amount', DoubleType(), True),
                                StructField('extra', DoubleType(), True),
                                StructField('mta_tax', DoubleType(), True),
                                StructField('tip_amount', DoubleType(), True),
                                StructField('tolls_amount', DoubleType(), True),
                                StructField('improvement_surcharge', DoubleType(), True),
                                StructField('total_amount', DoubleType(), True),
                                StructField('congestion_surcharge', DoubleType(), True),
                                StructField('airport_fee', DoubleType(), True),
                                StructField('report_date', StringType(), True)
    ])

    # read the data from root_path
    df = spark_session.read.schema(custom_schema).parquet(root_path)
    
    # clean the column names and perform type casting, add primary key trip_id
    df = df.withColumnRenamed("VendorID", "vendor_id") \
           .withColumnRenamed("PULocationID", "pick_up_location_id") \
           .withColumnRenamed("DOLocationID", "drop_off_location_id") \
           .withColumnRenamed("RatecodeID", "rate_code_id") \
           .withColumnRenamed("payment_type", "payment_type_id") \
           .withColumn("passenger_count", col("passenger_count").cast(LongType())) \
           .withColumn("trip_id", monotonically_increasing_id())

    return df

In [3]:
root_path = "raw_data"
df = read_and_preprocess(spark, root_path)
df.printSchema()

root
 |-- vendor_id: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- rate_code_id: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- pick_up_location_id: long (nullable = true)
 |-- drop_off_location_id: long (nullable = true)
 |-- payment_type_id: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)
 |-- report_date: string (nullable = true)
 |-- trip_id: long (nullable = false)



In [4]:
from pyspark.sql.functions import year, month, dayofmonth, dayofweek, hour

# Construct datetime dimension table
dim_datetime = df.select("tpep_pickup_datetime", "tpep_dropoff_datetime") \
                 .withColumn("pick_up_year", year("tpep_pickup_datetime")) \
                 .withColumn("pick_up_month", month("tpep_pickup_datetime")) \
                 .withColumn("pick_up_day", dayofmonth("tpep_pickup_datetime")) \
                 .withColumn("pick_up_dayofweek", dayofweek("tpep_pickup_datetime")) \
                 .withColumn("pick_up_hour", hour("tpep_pickup_datetime")) \
                 .withColumn("drop_off_year", year("tpep_dropoff_datetime")) \
                 .withColumn("drop_off_month", month("tpep_dropoff_datetime")) \
                 .withColumn("drop_off_day", dayofmonth("tpep_dropoff_datetime")) \
                 .withColumn("drop_off_dayofweek", dayofweek("tpep_dropoff_datetime")) \
                 .withColumn("drop_off_hour", hour("tpep_dropoff_datetime")) \
                 .withColumn("datetime_id", monotonically_increasing_id())
dim_datetime.show(5)

+--------------------+---------------------+------------+-------------+-----------+-----------------+------------+-------------+--------------+------------+------------------+-------------+-----------+
|tpep_pickup_datetime|tpep_dropoff_datetime|pick_up_year|pick_up_month|pick_up_day|pick_up_dayofweek|pick_up_hour|drop_off_year|drop_off_month|drop_off_day|drop_off_dayofweek|drop_off_hour|datetime_id|
+--------------------+---------------------+------------+-------------+-----------+-----------------+------------+-------------+--------------+------------+------------------+-------------+-----------+
| 2023-02-28 19:06:43|  2023-02-28 19:16:43|        2023|            2|         28|                3|          19|         2023|             2|          28|                 3|           19| 8589934592|
| 2023-02-28 19:08:25|  2023-02-28 19:39:30|        2023|            2|         28|                3|          19|         2023|             2|          28|                 3|           19| 85

In [5]:
from itertools import chain
from pyspark.sql.functions import create_map, lit

# Construct rate code dimension table according to mapping at https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
rate_code_mapping = {
    1:"Standard rate",
    2:"JFK",
    3:"Newark",
    4:"Nassau or Westchester",
    5:"Negotiated fare",
    6:"Group ride",
    99:"Unknown"
}

mapping_func = create_map([lit(x) for x in chain(*rate_code_mapping.items())])
dim_rate_code = df.select("rate_code_id") \
                  .withColumn("rate_code_type", mapping_func[col("rate_code_id")]) \
                  .groupBy("rate_code_id", "rate_code_type") \
                  .count() \
                  .orderBy(col("rate_code_id").asc()) \
                  .drop("count")     

dim_rate_code.show()

+------------+--------------------+
|rate_code_id|      rate_code_type|
+------------+--------------------+
|           1|       Standard rate|
|           2|                 JFK|
|           3|              Newark|
|           4|Nassau or Westche...|
|           5|     Negotiated fare|
|           6|          Group ride|
|          99|             Unknown|
+------------+--------------------+



In [6]:
# Construct vendor dimension table according to mapping at https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
vendor_mapping = {
    1:"Creative Mobile Technologies, LLC;",
    2:"VeriFone Inc.",
    6:"Unknown"
}

mapping_func = create_map([lit(x) for x in chain(*vendor_mapping.items())])
dim_vendor = df.select("vendor_id") \
               .withColumn("vendor_info", mapping_func[col("vendor_id")]) \
               .groupBy("vendor_id", "vendor_info") \
               .count() \
               .orderBy(col("vendor_id").asc()) \
               .drop("count")               

dim_vendor.show()

+---------+--------------------+
|vendor_id|         vendor_info|
+---------+--------------------+
|        1|Creative Mobile T...|
|        2|       VeriFone Inc.|
+---------+--------------------+



In [7]:
# Construct payment type dimension table according to mapping at https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
payment_mapping = {
    1:"Credit card",
    2:"Cash",
    3:"No charge",
    4:"Dispute",
    5:"Unknown",
    6:"Voided trip",
}

mapping_func = create_map([lit(x) for x in chain(*payment_mapping.items())])
dim_payment_type = df.select("payment_type_id") \
                     .withColumn("payment_type", mapping_func[col("payment_type_id")]) \
                     .groupBy("payment_type_id", "payment_type") \
                     .count() \
                     .orderBy(col("payment_type_id").asc()) \
                     .drop("count") 

dim_payment_type.show()

+---------------+------------+
|payment_type_id|payment_type|
+---------------+------------+
|              1| Credit card|
|              2|        Cash|
|              3|   No charge|
|              4|     Dispute|
|              5|     Unknown|
+---------------+------------+



In [8]:
# Construct passenger count dimension table 
dim_passenger_count = df.groupBy(col("passenger_count").alias("passenger_count")) \
                        .count() \
                        .orderBy(col("passenger_count").asc()) \
                        .withColumn("passenger_count_id", monotonically_increasing_id()) \
                        .drop("count")

dim_passenger_count.show()

+---------------+------------------+
|passenger_count|passenger_count_id|
+---------------+------------------+
|              0|                 0|
|              1|                 1|
|              2|                 2|
|              3|                 3|
|              4|                 4|
|              5|                 5|
|              6|                 6|
|              7|                 7|
|              8|                 8|
|              9|                 9|
+---------------+------------------+



In [9]:
# Construct fact table from all other tables
fact_trips = df.withColumnRenamed("passenger_count", "passenger_count_id") \
               .join(dim_datetime, on=['tpep_pickup_datetime', 'tpep_dropoff_datetime'], how="inner") \
               .select("trip_id", "passenger_count_id", "vendor_id", "rate_code_id", "payment_type_id", "trip_distance",
                       "datetime_id", "pick_up_location_id", "drop_off_location_id", "store_and_fwd_flag",
                       "fare_amount", "extra", "mta_tax", "tip_amount", "tolls_amount", "improvement_surcharge",
                       "total_amount", "congestion_surcharge", "airport_fee", "report_date")

fact_trips = fact_trips.coalesce(5)
fact_trips.show(10)

+-----------+------------------+---------+------------+---------------+-------------+-----------+-------------------+--------------------+------------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------+
|    trip_id|passenger_count_id|vendor_id|rate_code_id|payment_type_id|trip_distance|datetime_id|pick_up_location_id|drop_off_location_id|store_and_fwd_flag|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|report_date|
+-----------+------------------+---------+------------+---------------+-------------+-----------+-------------------+--------------------+------------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------+
|60132082359|                 1|        2|           1|              2|          0.0|60132082359|                193|                 26

In [11]:
fact_trips.write.parquet("transformed_data/fact_trips")
dim_passenger_count.write.parquet("transformed_data/dim_passenger_count")
dim_payment_type.write.parquet("transformed_data/dim_payment_type")
dim_vendor.write.parquet("transformed_data/dim_vendor")
dim_rate_code.write.parquet("transformed_data/dim_rate_code")
dim_datetime.write.parquet("transformed_data/dim_datetime")