In [1]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
import spark_util

In [2]:
submit_args = "--conf spark.hadoop.fs.s3a.endpoint=http://minio-ml-workshop:9000 \
--conf spark.hadoop.fs.s3a.access.key=minio \
--conf spark.hadoop.fs.s3a.secret.key=minio123 \
--conf spark.hadoop.fs.s3a.path.style.access=true \
--conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem \
--conf spark.hadoop.fs.s3a.multipart.size=104857600 \
--packages org.apache.hadoop:hadoop-aws:3.2.0,org.postgresql:postgresql:42.3.3"

spark = spark_util.getOrCreateSparkSession("Enrich flights data", submit_args)
spark.sparkContext.setLogLevel("INFO")
print('Spark context started.')    

Initializing environment variables for Spark
Creating a spark session...
Spark session created


Spark context started.


In [3]:
df_flights = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://pg-flights-data:5432/postgres") \
    .option("dbtable", "flights") \
    .option("user", "postgres") \
    .option("password", "postgres") \
    .option("driver", "org.postgresql.Driver") \
    .option("numPartitions", 31) \
    .option("partitionColumn", "day") \
    .option("lowerBound", 0)\
    .option("upperBound", 31)\
    .load()

print(f"Partition count:{df_flights.rdd.getNumPartitions()}")

df_airlines = spark.read\
                .options(delimeter=',', inferSchema='True', header='True') \
                .csv("s3a://airport-data/airlines.csv")
df_airports = spark.read\
                .options(delimiter=',', inferSchema='True', header='True') \
                .csv("s3a://airport-data/airports.csv")

#df_flights.printSchema()
#df_airlines.printSchema()
#df_airports.printSchema()

Partition count:31


In [4]:
from pyspark.sql.functions import col

df_airlines = df_airlines.select([col(c).alias("AL_"+c) for c in df_airlines.columns])
df_o_airports = df_airports.select([col(c).alias("ORIG_"+c) for c in df_airports.columns])
df_d_airports = df_airports.select([col(c).alias("DEST_"+c) for c in df_airports.columns])
#df_airlines.printSchema()
#df_o_airports.printSchema()
#df_d_airports.printSchema()

In [5]:
df_flights = df_flights\
    .join(broadcast(df_airlines), df_flights.airline == df_airlines.AL_IATA_CODE)\
    .join(broadcast(df_o_airports), df_flights.origin_airport == df_o_airports.ORIG_IATA_CODE)\
    .join(broadcast(df_d_airports), df_flights.destination_airport == df_d_airports.DEST_IATA_CODE)

#df_flights.printSchema()

root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- airline: string (nullable = true)
 |-- flight_number: integer (nullable = true)
 |-- tail_number: string (nullable = true)
 |-- origin_airport: string (nullable = true)
 |-- destination_airport: string (nullable = true)
 |-- scheduled_departure: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- departure_delay: integer (nullable = true)
 |-- taxi_out: integer (nullable = true)
 |-- wheels_off: string (nullable = true)
 |-- scheduled_time: integer (nullable = true)
 |-- elapsed_time: integer (nullable = true)
 |-- air_time: integer (nullable = true)
 |-- distance: integer (nullable = true)
 |-- wheels_on: string (nullable = true)
 |-- taxi_in: integer (nullable = true)
 |-- scheduled_arrival: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- arrival_delay: integer (nullable =

In [6]:
output_location = "s3a://flights-data/flights"

df_flights.cache() #this is to make sure the DAG is not recalculated when we call the .count() later
df_flights.write.mode("overwrite")\
    .option("header","true")\
    .format("parquet").save(output_location)

df_flights.count()

5332914

In [16]:
spark.stop()