In [1]:
import os
import pyspark
import pyspark.sql.functions as F
root_folder = "/home/trungdc/unimelb/MAST30024/asm/mast30034_2021_s2_project_1-alexdang02-1"
data_dir = os.path.join(root_folder, "Data")
SQLOutput_dir = os.path.join(root_folder, "code/SparkSQL_Output")
plot_dir = os.path.join(root_folder, "Plots")

In [2]:
import pyspark.sql.functions as F
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [3]:
from pyspark.sql import SparkSession
import warnings
warnings.filterwarnings("ignore")

spark = SparkSession.builder.getOrCreate()

In [4]:
def weekend(dayofweek):
  if   dayofweek <= 5: 
      return 0
  else:
      return 1
udfWeekendFunc = F.udf(weekend, IntegerType())

def workingHour(hour, weekend):
    if weekend <= 5:
        if 7 <= hour <= 19:
            return 1
        else:
            return 0
    else:
        return 0
udfworkingHour = F.udf(workingHour, IntegerType())

In [6]:
def preprocess(sdf):
    sdf = sdf.withColumnRenamed("duration(m)", "duration")  \
    .withColumnRenamed('expected_total_distance(miles)', "expected_total_distance") \
    .withColumnRenamed('expected_total_duration(s)', "expected_total_duration") 
    sdf = sdf.withColumn("expected_total_duration", sdf.expected_total_duration/60)
    sdf = sdf.filter(sdf.passenger_count <= 6) \
        .filter(sdf.PULocationID != sdf.DOLocationID) \
        .filter( ~ ((sdf.RatecodeID == 2) & (sdf.fare_amount <= 50))) \
        .filter(sdf.duration < 500) \
        .filter(sdf.total_amount >= 2.5) \
        .filter(sdf.tolls_amount >= 0) \
        .filter(sdf.VendorID.isin([1,2])) \
        .filter(sdf.RatecodeID.isin([1,2])) \
        .filter(sdf.payment_type.isin([1,2])) \
        .filter(sdf.tip_amount <= 25) \
    .withColumn("DayofWeek", dayofweek(sdf.tpep_pickup_datetime))   \
    .withColumn("Weekend", udfWeekendFunc(col("DayofWeek"))) \
    .withColumn("Month", month(sdf.tpep_pickup_datetime))   \
    .withColumn("Hour", hour(sdf.tpep_pickup_datetime)) 

    sdf = sdf.withColumn("WorkingHour", udfworkingHour(col("Hour"), col("DayofWeek"))) \
    .withColumn("trip_distance",F.when(sdf.trip_distance==0, sdf.expected_total_distance).otherwise(sdf.trip_distance)) \
    .withColumn("passenger_count", F.when(sdf.passenger_count==0, 1).when(sdf.passenger_count==6, 5).otherwise(sdf.passenger_count)) \
    .withColumn("tip_amount", F.when(sdf.tip_amount<0, 0).otherwise(sdf.tip_amount)) \
    .withColumn("fare_amount", F.when(sdf.fare_amount<2.5, 2.5).otherwise(sdf.fare_amount)) \
    .drop("expected_total_distance", "expected_total_duration", "tpep_pickup_datetime")
    return sdf

sdf = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load(os.path.join(data_dir,"Merge", "train.csv"))
sdf = preprocess(sdf)
sdf.createOrReplaceTempView("trip")
sdf.limit(5)

DataFrame[VendorID: int, passenger_count: int, trip_distance: double, RatecodeID: int, PULocationID: int, DOLocationID: int, payment_type: int, fare_amount: double, extra: double, mta_tax: double, tip_amount: double, tolls_amount: double, improvement_surcharge: double, total_amount: double, duration: double, tempMax: double, tempMin: double, tempAvg: double, tempDeparture: double, hdd: double, cdd: double, precipitation: double, newSnow: double, snowDepth: double, DayofWeek: int, Weekend: int, Month: int, Hour: int, WorkingHour: int]

In [8]:
for col in sdf.dtypes:
    print((col[0], col[1]))

('VendorID', 'int')
('passenger_count', 'int')
('trip_distance', 'double')
('RatecodeID', 'int')
('PULocationID', 'int')
('DOLocationID', 'int')
('payment_type', 'int')
('fare_amount', 'double')
('extra', 'double')
('mta_tax', 'double')
('tip_amount', 'double')
('tolls_amount', 'double')
('improvement_surcharge', 'double')
('total_amount', 'double')
('duration', 'double')
('tempMax', 'double')
('tempMin', 'double')
('tempAvg', 'double')
('tempDeparture', 'double')
('hdd', 'double')
('cdd', 'double')
('precipitation', 'double')
('newSnow', 'double')
('snowDepth', 'double')
('DayofWeek', 'int')
('Weekend', 'int')
('Month', 'int')
('Hour', 'int')
('WorkingHour', 'int')
