In [1]:
import configparser
from datetime import datetime
from datetime import timedelta
import os

from pyspark.sql import SparkSession, SQLContext, GroupedData, HiveContext
from pyspark.sql.functions import udf, col, lit, expr, when, regexp_replace, floor, split, abs, concat, round
from pyspark.sql.functions import year, month, quarter, dayofmonth, hour, weekofyear, date_format, date_add, mean
from pyspark.sql.functions import monotonically_increasing_id

from pyspark.sql.types import StructField
from pyspark.sql.types import StructType
from pyspark.sql.types import IntegerType, StringType, DoubleType, DateType

#####

import pandas as pd
import os
import configparser
import datetime as dt
import time

#from pyspark.sql.functions import isnan, when, count, col, udf, dayofmonth, dayofweek, month, year, weekofyear, avg, monotonically_increasing_id


In [2]:
# Guidance from Udacity Knowledge - https://knowledge.udacity.com/questions/911823

def create_spark_session():
    """
    Create spark session, utilising saurfang's Spark SAS package
    
    OUTPUT: returns spark session to main
    """
    spark = SparkSession.builder.\
    config("spark.jars.repositories", "https://repos.spark-packages.org/").\
    config("spark.jars.packages", "saurfang:spark-sas7bdat:2.0.0-s_2.11").\
    enableHiveSupport().getOrCreate()
    
    return spark

In [3]:
def process_airport_data(spark, input_data, output_data):
    """
    Loads airport data from local storage and creates:
    - 'airports' table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    
    # get filepath to airport data file
    airport_data = input_data

    schema = StructType([
        StructField("ident", StringType(), True),
        StructField("type", StringType(), True),
        StructField("name", StringType(), True),
        StructField("elevation_ft", DoubleType(), True),
        StructField("continent", StringType(), True),
        StructField("iso_country", StringType(), True),
        StructField("iso_region", StringType(), True),
        StructField("municipality", StringType(), True),
        StructField("gps_code", StringType(), True),
        StructField("iata_code", StringType(), True),
        StructField("local_code", StringType(), True),
        StructField("coordinates", StringType(), True)
        ]
    )
    
    
    # read airport data file
    df = spark.read.format("csv").option("header", "True").load(airport_data)
    
    # Two-letter state codes are given in the format 'US-xx'
    # Extract two-letter code from this column
    df = df.withColumn('state_code', when(df.iso_region.startswith('US-'), regexp_replace(df.iso_region,'US-','')))
    
    # Latitude and longitude are given as a tuple
    # Separate into distinct columns 
    df = df.withColumn('lat', split(df['coordinates'],',').getItem(0).cast('double'))
    df = df.withColumn('long', split(df['coordinates'],',').getItem(1).cast('double'))
    
    # drop rows with nulls in latitude or longitude
    df = df.na.drop(subset=['lat','long'])
    
    # Process latitude and longitude figures into the format given for temperature data
    # E.g., from -56.52352135 to 56.52S
    #
    # Guidance taken from stack overflow user Daniel de Paula for regexp_replace use
    # https://stackoverflow.com/questions/37038014/pyspark-replace-strings-in-spark-dataframe-column
    #
    # Guidance taken from Udacity GPT for startswith use
    df = df.withColumn('lat_2dp', round(df.lat, 2))
    df = df.withColumn('abs_lat', when(df.lat_2dp.startswith('-'), regexp_replace(df.lat_2dp,'-','')).otherwise(df.lat_2dp))
    df = df.withColumn('latitude_direction', when(df.lat.startswith('-'), 'S').otherwise('N'))
    df = df.withColumn('long_2dp', round(df.long, 2))
    df = df.withColumn('abs_long', when(df.long_2dp.startswith('-'), regexp_replace(df.long_2dp,'-','')).otherwise(df.long_2dp))
    df = df.withColumn('longitude_direction', when(df.long.startswith('-'), 'W').otherwise('E'))
    
    
    # extract columns to create airports table
    airport_table = df.select(col('municipality').alias('city'),
                              'continent', 
                              'iso_country',
                              'state_code',
                              concat(df.abs_lat, df.latitude_direction).alias('latitude'),
                              concat(df.abs_long, df.longitude_direction).alias('longitude')) \
                                .dropDuplicates()
    
    # Save table to parquet file
    airport_table.write.mode('overwrite') \
        .parquet(output_data)
        
    # Create or replace airports view; print confirmation
    print('Creating airports table')
    airport_table.createOrReplaceTempView('airports')
    print('---Done---')
    
    return airport_table

In [4]:
def process_demographics_data(spark, input_data, output_data):
    """
    Loads demographics data from local storage and creates:
    - demographics table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    # get filepath to airport data file
    demographics_data = input_data
    
    schema = StructType() \
        .add("City", StringType(), True) \
        .add("State", StringType(), True) \
        .add("Median Age", DoubleType(), True) \
        .add("Male Population", IntegerType(), True) \
        .add("Female Population", IntegerType(), True) \
        .add("Total Population", IntegerType(), True) \
        .add("Number of Veterans", IntegerType(), True) \
        .add("Foreign-born", IntegerType(), True) \
        .add("Average Household Size", IntegerType(), True) \
        .add("State Code", StringType(), True) \
        .add("Race", StringType(), True) \
        .add("Count", IntegerType(), True)
    
    # read demographics data file
    df = spark.read.format("csv").option("delimiter", ";").option("header", "True").load(demographics_data)
    
    # extract columns to create demographics table
    demographics_table = df.select(col('City').alias('city'),
                                col('Median Age').alias('median_age'),
                                col('Male Population').alias('male_population'),
                                col('Female Population').alias('female_population'),
                                col('Total Population').alias('total_population'),
                                col('Number of Veterans').alias('veteran_population'),
                                col('Foreign-born').alias('foreign_born_population'),
                                col('Average Household Size').alias('avg_household_size'),
                                col('State Code').alias('state_code')).dropDuplicates()
    
    # extract columns to create race table
    race_table = df.select(col('City').alias('city'),
                          col('State Code').alias('state_code'),
                          col('Race').alias('race'),
                          col('Count').alias('count')).dropDuplicates()
    
    # extract columns to create state_code table
    state_code_table = df.select(col('State Code').alias('state_code'),
                                col('State').alias('state')).dropDuplicates()
    
        
    # Save tables to parquet files
    demographics_table.write.mode('overwrite') \
        .parquet(output_data)
    race_table.write.mode('overwrite') \
        .parquet(output_data)
    state_code_table.write.mode('overwrite') \
        .parquet(output_data)
    
    # Create or replace demographics, race and state_code tables
    print('Creating demographics table')
    demographics_table.createOrReplaceTempView('demographics')
    print('---Done---')
    print('Creating race table')
    race_table.createOrReplaceTempView('race_counts')
    print('---Done---')
    print('Creating state_code table')
    state_code_table.createOrReplaceTempView('state_table')
    print('---Done---')
    
    return demographics_table

In [5]:
def process_temperature_data(spark, input_data, output_data):
    """
    Loads temperature data from local storage and creates:
    - temperature table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    # get filepath to temperature data file
    temperature_data = input_data
    
    schema = StructType() \
        .add("dt", DateType(), True) \
        .add("AverageTemperature", DoubleType(), True) \
        .add("AverageTemperatureUncertainty", DoubleType(), True) \
        .add("City", StringType(), True) \
        .add("Country", StringType(), True) \
        .add("Latitude", StringType(), True) \
        .add("Longitude", StringType(), True)
    
    
    # read demographics data file
    df = spark.read.format("csv").option("header", "True").load(temperature_data)
    
    # keep only rows from the last 10 years
    df = df.filter(col('dt')>lit('2003-09-01'))
    
    # extract columns to create temperature table
    # extract month from the date column
    # group by all columns while averaging the temperature
    temp_avg_table = df.select('dt','AverageTemperature', 'City', 'Country', 'Latitude', 'Longitude') \
                            .groupby(month('dt').alias('month'),
                                     col('City').alias('city'), 
                                     col('Country').alias('country'), 
                                     col('Latitude').alias('latitude'), 
                                     col('Longitude').alias('longitude')
                                    ) \
                            .agg(mean('AverageTemperature').alias('avg_temp'))
    
    # Save table to parquet file
    temp_avg_table.write.mode('overwrite').partitionBy('month') \
        .parquet(output_data)
    
    # Create or replace temperature table
    print('Creating temperature table')
    temp_avg_table.createOrReplaceTempView('temperature')
    print('---Done---')
    
    return temp_avg_table

In [6]:
def process_immigration_data(spark, input_data, output_data):
    """
    Loads immigration data from local storage and creates:
    - immigration table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    # get filepath to immigration data file
    immigration_data = input_data
    
    schema = StructType() \
        .add("cicid", IntegerType(), True) \
        .add("i94yr", IntegerType(), True) \
        .add("i94mon", IntegerType(), True) \
        .add("i94cit", IntegerType(), True) \
        .add("i94res", IntegerType(), True) \
        .add("i94port", StringType(), True) \
        .add("arrdate", IntegerType(), True) \
        .add("i94mode", IntegerType(), True) \
        .add("i94addr", StringType(), True) \
        .add("depdate", IntegerType(), True) \
        .add("i94bir", IntegerType(), True) \
        .add("i94visa", IntegerType(), True) \
        .add("count", IntegerType(), True) \
        .add("dtadfile", IntegerType(), True) \
        .add("visapost", StringType(), True) \
        .add("occup", StringType(), True) \
        .add("entdepa", StringType(), True) \
        .add("entdepd", StringType(), True) \
        .add("entdepu", StringType(), True) \
        .add("matflag", StringType(), True) \
        .add("biryear", IntegerType(), True) \
        .add("dtaddto", StringType(), True) \
        .add("gender", StringType(), True) \
        .add("insnum", DoubleType(), True) \
        .add("airline", StringType(), True) \
        .add("admnum", DoubleType(), True) \
        .add("fltno", StringType(), True) \
        .add("visatype", StringType(), True)

    
    # UDF to obtain arrival date by adding 'arrdate' number of days to 01/01/1960
    ### From Udacity GPT ###
    def add_days(days):
        new_date = datetime.strptime('1960-01-01', '%Y-%m-%d') + timedelta(days=days)
        return new_date.date()

    add_days_udf = udf(add_days, DateType())
    #######################
    
    
   
    # read immigration data file
    df = spark.read.format('com.github.saurfang.sas.spark').load(input_data)
    
    # set arr_date to be integer format
    df = df.withColumn('arrdate_int', floor(col('arrdate'))) \
           .withColumn('depdate_int', floor(col('depdate')))
    # apply UDF
    df = df.withColumn('arrival_date', add_days_udf(df.arrdate_int))
    
    # extract columns to create immigration table
    immigration_table = df.select(col('cicid').alias('cicid').cast("integer"),
                                  col('i94yr').alias('I94_year').cast("integer"),
                                  col('i94mon').alias('I94_month').cast("integer"),
                                  col('i94cit').alias('cit_code').cast("integer"),
                                  col('i94res').alias('res_code').cast("integer"),
                                  col('i94port').alias('arrival_port'),
                                  col('i94addr').alias('state_code'),
                                  col('arrdate_int'),
                                  col('arrival_date'),
                                  col('depdate_int'),
                                  col('i94bir').alias('age').cast("integer"),
                                  col('i94visa').alias('visa_code').cast("integer"),
                                  col('matflag').alias('match_flag'),
                                  col('biryear').alias('birth_year').cast("integer"),
                                  col('dtaddto').alias('allowed_until'),
                                  col('gender'),
                                  col('airline')) \
                                    .withColumn('immigrant_id', monotonically_increasing_id()) \
                                    .dropDuplicates()

    # extract columns to create arrival_dates table
    date_table = df.select(col('arrdate_int'),
                           col('arrival_date')) \
                            .withColumn('day', dayofmonth('arrival_date')) \
                            .withColumn('week', weekofyear('arrival_date')) \
                            .withColumn('month', month('arrival_date')) \
                            .withColumn('quarter', quarter('arrival_date')) \
                            .withColumn('year', year('arrival_date')) \
                            .dropDuplicates()
    
    # Save tables to parquet file
    immigration_table.write.mode('overwrite').partitionBy('I94_year', 'I94_month') \
        .parquet(output_data)
    date_table.write.mode('overwrite') \
        .parquet(output_data)
    
    # Create or replace immigration and dates tables
    print('Creating immigration table')
    immigration_table.createOrReplaceTempView('immigration')
    print('---Done---')    
    print('Creating date table')
    date_table.createOrReplaceTempView('arrival_dates')
    print('---Done---')
    
    return immigration_table, date_table

In [7]:
def gather_port_data(spark):
    """
    Gather port codes from local storage and creates:
    - port_codes table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    
    
    file_path = 'I94_SAS_Labels_Descriptions.SAS'
    
    # load data and split rows into elements of a list
    with open(file_path) as file:
        i94_label_data = file.readlines()
        i94_label_data = [item.split("\n")[0] for item in i94_label_data]
    
    # isolate port codes
    portcode_start = [i for i, item in enumerate(i94_label_data) if 'I94PORT' in item][0]
    portcode_end   = [i for i, item in enumerate(i94_label_data) if 'ARRDATE' in item][0]
    port_code_list = i94_label_data[portcode_start+2:portcode_end-3]
    
    # extract port code, port name and state code from each element
    port_codes = []
    for item in port_code_list:
        code_item = item.split('\t=\t')
        code_item = [code_item[0].replace("'",""), code_item[1].replace("'","")]
        code_item = [code_item[0].replace(" ",""), code_item[1].replace(" ","")]
        code_item = [code_item[0]]+code_item[1].split(",")
        
        # some (~10) port names contain commas, which we ignore
        if len(code_item) == 3:
            port_codes = port_codes+[code_item]
        
    
    df = pd.DataFrame(data=port_codes,columns=['port_code','city','state_code'])
    
    DF = spark.createDataFrame(df)
    
    # extract columns to create port_codes table
    port_table = DF.select(col('port_code'),col('city'),col('state_code')) \
                        .dropDuplicates()

    
    # Create or replace port_codes table
    print('Creating port code table')
    port_table.createOrReplaceTempView('port_codes')
    print('---Done---')
    
    return port_table

In [8]:
def gather_cit_res_data(spark):
    """
    Gathers citizen and resident codes from local storage and creates:
    - cit_res table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    
    
    file_path = 'I94_SAS_Labels_Descriptions.SAS'
    
    # load data and split rows into elements of a list
    with open(file_path) as file:
        i94_label_data = file.readlines()
        i94_label_data = [item.split("\n")[0] for item in i94_label_data]
    
    # isolate cit/res codes
    citres_start = [i for i, item in enumerate(i94_label_data) if 'i94cntyl' in item][0]
    citres_end   = [i for i, item in enumerate(i94_label_data) if 'I94PORT' in item][0]
    citres_list = i94_label_data[citres_start+1:citres_end-2]
        
    # extract cit/res code and country from each element
    citres_codes = []
    for item in citres_list:
        code_item = item.split('=')
        code_item = [code_item[0].replace("'",""), code_item[1].replace("'","")]
        code_item = [code_item[0].replace(" ",""), code_item[1].replace(" ","")]
        
        citres_codes = citres_codes+[code_item]
    
    df = pd.DataFrame(data=citres_codes,columns=['cit_res_code','country'])
    
    DF = spark.createDataFrame(df)
    
    # extract columns to create port_codes table
    citres_table = DF.select(col('cit_res_code'),col('country')) \
                        .dropDuplicates()
    
    # Create or replace citres_codes table
    print('Creating cit/res code table')
    citres_table.createOrReplaceTempView('citres_codes')
    print('---Done---')
    
    return citres_table

In [9]:
def gather_visa_data(spark):
    """
    Gathers visa codes from local storage and creates:
    - visa_code table
    and stores it locally in parquet format
    
    INPUTS: 
    spark - spark session
    input_data - local location of dataset
    output_data - local location for parquet table
    
    OUTPUTS:
    None
    """
    
    
    file_path = 'I94_SAS_Labels_Descriptions.SAS'
    
    # load data and split rows into elements of a list
    with open(file_path) as file:
        i94_label_data = file.readlines()
        i94_label_data = [item.split("\n")[0] for item in i94_label_data]
    
    # isolate visa codes
    visa_start = [i for i, item in enumerate(i94_label_data) if 'I94VISA' in item][0]
    visa_end   = [i for i, item in enumerate(i94_label_data) if '/* COUNT' in item][0]
    visa_list = i94_label_data[visa_start+1:visa_end-3]
    
    # extract visa code and country from each element
    visa_codes = []
    for item in visa_list:
        code_item = item.split(' = ')
        code_item = [code_item[0].replace("'",""), code_item[1].replace("'","")]
        code_item = [code_item[0].replace(" ",""), code_item[1].replace(" ","")]
                
        visa_codes = visa_codes+[code_item]
        
    df = pd.DataFrame(data=visa_codes,columns=['visa_code','visa_type'])
    
    DF = spark.createDataFrame(df)
    
    # extract columns to create port_codes table
    visa_table = DF.select(col('visa_code'),col('visa_type')) \
                        .dropDuplicates()
    
    # Create or replace visa_codes table
    print('Creating visa code table')
    visa_table.createOrReplaceTempView('visa_codes')
    print('---Done---')
    
    return visa_table

In [10]:
def valid_coord_check(temperature_table):
    """
    This function completes a validity check on the temperature table
    It checks whether the obtained values for latitude and longitude are valid
    
    INPUT: temperature_table (PySpark DF)
    OUTPUT: None    
    """
    
    # latitude and longitude values should be in the range [-90, 90] and [-180, 180] respectively, 
    # or, as we have formatted these values, [90S, 90N] and [180W, 180E]. 
    # We should check that:
    #     abs(latitude) <= 90
    #     abs(longitude) <= 180
    lat_long_check = temperature_table.withColumn('latitude_num', expr("substring(latitude, 1, length(latitude)-1)")) \
                                      .withColumn('longitude_num', expr("substring(longitude, 1, length(longitude)-1)"))
    lat_long_check = lat_long_check.withColumn('latitude_num', lat_long_check.latitude_num.cast('double')) \
                                   .withColumn('longitude_num', lat_long_check.longitude_num.cast('double'))
    max_lat = lat_long_check.agg({"latitude_num": "max"}).collect()[0]['max(latitude_num)']
    max_long = lat_long_check.agg({"longitude_num": "max"}).collect()[0]['max(longitude_num)']
    if max_lat > 90:
        print('Invalid latitude in temperature data: {}'.format(max_lat))
    if max_long > 180:
        print('Invalid longitude in temperature data: {}'.format(max_long))
    

In [11]:
def existence_check(table, data_file):
    """
    This function completes a validity check on a table
    It checks whether the constructed table is empty, indicating that an issue has occured in loading the data
    
    INPUT: PySpark DF
    OUTPUT: None    
    """
    
    table_count = table.count()
    if table_count == 0:
        print("Table is empty - check source data: {}",format(data_file))
    else: 
        print("{} loaded successfully with {} rows".format(data_file, table_count))

In [12]:
def main():
    """
    Main function in which:
        - a spark session is created
        - data source locations are defined
        - datasets are processed into tables
        - validity checks are completed
        - output locations are defined
    """
    
    spark = create_spark_session()
    airport_data = 'airport-codes_csv.csv'
    aiport_output = 'airport_parquet.parquet'
    demographics_data = 'us-cities-demographics.csv'
    demographics_output = 'demographics_parquet.parquet'
    temperature_data = '../../data2/GlobalLandTemperaturesByCity.csv'
    temperature_output = 'temperature_parquet.parquet'
    immigration_data = '../../data/18-83510-I94-Data-2016/i94_apr16_sub.sas7bdat'
    immigration_output = 'immigration_parquet.parquet'
    
    port_table = gather_port_data(spark)
    citres_table = gather_cit_res_data(spark)
    visa_table = gather_visa_data(spark)
    
    airport_table = process_airport_data(spark, airport_data, aiport_output)
    demographics_table = process_demographics_data(spark, demographics_data, demographics_output)
    temperature_table = process_temperature_data(spark, temperature_data, temperature_output)
    immigration_table, arrival_date_table = process_immigration_data(spark, immigration_data, immigration_output)
    
    data_sets = [[airport_table, airport_data],
                 [demographics_table, demographics_data],
                 [temperature_table, temperature_data],
                 [immigration_table, immigration_data]]
    for data in data_sets:
        existence_check(data[0],data[1])
    valid_coord_check(temperature_table)
    
    
    # Example queries
    
    #spark.sql("""SELECT im.cicid, im.arrdate_int, im.arrival_port, dt.arrdate_int, dt.arrival_date, dt.month, temp.avg_temp FROM immigration im
    #          LEFT JOIN arrival_dates dt
    #          ON dt.arrdate_int = im.arrdate_int
    #          LEFT JOIN temperature temp
    #          ON temp.month = dt.month
    #          LIMIT 10
    #         """).show(10)
    
    spark.sql("""SELECT * FROM airports LIMIT 10""").show(10)
    
    #spark.sql("""SELECT * FROM temperature LIMIT 10""").show(10)
    
if __name__ == "__main__":
    main()

Creating port code table
---Done---
Creating cit/res code table
---Done---
Creating visa code table
---Done---
Creating airports table
---Done---
Creating demographics table
---Done---
Creating race table
---Done---
Creating state_code table
---Done---
Creating temperature table
---Done---
Creating immigration table
---Done---
Creating date table
---Done---
airport-codes_csv.csv loaded successfully with 53818 rows
us-cities-demographics.csv loaded successfully with 596 rows
../../data2/GlobalLandTemperaturesByCity.csv loaded successfully with 42120 rows
../../data/18-83510-I94-Data-2016/i94_apr16_sub.sas7bdat loaded successfully with 3096313 rows
+--------------+---------+-----------+----------+--------+---------+
|          city|continent|iso_country|state_code|latitude|longitude|
+--------------+---------+-----------+----------+--------+---------+
|      Tonasket|       NA|         US|        WA| 119.32S|   48.75E|
|        Boston|       NA|         US|        MA|  71.13S|   42.37E|
