# Daily Covid-19 Data ETL

## Imports and creating spark session

In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F

import pandas as pd
from datetime import datetime

In [2]:
spark = SparkSession \
    .builder \
    .appName("Data Engineering Capstone Project") \
    .getOrCreate()

In [3]:
spark.conf.set("spark.sql.repl.eagerEval.enabled", True)

## Data loading and cleaning

In [4]:
df = spark.read.csv('covid_19_daily_reports', header=True)

In [5]:
df.count()

1590041

In [6]:
df

FIPS,Admin2,Province_State,Country_Region,Last_Update,Lat,Long_,Confirmed,Deaths,Recovered,Active,Combined_Key,Incident_Rate,Case_Fatality_Ratio
,,,Afghanistan,2021-02-27 05:22:28,33.93911,67.709953,55696,2442,49285,3969,Afghanistan,143.0731404659654,4.384515943694341
,,,Albania,2021-02-27 05:22:28,41.1533,20.1683,105229,1756,68007,35466,Albania,3656.5779414830768,1.6687415066188978
,,,Algeria,2021-02-27 05:22:28,28.0339,1.6596,112805,2977,77842,31986,Algeria,257.2458766830244,2.639067417224414
,,,Andorra,2021-02-27 05:22:28,42.5063,1.5218,10822,110,10394,318,Andorra,14006.341810651653,1.0164479763444834
,,,Angola,2021-02-27 05:22:28,-11.2027,17.8739,20759,504,19307,948,Angola,63.16202375030837,2.42786261380606
,,,Antigua and Barbuda,2021-02-27 05:22:28,17.0608,-61.7964,701,14,271,416,Antigua and Barbuda,715.832039866024,1.9971469329529243
,,,Argentina,2021-02-27 05:22:28,-38.4161,-63.6167,2098728,51887,1892834,154007,Argentina,4643.637391165993,2.4723070354995977
,,,Armenia,2021-02-27 05:22:28,40.0691,45.0382,171510,3183,163165,5162,Armenia,5787.93304882436,1.855868462480322
,,Australian Capita...,Australia,2021-02-27 05:22:28,-35.4735,149.0124,118,3,115,0,Australian Capita...,27.563653352020555,2.542372881355932
,,New South Wales,Australia,2021-02-27 05:22:28,-33.8688,151.2093,5172,54,0,5118,"New South Wales, ...",63.71027346637103,1.0440835266821349


In [7]:
df = df.where(F.col("Country_Region") == "US")

In [8]:
df.count()

1328965

In [9]:
# note: input files are in format /path/to/file/MM-dd-yyyy.csv
# the final line of the select statement grabs the first 10 characters
# starting at the 14th-from-last character of the file path, which results
# in the date of the record.

# numbers 

daily_df = df.select(
    F.col("Province_State").alias("state"),
    F.to_date(F.substring(F.input_file_name(), -14, 10), "MM-dd-yyyy").alias("date"),
    F.col("Confirmed").cast("integer").alias("confirmed"),
    F.col("Deaths").cast("integer").alias("deaths")
).where(~F.col("state").isin(
        "Recovered", "District of Columbia", "Grand Princess",
        "Puerto Rico", "American Samoa", "Northern Mariana Islands",
        "Guam", "Diamond Princess", "Wuhan Evacuee")
)

# group by state and date and sum up confirmed and death counts
daily_df = daily_df.groupBy("state", "date").agg(
    F.sum("confirmed").alias("confirmed"),
    F.sum("deaths").alias("deaths"),
)

# select daily increments rather than accumulating values
w = Window().partitionBy("state").orderBy("date")
daily_df = daily_df.select(
    "state", "date",
    (F.col("confirmed") - F.lag("confirmed", default=0).over(w)).alias("confirmed"),
    (F.col("deaths") - F.lag("deaths", default=0).over(w)).alias("deaths")
)

In [None]:
daily_df

## Data validation

### Check #1: data does not contain null values

In [None]:
data_validation_df1 = daily_df.where(
    (F.col("state").isNull())
    | (F.col("date").isNull())
    | (F.col("confirmed").isNull())
    | (F.col("deaths").isNull())
)
assert data_validation_df1.count() == 0

### Check #2: each day has data for all 50 states

In [None]:
data_validation_df2 = daily_df.groupBy("date").count()
data_validation_df2 = data_validation_df2.where(F.col("count") != 50)
assert data_validation_df2.count() == 0

## Data transformation and preparation

### Hexadecimal color codes used to "bucket" aggregated data

In [643]:
# from least severe to most severe
c1 = "#FFFED1"  # light yellow
c2 = "#FFE98F"
c3 = "#FFDF8F"
c4 = "#FDC979"
c5 = "#FFAA61"
c6 = "#FF9161"
c7 = "#FF7149"
c8 = "#E55C36"
c9 = "#D84E27"
c10 = "#CB2921"  # deep red

In [644]:
max_confirmed = daily_df.agg(F.expr("percentile(confirmed, 0.95)").alias("confirmed")).collect()[0].confirmed
max_deaths = daily_df.agg(F.expr("percentile(deaths, 0.95)").alias("deaths")).collect()[0].deaths

In [645]:
confirmed_bucket_interval = max_confirmed // 10
deaths_bucket_interval = max_deaths // 10

### 7-day and 30-day rollups

In [646]:
# note: this data source begins on March 22, 2020 and goes until the current day (at the time this
# project was being worked on); the where clause exists because the window function produces
# rows that start before March 22, 2020.

seven_day_df = (
    daily_df.groupBy("state", F.window("date", windowDuration="7 day", slideDuration="1 day"))
    .agg(
        F.avg("confirmed").cast("integer").alias("confirmed"),
        F.avg("deaths").cast("integer").alias("deaths")
    )
    .withColumn("date", (F.col("window").start).cast("date"))
    .drop("window")
    .where(F.col("date") >= F.lit(datetime(2020, 3, 22)))
)

In [647]:
# geo ID will be correlated with the GeoJSON data in the web app to retrieve the county name
seven_day_df

state,confirmed,deaths,date
Hawaii,4,0,2020-06-04
Hawaii,222,1,2020-08-18
Minnesota,771,6,2020-08-31
Minnesota,974,7,2021-03-10
Ohio,510,43,2020-04-25
Arkansas,154,2,2021-04-04
Oregon,761,3,2021-04-30
Texas,4164,130,2020-08-28
Texas,4149,126,2020-08-31
Texas,3676,132,2021-03-16


In [648]:
thirty_day_df = (
    daily_df.groupBy("state", F.window("date", windowDuration="30 day", slideDuration="1 day"))
    .agg(
        F.avg("confirmed").cast("integer").alias("confirmed"),
        F.avg("deaths").cast("integer").alias("deaths")
    )
    .withColumn("date", F.col("window").start.cast("date"))
    .drop("window")
    .where(F.col("date") >= F.lit(datetime(2020, 3, 22)))
)

In [649]:
thirty_day_df

state,confirmed,deaths,date
Utah,405,3,2021-03-24
Hawaii,68,0,2021-02-24
Minnesota,1552,10,2021-04-14
Ohio,1259,20,2020-07-04
Ohio,1027,21,2020-08-16
Ohio,1016,18,2020-08-30
Ohio,4171,49,2020-10-18
Ohio,6482,80,2020-10-29
Arkansas,1200,28,2021-01-22
Arkansas,449,6,2021-02-21


In [650]:
# split data into 10 groups, indicating the severity of cases at a given date

def assign_color(df, metric_col_name, new_column_name, interval):
    """
    Assigns a column color the the DataFrame based on metric value.
    
    
    Parameters:
    -----------
    df: DataFrame
    measure_col_name: the name of the measurement column
    new_column_name: the name of the newly created column
    interval: the bucketing interval
    """
    return df.withColumn(
        new_column_name,
        F.when(F.col(metric_col_name) <= interval * 1, c1)
         .when(F.col(metric_col_name) <= interval * 2, c2)
         .when(F.col(metric_col_name) <= interval * 3, c3)
         .when(F.col(metric_col_name) <= interval * 4, c4)
         .when(F.col(metric_col_name) <= interval * 5, c5)
         .when(F.col(metric_col_name) <= interval * 6, c6)
         .when(F.col(metric_col_name) <= interval * 7, c7)
         .when(F.col(metric_col_name) <= interval * 8, c8)
         .when(F.col(metric_col_name) <= interval * 9, c9)
         .otherwise(c10)
    )


def assign_all_colors(df):
    """
    Assigns color columns for confirmed cases and deaths.
    
    
    Parameters:
    -----------
    df: DataFrame
    """
    df = assign_color(df, "confirmed", "confirmed_color", confirmed_bucket_interval)
    df = assign_color(df, "deaths", "deaths_color", deaths_bucket_interval)
    return df


def prepare_for_export(df):
    """
    Transforms data into the JSON format that the web app will parse to color the
    map. First, it condenses all but the date column into a JSON object. Next, it
    groups by date and combines all of those JSON objects into a list. We are left
    with a mapping from date to the covid data for all counties.
    
    Parameters:
    -----------
    df: either daily_df, seven_day_df, or thirty_day_df
    """
    df = assign_all_colors(df)
    
    non_date_cols = df.columns
    non_date_cols.remove("date")
    
    df = df.select(
        "date",
        F.to_json(F.struct(*non_date_cols)).alias("values")
    )
    
    df = (
        df
        .groupBy("date")
        .agg(F.collect_list("values").alias("values"))
        .orderBy("date")
    )
    
    return df


def assign_colors_and_write_to_json(df, output_path):
    """
    Assigned color columns to dataframe and write to output file in JSON format.
    
    
    Parameters:
    -----------
    df: DataFrame
    output_path: the path of the output file
    """
    df = prepare_for_export(df)
    df.coalesce(1).write.json(output_path)

In [651]:
assign_colors_and_write_to_json(daily_df, "covid_19_daily")
assign_colors_and_write_to_json(seven_day_df, "covid_19_seven_day")
assign_colors_and_write_to_json(thirty_day_df, "covid_19_thirty_day")