In [95]:
import spark_util
from pyspark.sql.functions import array, col, explode, struct, lit, approx_count_distinct, udf

In [96]:
submit_args = "--conf spark.hadoop.fs.s3a.endpoint=http://minio-ml-workshop:9000 \
--conf spark.hadoop.fs.s3a.access.key=minio \
--conf spark.hadoop.fs.s3a.secret.key=minio123 \
--conf spark.hadoop.fs.s3a.path.style.access=true \
--conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem \
--conf spark.hadoop.fs.s3a.multipart.size=104857600 \
--packages org.apache.hadoop:hadoop-aws:3.2.0"

spark = spark_util.getOrCreateSparkSession("Clean flights data", submit_args)
spark.sparkContext.setLogLevel("INFO")

Initializing environment variables for Spark
Creating a spark session...
Spark session created


In [97]:
data_location = "s3a://flights-data/flights/*.parquet"
df_flights = spark.read.parquet(data_location)
#df_flights.printSchema()
#df_flights.count()

In [98]:
df_clean = df_flights.drop("AL_IATA_CODE", "ORIG_IATA_CODE", "DEST_IATA_CODE")
#df_clean.printSchema()

In [99]:
df_distinct = df_clean.agg(*(approx_count_distinct(col(c)).alias(c) for c in df_clean.columns))

cols, dtypes = zip(*((c, t) for (c, t) in df_distinct.dtypes))

kvs = explode(array([
      struct(lit(c).alias("column_name"), col(c).alias("distinct_count")) for c in cols
    ])).alias("kvs")

distinct_count = df_distinct\
    .select([kvs]).select(["kvs.column_name", "kvs.distinct_count"])

uni_value_fields = distinct_count.filter(distinct_count.distinct_count == 1)

#uni_value_fields.show(50)

In [100]:
col_names = [str(row.column_name) for row in uni_value_fields.select("column_name").collect()]

df_clean = df_clean.drop(*col_names)
#df_clean.printSchema()

In [101]:
delay_threshold = 15


@udf("integer")
def is_delayed(departure_delay, cancelled):
    if(cancelled == 1):
        return 0
    if(departure_delay >= delay_threshold):
        return 1
    return 0


df_clean = df_clean.withColumn("DELAYED", is_delayed(df_clean.departure_delay, df_clean.cancelled))
#df_clean.select("month", "day", "flight_number", "departure_delay", "cancelled", "DELAYED").show(50)

In [102]:
output_location = "s3a://flights-data/flights-clean"
df_clean.cache() #this is to make sure the DAG is not recalculated when we call the .count() later
df_clean.write.mode("overwrite")\
    .option("header","true")\
    .format("parquet").save(output_location)

df_clean.count()

5332914

In [103]:
spark.stop()