In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import types
from pyspark.sql import functions as F
from pyspark.sql.functions import col
from pyspark.sql.window import Window
import os

In [2]:
spark = SparkSession.builder \
    .master('local[*]') \
    .appName('data_transformation') \
    .getOrCreate()

sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/06 03:58:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data_path = '../raw_data'

In [4]:
new_schema = types.StructType([
    types.StructField('VendorID', types.IntegerType(), True), 
    types.StructField('tpep_pickup_datetime', types.TimestampNTZType(), True), 
    types.StructField('tpep_dropoff_datetime', types.TimestampNTZType(), True), 
    types.StructField('passenger_count', types.IntegerType(), True), 
    types.StructField('trip_distance', types.FloatType(), True), 
    types.StructField('RatecodeID', types.IntegerType(), True), 
    types.StructField('store_and_fwd_flag', types.StringType(), True), 
    types.StructField('PULocationID', types.IntegerType(), True), 
    types.StructField('DOLocationID', types.IntegerType(), True), 
    types.StructField('payment_type', types.IntegerType(), True), 
    types.StructField('fare_amount', types.FloatType(), True), 
    types.StructField('extra', types.FloatType(), True), 
    types.StructField('mta_tax', types.FloatType(), True), 
    types.StructField('tip_amount', types.FloatType(), True), 
    types.StructField('tolls_amount', types.FloatType(), True), 
    types.StructField('improvement_surcharge', types.FloatType(), True), 
    types.StructField('total_amount', types.FloatType(), True), 
    types.StructField('congestion_surcharge', types.FloatType(), True), 
    types.StructField('airport_fee', types.IntegerType(), True)])

In [5]:
# years = os.listdir(data_path)
# for year in years:
#     months = os.listdir(f'{data_path}/{year}')
#     spark.read.option('headers', True).parquet(f'{data_path}/{year}/*/*.parquet')

In [6]:
df = spark.read.option('headers', True).parquet(f'{data_path}/2019/01/*.parquet')

                                                                                

In [7]:
old_schema = df.schema

In [8]:
for old_field, new_field in zip(old_schema.fields, new_schema.fields):
    df = df.withColumn(new_field.name, col(old_field.name).cast(new_field.dataType))

In [9]:
df = df.withColumnRenamed('VendorID', 'vendor_id') \
    .withColumnRenamed('RatecodeID', 'ratecode_id') \
    .withColumnRenamed('tpep_pickup_datetime', 'pickup_datetime') \
    .withColumnRenamed('tpep_dropoff_datetime', 'dropoff_datetime') \
    .withColumnRenamed('PULocationID', 'pickup_location_id') \
    .withColumnRenamed('DOLocationID', 'dropoff_location_id')

In [42]:
df = df.filter((col('fare_amount') > 0) \
               & (col('trip_distance') > 0) \
               & (col('extra') > 0))

df = df.filter((col('ratecode_id') <= 6))

df = df.withColumn('congestion_surcharge', F.when(col('congestion_surcharge').isNull(), 0).otherwise(col('congestion_surcharge')))
df = df.withColumn('airport_fee', F.when(col('airport_fee').isNull(), 0).otherwise(col('airport_fee')))

In [None]:
window_spec = Window.orderBy('passenger_count')

df_rn = df.select(['passenger_count']).withColumn('rn', F.row_number().over(window_spec))
total_rows = df.count()
                                                         
if total_rows % 2 == 0:
    lower_mid = total_rows // 2
    upper_mid = lower_mid + 1
else:
    lower_mid = total_rows // 2 + 1
    upper_mid = lower_mid

median_df = df_rn.filter((col('rn') == lower_mid) | (col('rn') == upper_mid))

median_value = median_df.agg(F.avg(col('passenger_count'))).collect()[0][0]

df = df.withColumn('passenger_count', F.when(col('passenger_count') == 0, median_value).otherwise(col('passenger_count')))

In [27]:
pickup_datetime_dim = df.select(['pickup_datetime']) \
    .distinct() \
    .withColumn('pickup_datetime_id', F.monotonically_increasing_id()) \
    .withColumn('pickup_hour', F.hour(col('pickup_datetime'))) \
    .withColumn('pickup_day', F.dayofmonth(col('pickup_datetime'))) \
    .withColumn('pickup_month', F.month(col('pickup_datetime'))) \
    .withColumn('pickup_year', F.year(col('pickup_datetime'))) \
    .withColumn('pickup_weekday', F.date_format(col('pickup_datetime'), 'EEEE'))

pickup_datetime_dim = pickup_datetime_dim.select(
    'pickup_datetime_id',
    'pickup_datetime',
    'pickup_hour',
    'pickup_day',
    'pickup_month',
    'pickup_year',
    'pickup_weekday'
)

In [28]:
pickup_datetime_id_dict = {row['pickup_datetime']: row['pickup_datetime_id'] for row in pickup_datetime_dim.collect()} 

pickup_datetime_id_broadcast = sc.broadcast(pickup_datetime_id_dict)

def get_pickup_datetime_id(date_time):
    return pickup_datetime_id_broadcast.value.get(date_time, None)

get_pickup_datetime_id_udf = F.udf(get_pickup_datetime_id, types.IntegerType())

df = df.withColumn('pickup_datetime_id', get_pickup_datetime_id_udf('pickup_datetime'))

                                                                                

In [29]:
dropoff_datetime_dim = df.select(['dropoff_datetime']) \
    .distinct() \
    .withColumn('dropoff_datetime_id', F.monotonically_increasing_id()) \
    .withColumn('dropoff_hour', F.hour(col('dropoff_datetime'))) \
    .withColumn('dropoff_day', F.dayofmonth(col('dropoff_datetime'))) \
    .withColumn('dropoff_month', F.month(col('dropoff_datetime'))) \
    .withColumn('dropoff_year', F.year(col('dropoff_datetime'))) \
    .withColumn('dropoff_weekday', F.date_format(col('dropoff_datetime'), 'EEEE'))

dropoff_datetime_dim = dropoff_datetime_dim.select(
    'dropoff_datetime_id',
    'dropoff_datetime',
    'dropoff_hour',
    'dropoff_day',
    'dropoff_month',
    'dropoff_year',
    'dropoff_weekday'
)

In [30]:
dropoff_datetime_id_dict = {row['dropoff_datetime']: row['dropoff_datetime_id'] for row in dropoff_datetime_dim.collect()} 

dropoff_datetime_id_broadcast = sc.broadcast(dropoff_datetime_id_dict)

def get_dropoff_datetime_id(date_time):
    return dropoff_datetime_id_broadcast.value.get(date_time, None)

get_dropoff_datetime_id_udf = F.udf(get_dropoff_datetime_id, types.IntegerType())

df = df.withColumn('dropoff_datetime_id', get_dropoff_datetime_id_udf('dropoff_datetime'))

                                                                                

In [31]:
df.show()



+---------+-------------------+-------------------+---------------+-------------+-----------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------------+
|vendor_id|    pickup_datetime|   dropoff_datetime|passenger_count|trip_distance|ratecode_id|store_and_fwd_flag|pickup_location_id|dropoff_location_id|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|pickup_datetime_id|dropoff_datetime_id|
+---------+-------------------+-------------------+---------------+-------------+-----------+------------------+------------------+-------------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------------+
|        1|2019-01-01 00:46:40|201

                                                                                

In [34]:
df = df.select(['vendor_id',
     'pickup_datetime_id',
     'dropoff_datetime_id',
     'pickup_location_id',
     'dropoff_location_id', 
     'ratecode_id',
     'passenger_count',
     'trip_distance',
     'payment_type',
     'store_and_fwd_flag',
     'fare_amount',
     'extra',
     'mta_tax',
     'tip_amount',
     'tolls_amount',
     'improvement_surcharge',
     'congestion_surcharge',
     'airport_fee',
     'total_amount'])

In [43]:
for c in ['congestion_surcharge', 'airport_fee']:
    count = df.select(c).filter(col(c).isNull()).count()
    print(c, count)

                                                                                

congestion_surcharge 0
airport_fee 0


                                                                                