# Step2-1st preprocessing

## 1. Impot necessary modules & start a spark session

In [None]:
# Import necessary modules
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import os

In [None]:
# Create a Spark session
spark = (
    SparkSession.builder.appName('ADS_project_1.py')
    .config('spark.sql.repl.eagerEval.enabled', True)
    .config('spark.sql.parquet.cacheMetadata', 'true')
    .config('spark.sql.session.timeZone', 'Etc/UTC')
    .config('spark.driver.memory', '16g')
    .config('spark.executer.memory', '16g')
    .getOrCreate()
)

## 2. Data import & overview

Import `TLC_data` from directory `data/landing/TLC_data/`

In [None]:
TLC_data_dir = '../data/landing/TLC_data/*.parquet'
TLC_data = spark.read.parquet(TLC_data_dir)

`TLC_data` overview

In [None]:
original_num_rows = TLC_data.count()
original_num_cols = len(TLC_data.columns)

print('number of rows: ', original_num_rows)
print('number of cols: ', original_num_cols)
TLC_data.limit(5)

## 3. Remove features not relevant to research goal

Define some features that we consider unnecessary

In [None]:
useless_feature_list = [
    'VendorID', 'RatecodeID', 'store_and_fwd_flag', 'payment_type', 'fare_amount', 
    'mta_tax', 'improvement_surcharge', 'tip_amount', 'total_amount'
]

Remove the invalid rows before deleting these unrelated features

In [None]:
# Filter out rows based on valid values of these features
TLC_data = (
    TLC_data
        # remove 'VendorID' that are not 1 & 2
        .where((F.col('VendorID') == 1) | (F.col('VendorID') == 2))

        # remove 'RatecodeID' that are in 1 ~ 6
        .where(F.col('RatecodeID').isin([1, 2, 3, 4, 5, 6]))

        # remove 'store_and_fwd_flag' that are not 'Y' & 'N'
        .where((F.col('store_and_fwd_flag') == 'Y') | (F.col('store_and_fwd_flag') == 'N'))

        # remove 'payment_type' that are not 1 & 2
        .where(F.col('payment_type').isin([1, 2]))

        # remove 'fare_amount' those are smaller than or equal to 0
        .where(F.col('fare_amount') > 0)

        # remove 'mta_tax' that is not 0.5
        .where(F.col('mta_tax') == 0.5)

        # remove 'improvement_surcharge' that is not 0.3
        .where(F.col('improvement_surcharge') == 0.3)

        # remove 'tip_amount' those are smaller than 0
        .where(F.col('tip_amount') >= 0)

        # remove 'total_amount' those are smaller than 0
        .where(F.col('total_amount') >= 0)
)              

Deleting these unrelated features, such as 'VendorID' and most fee features

In [None]:
for useless_feature in useless_feature_list:
    TLC_data = TLC_data.drop(useless_feature)

TLC_data.limit(5)

## 4. Changes for readability

Rename features

In [None]:
TLC_data = TLC_data.withColumnRenamed('tpep_pickup_datetime', 'pickup_time') \
                   .withColumnRenamed('tpep_dropoff_datetime', 'dropoff_time') \
                   .withColumnRenamed('PULocationID', 'up_location_id') \
                   .withColumnRenamed('DOLocationID', 'off_location_id') \
                   .withColumnRenamed('extra', 'extra_fee') \
                   .withColumnRenamed('tolls_amount', 'toll_fee') \
                   .withColumnRenamed('passenger_count', '#passenger') \
                   .withColumnRenamed('congestion_surcharge', 'congestion_fee') \

TLC_data.limit(5)

Reorder features

In [None]:
TLC_data = TLC_data.select(
    'pickup_time', 'dropoff_time', 'up_location_id', 'off_location_id', '#passenger', 
    'trip_distance', 'congestion_fee', 'extra_fee', 'toll_fee', 'airport_fee'
)

TLC_data.limit(5)

## 5. Data overview after 1st preprocessing & saving

`TLC_data` overview after 1st preprocessing

In [None]:
num_rows_after_1st_preprocessing = TLC_data.count()
num_cols_after_1st_preprocessing = len(TLC_data.columns)

num_removed_rows = original_num_rows - num_rows_after_1st_preprocessing
num_removed_cols = original_num_cols - num_cols_after_1st_preprocessing

print('number of rows: ', num_rows_after_1st_preprocessing)
print('number of cols: ', num_cols_after_1st_preprocessing)
print('\n')
print('number of removed rows: ', num_removed_rows)
print('number of removed cols: ', num_removed_cols)

Save `TLC_data` to directory `data/raw/TLC_data/`

In [None]:
# Define the directory for saving 1st preprocessed data
directory = ('../data/raw/TLC_data')
# Check if the directory exists; if not, create it
if not os.path.exists(directory):
    os.makedirs(directory)

# Save TLC_data
TLC_data.write.mode('overwrite').parquet('../data/raw/TLC_data/TLC_data.parquet')

## 6. Stop spark session

In [None]:
spark.stop()