In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from viz import viz_interactive
from modeling import exponential_modeling
from bokeh.plotting import figure, show, output_notebook, output_file, save
from functions import merge_data
import load_data

from plotly.offline import init_notebook_mode, iplot

df = load_data.load_county_level()
df = df.sort_values('tot_deaths', ascending=False)
important_vars = load_data.important_keys(df)
print(df.keys())

## how many deaths/cases are there

In [None]:
df[['tot_deaths', 'tot_cases', 'StateName', 'CountyName', ]].head(10)

In [None]:
d = df[(df['tot_deaths'] > 0) & (df['tot_cases'] > 0)]
R, C = 1, 2
plt.subplot(R, C, 1)
plt.hist(np.log10(d['tot_cases']))
plt.xlabel('Number of cases (log-scale)')
plt.ylabel('Number of counties')
plt.yscale('log')

plt.subplot(R, C, 2)
plt.hist(np.log10(d['tot_deaths']))
plt.xlabel('Number of deaths (log-scale)')
plt.ylabel('Number of counties')
plt.yscale('log')
plt.tight_layout()
plt.show()

In [None]:
plt.hist(df['#Hospitals'], bins=100)
plt.xlabel('#Hospitals (per county)')
plt.show()

# correlations

In [None]:
sns.set(style="white")

# Generate a large random dataset
rs = np.random.RandomState(33)
d = df[[k for k in important_vars if not 'PopMale' in k and not 'PopFmle' in k and not 'MortalityAge' in k and not 'PopTotal' in k] + 
        ['tot_deaths', 'tot_cases']]

# Compute the correlation matrix
corrs = d.corr(method='spearman')

# Generate a custom diverging colormap
cmap = sns.diverging_palette(10, 220, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
plt.figure(figsize=(8, 8), dpi=300)
sns.heatmap(corrs, cmap=cmap, vmax=1, center=0, square=True, linewidths=.5, cbar_kws={"shrink": .5})
# sns.clustermap(corr, cmap=cmap, vmax=1, center=0, square=True, linewidths=.5, cbar_kws={"shrink": .5})
plt.tight_layout()
plt.savefig('results/correlations_heatmap.png')
plt.show()

In [None]:
keys = np.array(corrs.index)
k = np.where(keys == 'tot_deaths')[0][0]
corrs_row = corrs.iloc[k]
args = np.argsort(corrs_row)
plt.figure(dpi=300, figsize=(6, 5))
plt.barh(keys[args][:-1], corrs_row[args][:-1]) # 1 to drop outcome itself
plt.xlabel('Correlation (spearman) with # deaths')
plt.tight_layout()
plt.savefig('results/correlations.png')
plt.show()

# interactive plots

In [None]:
ks = [k for k in important_vars if not 'PopMale' in k
      and not 'PopFmle' in k
      and not 'MortalityAge' in k]

In [None]:
# filter by state
for state in ['NY', 'WA', 'CA']:
    d = df[df["StateNameAbbreviation"] == state]

    p = viz_interactive.plot_counties(d, 
                          variable_to_distribute='tot_cases',
                          variables_to_display=ks,
                          state=state,
                          logcolor=False)
    
    output_file(f"results/{state}.html", mode='inline')
    show(p)
    save(p)

In [None]:
from urllib.request import urlopen
import json
with urlopen('https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json') as response:
    counties = json.load(response)

# See:
# https://plotly.com/python/county-choropleth/
# https://plotly.com/python/choropleth-maps/
# https://plotly.com/python/sliders/
# https://plotly.com/python/reference/#choropleth
# https://plotly.com/python-api-reference/
# TODO: allow filtering by state

import plotly.figure_factory as ff
import plotly.graph_objs as go

target_days = np.array([1, 2, 3, 4])

df_preds = exponential_modeling.estimate_deaths(
    df, mode='predict_future', target_day=target_days
)

zmax = df_preds['predicted_deaths_exponential'].apply(
    lambda x: x[len(target_days)-1]
).max()

scl = [[0.0, '#ffffff'],[0.2, '#ff9999'],[0.4, '#ff4d4d'], 
       [0.6, '#ff1a1a'],[0.8, '#cc0000'],[1.0, '#4d0000']] # reds

fips = df_preds['SecondaryEntityOfFile'].tolist()

fig = go.Figure()

for day in range(len(target_days)):

    values = df_preds['predicted_deaths_exponential'].apply(
        lambda x: x[day]
    ).tolist()
    
    fig.add_trace(
        go.Choropleth(
            colorscale=scl,
            visible=False,
            z=values,
            geojson=counties,
            locations=fips,
            zmin=0,
            zmax=zmax
        )
    )

    # TODO
    # endpts = list(np.linspace(1, 12, len(colorscale) - 1))
    
    # TODO: text for mouse hover   

# make first trace visible
# fig.data[0].visible = True

steps = []
for i in range(len(fig.data)):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=10,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    title_text='Predicted COVID-19 Deaths',
    geo = dict(
        scope='usa',
        projection=go.layout.geo.Projection(type = 'albers usa'),
        showlakes=False, # lakes
        lakecolor='rgb(255, 255, 255)'),
    sliders=sliders
)

fig.show()

## plot political leaning

In [None]:
# filter by state
for state in ['NY', 'WA', 'CA']:
    d = df[df["StateNameAbbreviation"] == state]

    p = viz_interactive.plot_counties(d, 
                          variable_to_distribute='dem_to_rep_ratio',
                          variables_to_display=ks,
                          state=state,
                          logcolor=False)
    show(p)