In [None]:
import datetime, math 
import config
from pyspark.sql import SparkSession
from pyspark.sql.functions import split, lower, upper, col, trim, year, month, dayofweek, translate, concat, regexp_replace, date_format, date_add, round
from pyspark.sql.types import IntegerType, StringType, IntegerType, DoubleType
from pyspark import SparkConf, SparkContext
from redShiftFunctions import redShiftQuery, redShiftWrite, redShiftRead


# REPLACE ACCESS KEY AND SECRET KEY in config.py file!
conf = SparkConf() \
       .set("spark.hadoop.fs.s3a.access.key", config.s3aAccessKey) \
       .set("spark.hadoop.fs.s3a.secret.key", config.s3aSecretKey)\
       .set("spark.hadoop.fs.s3a.endpoint", config.s3aEndpoint)

spark = SparkSession \
        .builder.master("local[*]") \
        .appName("Main ETL for the Capstone Project")\
        .config(conf=conf)\
        .getOrCreate()


#Used to collect all dataframes.
tableNamesArray = []
tableNamesArray.clear()

In [None]:

#  ------  PORT LOCATION TABLE -----------

#Read the Portal Codes json file.
textFile = spark.sparkContext.textFile(config.prtlCodes_file)
dfPortCodes = spark.read.json(textFile)

#Remove any rows that do not contain a name and drop duplicates.
dfPortCodes = dfPortCodes.filter(~dfPortCodes.name.contains("No PORT") & ~dfPortCodes.name.contains("Collapsed") ).drop_duplicates()

#Split the name to columns municipality and statecode.
nameSplit = split(dfPortCodes.name, ',')
dfPortCodes = dfPortCodes.withColumn('municipality', trim (lower( nameSplit.getItem(0) ) ) ) \
                         .withColumn('statecode', trim(upper(nameSplit.getItem(1)))  ) 

#Remove any rows without state code. Drop column name as its been split to two new columns.
dfPortCodes = dfPortCodes.filter(~dfPortCodes.statecode.isNull()).drop("name")

dfStateCode = spark.read.options(header=True, delimiter=',', inferSchema='True').csv(config.stateCodes_file)

#Joins Portal Codes and State Codes together to have full state name.
dfPortCodes = dfPortCodes.join(dfStateCode, dfPortCodes.statecode == dfStateCode["Alpha code"], "left")

#Drops Alpha code which is a duplicate of state code.
dfPortCodes = dfPortCodes.drop("Alpha code").withColumn("state", lower(dfPortCodes.State))

dfPortCodes.createOrReplaceTempView("PORTCODES")

if config.limitDFtoTen == True:
    dfPortCodes = dfPortCodes.limit(10)

#Appends Data frame with Key to be used at the end to insert into Database.
tableNamesArray.append({'PORTCODES':dfPortCodes})

In [None]:
#
#  ------  AIRPORTS TABLE
#

dfAirPorts = spark.read.option("header",True).csv(config.airportCodes_file)

# Our analysis will be set for US only.
dfAirPorts = dfAirPorts.filter( (dfAirPorts.iso_country == 'US') & (dfAirPorts.type != 'heliport') & (dfAirPorts.type != 'closed') )
latLongSplit = split(dfAirPorts.coordinates, ',')
isoRegSplit = split(dfAirPorts.iso_region, '-')

dfAirPorts = dfAirPorts.withColumn('longitude', latLongSplit.getItem(0).cast(DoubleType())) \
        .withColumn('latitude', latLongSplit.getItem(1).cast(DoubleType())) \
        .withColumn('statecode', upper(isoRegSplit.getItem(1))) \
        .withColumn('municipality', lower(col("municipality"))) \
        .withColumn('name', lower(col("name"))) \
        .withColumn('elevation_ft', col("elevation_ft").cast(IntegerType())) 


#Remove columns continent and iso_country since it is US only.
dfAirPorts = dfAirPorts.drop("iso_region", "continent", "coordinates", "gps_code", "local_code")

if config.limitDFtoTen == True:
    dfAirPorts = dfAirPorts.limit(10)

#Append table.
tableNamesArray.append({'AIRPORTS':dfAirPorts})

In [None]:
#
#  ------  USATEMP  TABLE
#

dfTemps = spark.read.options(header=True, delimiter=',', inferSchema='True').csv(config.globalLandTemp_file)

#Filter for US only and removes any rows where Average Temp does not exist.
dfTemps = dfTemps.filter( (dfTemps.Country == 'United States') & (dfTemps.AverageTemperature != 'Nan') )

#Splits date into year, month, and day of week. Lower cases state to make it easy to join.
dfTemps = dfTemps.withColumn('year', year(dfTemps.dt)) \
          .withColumn('month', month(dfTemps.dt))\
          .withColumn('dayOfweek', dayofweek(dfTemps.dt))\
          .withColumn('State', lower (dfTemps.State)) 

#Create an unique ID for the rows. Rounds Tempature to 3 decimel places.
dfTemps = dfTemps.withColumn("id", concat(translate(dfTemps["dt"], "-" , ""), dfTemps.State )) \
                 .withColumn("AverageTemperature", round("AverageTemperature", 3)) \
                 .withColumn("AverageTemperatureUncertainty", round("AverageTemperatureUncertainty", 3))

dfTemps = dfTemps.drop("dt", "Country")

if config.limitDFtoTen == True:
    dfTemps = dfTemps.limit(10)
    
#Append table.
tableNamesArray.append({'USATEMP':dfTemps})

In [None]:
#
#  ------  USADEMOGRAPHICS  TABLE
#
dfCityDemo = spark.read.options(header=True, delimiter=';').csv(config.usDemograph_file)

#Lower cases columns to be used easily when joining.
dfCityDemo = dfCityDemo.withColumn('City', lower(dfCityDemo.City)) \
                        .withColumn('State', lower(dfCityDemo.State)) \
                        .withColumn('count', col("count").cast(IntegerType()) ) 

#Pivots Race to reduce duplicated rows with only race and count being different.
dfCityDemo = dfCityDemo.groupBy("City", "State", "Median Age", "Male Population", "Female Population", "Total Population", "Number of Veterans", "Foreign-born", "Average Household Size", "State Code").pivot("Race").sum("Count") 

#Renames columns to lower case.
dfCityDemo = dfCityDemo.withColumnRenamed('City', 'city')\
                        .withColumnRenamed('State', 'state')\
                        .withColumnRenamed('Median Age', 'median_age')\
                        .withColumnRenamed('Male Population', 'male_population') \
                        .withColumnRenamed('Female Population', 'female_population') \
                        .withColumnRenamed('Total Population', 'total_population') \
                        .withColumnRenamed('Number of Veterans', 'number_of_veterans') \
                        .withColumnRenamed('Foreign-born', 'foreign_born') \
                        .withColumnRenamed('Average Household Size', 'average_household_size') \
                        .withColumnRenamed('State Code', 'state_code') \
                        .withColumnRenamed('American Indian and Alaska Native', 'native') \
                        .withColumnRenamed('Asian', 'asian') \
                        .withColumnRenamed('Black or African-American', 'black_african_america') \
                        .withColumnRenamed('Hispanic or Latino', 'hispanic_latino') \
                        .withColumnRenamed('White', 'white') \

#casts columns into appropriate type. 
dfCityDemo = dfCityDemo.withColumn('median_age', col("median_age").cast(DoubleType())) \
             .withColumn('male_population', col("male_population").cast(IntegerType()) ) \
             .withColumn('female_population', col("female_population").cast(IntegerType()) ) \
             .withColumn('total_population', col("total_population").cast(IntegerType()) ) \
             .withColumn('number_of_veterans', col("number_of_veterans").cast(IntegerType()) ) \
             .withColumn('foreign_born', col("foreign_born").cast(IntegerType()) ) \
             .withColumn('average_household_size', col("average_household_size").cast(DoubleType()) ) \
             .withColumn('native', col("native").cast(IntegerType()) ) \
             .withColumn('asian', col("asian").cast(IntegerType()) ) \
             .withColumn('black_african_america', col("black_african_america").cast(IntegerType()) ) \
             .withColumn('hispanic_latino', col("hispanic_latino").cast(IntegerType()) ) \
             .withColumn('white', col("white").cast(IntegerType()) ) 

#Verifies that data that exists in Port Code table is inserted only.
dfCityDemo.createOrReplaceTempView("USADEMOGRAPHICS")
dfCityDemo = spark.sql("SELECT * FROM USADEMOGRAPHICS WHERE (state_code, city) in (select statecode, municipality from PORTCODES)")

if config.limitDFtoTen == True:
    dfCityDemo = dfCityDemo.limit(10)

tableNamesArray.append({'USADEMOGRAPHICS':dfCityDemo})


In [None]:
#
#  ------  IMMIGRATION  TABLE
#
dfImmigration = spark.read.options(inferSchema='True').parquet(config.immig_file)

#Filter dataset to mode 1 which is airplane.
dfImmigration = dfImmigration.filter(dfImmigration["i94mode"] == 1)

#Cast into Int
dfImmigration = dfImmigration \
                .withColumn("arrdate", regexp_replace(dfImmigration["arrdate"], '\..*$', '').cast(IntegerType()) ) \
                .withColumn("depdate", regexp_replace(dfImmigration["depdate"], '\..*$', '').cast(IntegerType()) )

#Used in the next sql.
dfImmigration.createOrReplaceTempView("dateTable")

#Adds arrdate and depdate to the year 1960.
dfImmigration = spark.sql("""
                    SELECT T1.*, 
                    date_add('1960-01-01', arrdate ) as arrival_full,
                    day(date_add('1960-01-01', arrdate ) ) as arrival_day,
                    month(date_add('1960-01-01', arrdate ) ) as arrival_month,
                    year(date_add('1960-01-01', arrdate ) ) as arrival_year,
                    date_add('1960-01-01', depdate ) as dep_full,
                    day(date_add('1960-01-01', depdate) ) as dep_day,
                    month(date_add('1960-01-01', depdate) ) as dep_month,
                    year(date_add('1960-01-01', depdate) ) as dep_year
                    FROM dateTable T1 
                    """)
#Drop not needed columns.
dfImmigration = dfImmigration.drop("depdate","i94mode","count", "admnum",  \
                                   "entdepa","entdepd","entdepu","matflag","insnum", "i94mon", "i94yr" , "arrdate")
#Rename Columns
dfImmigration = dfImmigration\
                        .withColumnRenamed('i94cit', 'city') \
                        .withColumnRenamed('i94res', 'residence') \
                        .withColumnRenamed('i94port', 'port') \
                        .withColumnRenamed('i94addr', 'address') \
                        .withColumnRenamed('i94bir', 'age') \
                        .withColumnRenamed('i94visa', 'visa') \

#Cast Columns
dfImmigration = dfImmigration \
                .withColumn('cicid', col("cicid").cast(IntegerType()) ) \
                .withColumn('city', col("city").cast(IntegerType()) ) \
                .withColumn('residence', col("residence").cast(IntegerType()) ) \
                .withColumn('age', col("age").cast(IntegerType()) ) \
                .withColumn('visa', col("visa").cast(IntegerType()) ) \
                .withColumn('biryear', col("biryear").cast(IntegerType()) ) \


if config.limitDFtoTen == True:
    dfImmigration = dfImmigration.limit(10)

tableNamesArray.append({'IMMIGRATION':dfImmigration})

In [None]:
#Writes all Dataframes into Redshift.
for tables in tableNamesArray:
    [[key, value]] = tables.items()
    redShiftWrite(value, key)

In [None]:
#Confirms data has been inserted into the tables.
for tables in tableNamesArray:
    [[key, value]] = tables.items()
    currentTable = redShiftRead(spark, key)
    if currentTable.count() == value.count():
        print("Count correct on table : " + key)
    else:
        print("Count incorrect on table : " + key)