# *CS 431 Final Project*

# COVID-19 Dashboard

Name: Namitra Kalicharran

Student ID: 20674483

In [1]:
%matplotlib inline

In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit, date_format, col, when, unix_timestamp, from_unixtime, sum as sql_sum
from functools import reduce
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import numpy as np
import seaborn as sns
import random
import os
import git
import shutil

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [3]:
spark = SparkSession.builder.appName("YourTest").master("local[2]").config('spark.ui.port', random.randrange(4000,5000)).getOrCreate()

In [4]:
git_url = "https://github.com/CSSEGISandData/COVID-19"
git.Repo.clone_from(git_url, "COVID_data")

GitCommandError: Cmd('git') failed due to: exit code(128)
  cmdline: git clone -v https://github.com/CSSEGISandData/COVID-19 COVID_data
  stderr: 'fatal: destination path 'COVID_data' already exists and is not an empty directory.
'

# Time Series Data

In [5]:
path = 'COVID_data/csse_covid_19_data/csse_covid_19_time_series/'

confirmed = spark.read.csv(path+'time_series_covid19_confirmed_global.csv', sep=",", inferSchema=True, header=True).cache()
deaths = spark.read.csv(path+'time_series_covid19_deaths_global.csv', sep=",", inferSchema=True, header=True).cache()
recovered = spark.read.csv(path+'time_series_covid19_recovered_global.csv', sep=",", inferSchema=True, header=True).cache()

country_info = spark.read.csv("COVID_data/csse_covid_19_data", sep=',', inferSchema=True, header=True).cache()

# Check if each table has the same number of columns
print(confirmed.columns)
date_cols = confirmed.columns[4:]
norm_dates = [f'norm_{date}' for date in date_cols]

['Province/State', 'Country/Region', 'Lat', 'Long', '1/22/20', '1/23/20', '1/24/20', '1/25/20', '1/26/20', '1/27/20', '1/28/20', '1/29/20', '1/30/20', '1/31/20', '2/1/20', '2/2/20', '2/3/20', '2/4/20', '2/5/20', '2/6/20', '2/7/20', '2/8/20', '2/9/20', '2/10/20', '2/11/20', '2/12/20', '2/13/20', '2/14/20', '2/15/20', '2/16/20', '2/17/20', '2/18/20', '2/19/20', '2/20/20', '2/21/20', '2/22/20', '2/23/20', '2/24/20', '2/25/20', '2/26/20', '2/27/20', '2/28/20', '2/29/20', '3/1/20', '3/2/20', '3/3/20', '3/4/20', '3/5/20', '3/6/20', '3/7/20', '3/8/20', '3/9/20', '3/10/20', '3/11/20', '3/12/20', '3/13/20', '3/14/20', '3/15/20', '3/16/20', '3/17/20', '3/18/20', '3/19/20', '3/20/20', '3/21/20', '3/22/20', '3/23/20', '3/24/20', '3/25/20', '3/26/20', '3/27/20', '3/28/20', '3/29/20', '3/30/20', '3/31/20', '4/1/20', '4/2/20', '4/3/20', '4/4/20', '4/5/20', '4/6/20', '4/7/20', '4/8/20', '4/9/20', '4/10/20', '4/11/20', '4/12/20', '4/13/20', '4/14/20', '4/15/20', '4/16/20', '4/17/20', '4/18/20', '4/19/2

In [6]:
# Aggregate the rows over each country
confirmed = confirmed.select('Country/Region', *date_cols).groupBy('Country/Region').agg(*[sql_sum(date).alias(date) for date in date_cols])
deaths = deaths.select('Country/Region', *date_cols).groupBy('Country/Region').agg(*[sql_sum(date).alias(date) for date in date_cols])
recovered = recovered.select('Country/Region', *date_cols).groupBy('Country/Region').agg(*[sql_sum(date).alias(date) for date in date_cols])

# Make normalized versions of each dataframe
total = confirmed.union(deaths).union(recovered).groupBy('Country/Region').agg(*[sql_sum(date).alias(date) for date in date_cols])
    

# Get each countries total population
country_info = country_info.select('Country_Region', 'Population').groupBy('Country_Region').agg(sql_sum('Population').alias('Total Population'))

In [7]:
countries = [element[0] for element in confirmed.select('Country/Region').distinct().collect()]

In [8]:
@interact(country=countries, scale=['linear', 'log'])
def plotConfirmedCountry(country, scale):
    cases_ts = list(confirmed.select(*date_cols).where(confirmed['Country/Region'] == country).collect()[0])
    death_count = list(deaths.select(*date_cols).where(deaths['Country/Region'] == country).collect()[0])
    recov_count = list(recovered.select(*date_cols).where(recovered['Country/Region'] == country).collect()[0])
    
    # Plot the data for a country
    fig, ax = plt.subplots(figsize=(15, 10))
    plt.title(f"Confirmed Cases ({country})")
    plt.xlabel('Date')
    plt.ylabel(f'# of Cases\n{scale} scale')
    plt.stackplot(date_cols, *[death_count, recov_count, cases_ts], labels=['Deaths', 'Recovered', 'Confirmed'])
    
    # Axis scales
    plt.xticks(date_cols, rotation='vertical')
    plt.yscale(scale)
    plt.legend(loc='upper left')
    


interactive(children=(Dropdown(description='country', options=('Chad', 'Paraguay', 'Russia', 'Yemen', 'Senegal…

In [9]:
@interact(country=countries)
def plotRelAreaCountry(country):    
    # normalize by column for country
    con = confirmed.select(*date_cols).where(confirmed['Country/Region'] == country).withColumn('Type', lit('Confirmed'))
    ded = deaths.select(*date_cols).where(deaths['Country/Region'] == country).withColumn('Type', lit('Deaths'))
    rec = recovered.select(*date_cols).where(recovered['Country/Region'] == country).withColumn('Type', lit('Recovered'))

    result = con.union(ded).union(rec)
    col_sums = list(result.groupBy().sum().collect()[0])

    for i, date in enumerate(date_cols):
        result = result.withColumn(f'norm_{date}', result[date] / col_sums[i]).drop(date)
    result = result.na.fill(0)
    
    # Organize into
    x = date_cols
    
    cases_ts = list(result.select(*norm_dates).where(result['Type'] == 'Confirmed').collect()[0])
    death_count = list(result.select(*norm_dates).where(result['Type'] == 'Deaths').collect()[0])
    recov_count = list(result.select(*norm_dates).where(result['Type'] == 'Recovered').collect()[0])
    
    y = [death_count, recov_count, cases_ts]
    labels = ['Deaths', 'Recovered', 'Confirmed Cases']
    
    # Plot the data for a country
    fig, ax = plt.subplots(figsize=(15, 10))
    plt.title(f"Case Proportions ({country})")
    plt.xlabel('Date')
    plt.ylabel(f'Percentage')
    plt.stackplot(x, y, labels=labels)
    plt.legend(loc='upper left')
    
    # Axis scales
    plt.xticks(date_cols, rotation='vertical')

interactive(children=(Dropdown(description='country', options=('Chad', 'Paraguay', 'Russia', 'Yemen', 'Senegal…

In [10]:
pre_select = ['Canada', 'China', 'Korea, South', 'Chad']
def moving_average(a, n=7):
    ret = np.cumsum(a)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

w_country = widgets.SelectMultiple(
    options=sorted(countries),
    rows = 10,
    value=pre_select,
)


def update_growth_rate(country_list):
    fig, ax = plt.subplots(figsize=(30, 15))
    for country in country_list:
        spark_query = confirmed.select(*date_cols).where(confirmed['Country/Region'] == country).collect()[0]
        x = moving_average(np.array(spark_query))
        y = np.gradient(x)
        plt.plot(x, y, label=country)
    
    plt.xlabel('Confirmed Cases')
    plt.ylabel('Growth Rate')
    plt.yscale('log')
    plt.xscale('log')
    plt.legend(loc='upper left')
    plt.show()
    
interactive(update_growth_rate, country_list=w_country)

interactive(children=(SelectMultiple(description='country_list', index=(32, 36, 90, 34), options=('Afghanistan…