In [1]:
import pyspark
from pyspark.sql import SparkSession, Window
from pyspark.sql import types
from pyspark.sql import functions as F
from pyspark.sql.functions import col
from pyspark.sql.window import Window
import os
import sys

In [2]:
year = sys.argv[1]
month = sys.argv[2]

In [3]:
minio_access_key = os.getenv('MINIO_ROOT_USER')
minio_secret_key = os.getenv('MINIO_ROOT_PASSWORD')
s3_path = 's3a://nyc-project/raw_data/'

In [7]:
spark.stop()

In [15]:
spark = SparkSession.builder \
    .master('local[*]') \
    .appName('data_transformation') \
    .config("spark.hadoop.fs.s3a.access.key", minio_access_key) \
    .config("spark.hadoop.fs.s3a.secret.key", minio_secret_key) \
    .getOrCreate()

sc = spark.sparkContext

In [8]:
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder \
    .master('local[*]') \
    .appName('data_transformation') \
    .config("spark.hadoop.fs.s3a.access.key", minio_access_key) \
    .config("spark.hadoop.fs.s3a.secret.key", minio_secret_key) \
    .config("spark.hadoop.fs.s3a.endpoint", "http://object-storage:9000") \
    .config("spark.hadoop.fs.s3a.path.style.access", "true") \
    .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.sql.catalog.my_catalog", "org.apache.iceberg.spark.SparkCatalog") \
    .config("spark.sql.catalog.my_catalog.type", "hadoop") \
    .config("spark.sql.catalog.my_catalog.warehouse", "s3a://iceberg/warehouse") \
    .getOrCreate()

# Create SparkContext
sc = spark.sparkContext

In [45]:
df.printSchema()

root
 |-- vendor_id: integer (nullable = true)
 |-- pickup_datetime_id: string (nullable = true)
 |-- dropoff_datetime_id: string (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- ratecode_id: integer (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: float (nullable = true)
 |-- payment_type_id: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- fare_amount: float (nullable = true)
 |-- extra: float (nullable = true)
 |-- mta_tax: float (nullable = true)
 |-- tip_amount: float (nullable = true)
 |-- tolls_amount: float (nullable = true)
 |-- improvement_surcharge: float (nullable = true)
 |-- congestion_surcharge: float (nullable = true)
 |-- airport_fee: integer (nullable = true)
 |-- total_amount: float (nullable = true)



In [77]:
# Create an Iceberg table in the catalog with the defined schema
spark.sql("""
    CREATE OR REPLACE TABLE my_catalog.db_name.my_table (
        vendor_id INT,
        pickup_datetime_id INT,
        dropoff_datetime_id INT,
        pickup_location_id INT,
        dropoff_location_id INT,
        ratecode_id INT,
        passenger_count INT,
        trip_distance FLOAT,
        payment_type_id INT,
        store_and_fwd_flag STRING,
        fare_amount FLOAT,
        extra FLOAT,
        mta_tax FLOAT,
        tip_amount FLOAT,
        tolls_amount FLOAT,
        improvement_surcharge FLOAT,
        congestion_surcharge FLOAT,
        airport_fee INT,
        total_amount FLOAT
    )
    USING iceberg
""")

DataFrame[]

In [27]:
new_schema = types.StructType([
    types.StructField('VendorID', types.IntegerType(), True), 
    types.StructField('tpep_pickup_datetime', types.TimestampType(), True), 
    types.StructField('tpep_dropoff_datetime', types.TimestampType(), True), 
    types.StructField('passenger_count', types.IntegerType(), True), 
    types.StructField('trip_distance', types.FloatType(), True), 
    types.StructField('RatecodeID', types.IntegerType(), True), 
    types.StructField('store_and_fwd_flag', types.StringType(), True), 
    types.StructField('PULocationID', types.IntegerType(), True), 
    types.StructField('DOLocationID', types.IntegerType(), True), 
    types.StructField('payment_type', types.IntegerType(), True), 
    types.StructField('fare_amount', types.FloatType(), True), 
    types.StructField('extra', types.FloatType(), True), 
    types.StructField('mta_tax', types.FloatType(), True), 
    types.StructField('tip_amount', types.FloatType(), True), 
    types.StructField('tolls_amount', types.FloatType(), True), 
    types.StructField('improvement_surcharge', types.FloatType(), True), 
    types.StructField('total_amount', types.FloatType(), True), 
    types.StructField('congestion_surcharge', types.FloatType(), True), 
    types.StructField('airport_fee', types.IntegerType(), True)])

In [57]:
df = spark.read.option('headers', True).parquet(f'{s3_path}/2019/01/*.parquet')

In [58]:
old_schema = df.schema

In [59]:
for old_field, new_field in zip(old_schema.fields, new_schema.fields):
    df = df.withColumn(new_field.name, col(old_field.name).cast(new_field.dataType))

In [60]:
df = df.withColumnRenamed('VendorID', 'vendor_id') \
    .withColumnRenamed('RatecodeID', 'ratecode_id') \
    .withColumnRenamed('payment_type', 'payment_type_id') \
    .withColumnRenamed('tpep_pickup_datetime', 'pickup_datetime') \
    .withColumnRenamed('tpep_dropoff_datetime', 'dropoff_datetime') \
    .withColumnRenamed('PULocationID', 'pickup_location_id') \
    .withColumnRenamed('DOLocationID', 'dropoff_location_id')

In [61]:
df = df.filter((col('fare_amount') > 0) \
               & (col('trip_distance') > 0) \
               & (col('extra') > 0))

df = df.filter((col('ratecode_id') <= 6))

df = df.withColumn('congestion_surcharge', F.when(col('congestion_surcharge').isNull(), 0).otherwise(col('congestion_surcharge')))
df = df.withColumn('airport_fee', F.when(col('airport_fee').isNull(), 0).otherwise(col('airport_fee')))

In [62]:
window_spec = Window.orderBy('passenger_count')

df_rn = df.select(['passenger_count']).withColumn('rn', F.row_number().over(window_spec))
total_rows = df.count()
                                                         
if total_rows % 2 == 0:
    lower_mid = total_rows // 2
    upper_mid = lower_mid + 1
else:
    lower_mid = total_rows // 2 + 1
    upper_mid = lower_mid

median_df = df_rn.filter((col('rn') == lower_mid) | (col('rn') == upper_mid))

median_value = median_df.agg(F.avg(col('passenger_count'))).collect()[0][0]

df = df.withColumn('passenger_count', F.when(col('passenger_count') == 0, median_value).otherwise(col('passenger_count')))

24/10/03 19:22:00 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/10/03 19:22:00 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/10/03 19:22:02 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
24/10/03 19:22:02 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
                                                                                

In [63]:
def index_id(date_column):
    year = F.year(date_column)
    month = F.lpad(F.month(date_column).cast("string"), 2, "0")
    day = F.lpad(F.dayofmonth(date_column).cast("string"), 2, "0")
    hour = F.lpad(F.hour(date_column).cast("string"), 2, "0")
    minute = F.lpad(F.minute(date_column).cast("string"), 2, "0")
    second = F.lpad(F.second(date_column).cast("string"), 2, "0")
    index = F.concat(year, month, day, hour, minute, second)
    return index.cast("int")

In [64]:
pickup_datetime_dim = df.select(['pickup_datetime']) \
    .distinct() \
    .withColumn('pickup_datetime_id', index_id(col('pickup_datetime'))) \
    .withColumn('pickup_hour', F.hour(col('pickup_datetime'))) \
    .withColumn('pickup_day', F.dayofmonth(col('pickup_datetime'))) \
    .withColumn('pickup_month', F.month(col('pickup_datetime'))) \
    .withColumn('pickup_year', F.year(col('pickup_datetime'))) \
    .withColumn('pickup_weekday', F.date_format(col('pickup_datetime'), 'EEEE'))

pickup_datetime_dim = pickup_datetime_dim.select(
    'pickup_datetime_id',
    'pickup_datetime',
    'pickup_hour',
    'pickup_day',
    'pickup_month',
    'pickup_year',
    'pickup_weekday'
)

In [65]:
dropoff_datetime_dim = df.select(['dropoff_datetime']) \
    .distinct() \
    .withColumn('dropoff_datetime_id', index_id(col('dropoff_datetime'))) \
    .withColumn('dropoff_hour', F.hour(col('dropoff_datetime'))) \
    .withColumn('dropoff_day', F.dayofmonth(col('dropoff_datetime'))) \
    .withColumn('dropoff_month', F.month(col('dropoff_datetime'))) \
    .withColumn('dropoff_year', F.year(col('dropoff_datetime'))) \
    .withColumn('dropoff_weekday', F.date_format(col('dropoff_datetime'), 'EEEE'))

dropoff_datetime_dim = dropoff_datetime_dim.select(
    'dropoff_datetime_id',
    'dropoff_datetime',
    'dropoff_hour',
    'dropoff_day',
    'dropoff_month',
    'dropoff_year',
    'dropoff_weekday'
)

In [66]:
df = df.withColumn('dropoff_datetime_id', index_id(col('dropoff_datetime'))) \
    .withColumn('pickup_datetime_id', index_id(col('pickup_datetime')))

In [67]:
df = df.select(['vendor_id',
 'pickup_datetime_id',
 'dropoff_datetime_id',
 'pickup_location_id',
 'dropoff_location_id',
 'ratecode_id',
 'passenger_count',
 'trip_distance',
 'payment_type_id',
 'store_and_fwd_flag',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'congestion_surcharge',
 'airport_fee',
 'total_amount'])

In [69]:
df.show(5)



+---------+------------------+-------------------+------------------+-------------------+-----------+---------------+-------------+---------------+------------------+-----------+-----+-------+----------+------------+---------------------+--------------------+-----------+------------+
|vendor_id|pickup_datetime_id|dropoff_datetime_id|pickup_location_id|dropoff_location_id|ratecode_id|passenger_count|trip_distance|payment_type_id|store_and_fwd_flag|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|congestion_surcharge|airport_fee|total_amount|
+---------+------------------+-------------------+------------------+-------------------+-----------+---------------+-------------+---------------+------------------+-----------+-----+-------+----------+------------+---------------------+--------------------+-----------+------------+
|        1|              null|               null|               151|                239|          1|            1.0|          1.5|              

                                                                                

In [78]:
df.writeTo("my_catalog.db_name.my_table").append()

                                                                                

In [81]:
df_sql = spark.sql("SELECT * FROM my_catalog.db_name.my_table LIMIT 10")

In [82]:
df_sql.show()

+---------+------------------+-------------------+------------------+-------------------+-----------+---------------+-------------+---------------+------------------+-----------+-----+-------+----------+------------+---------------------+--------------------+-----------+------------+
|vendor_id|pickup_datetime_id|dropoff_datetime_id|pickup_location_id|dropoff_location_id|ratecode_id|passenger_count|trip_distance|payment_type_id|store_and_fwd_flag|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|congestion_surcharge|airport_fee|total_amount|
+---------+------------------+-------------------+------------------+-------------------+-----------+---------------+-------------+---------------+------------------+-----------+-----+-------+----------+------------+---------------------+--------------------+-----------+------------+
|        1|              null|               null|               151|                239|          1|              1|          1.5|              

                                                                                