In [1]:
# import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import datetime
import itertools

In [2]:
wi_grid = pd.read_csv("../data/mygrid.csv")
# r0_county = pd.read_csv("../data/df_gg_ms.csv")
r0_county = pd.read_csv("../data/df_gg.csv")

In [12]:
def risk_text_mapping(row):
    if row['Risk (Interval)'] == 'Low-Low':
        return 'Low'
    elif row['Risk (Interval)'] == 'Medium-Medium':
        return 'Medium'
    elif row['Risk (Interval)'] == 'High-High':
        return 'High' 
    else:
        return row['Risk (Interval)'].replace("-", " to ")



In [13]:
r0_county_grid = pd.merge(r0_county, wi_grid, how = "left", on = "code")
# Add columns to be used as hover-tooltip text
r0_county_grid["nice_date"] = r0_county_grid["Date"].apply(lambda x: x if np.isnan(x) else datetime.datetime.fromtimestamp(x/1000).strftime("%b %d, %Y"))
r0_county_grid['Risk Text'] = r0_county_grid.apply(risk_text_mapping, axis=1)
r0_county_grid["bg_hover_text"] = "<b>" + r0_county_grid["county"] + " County</b> <br>" + r0_county_grid["Risk Text"] + " Risk<br>Cases: " + r0_county_grid["cases"].map('{:,.0f}'.format)
r0_county_grid["r0_line_hover_text"] = "<b>" + r0_county_grid["county"] + " County</b> <br>" + r0_county_grid["nice_date"] + "<br>R<sub>0</sub>: " + r0_county_grid["Mean R0"].map('{:,.2f}'.format) + " (95% CI: " + r0_county_grid["Quantile.0.025(R)"].map('{:,.2f}'.format) + ", " + r0_county_grid["Quantile.0.975(R)"].map('{:,.2f}'.format) + ")<br>Cases: " + r0_county_grid["cases"].map('{:,.0f}'.format)

In [6]:
rows = sorted(set(wi_grid["row"].tolist()))
cols = sorted(set(wi_grid["col"].tolist()))

In [27]:
# define limits for background area plots
min_r = 0
max_r = np.max(r0_county_grid["Quantile.0.975(R)"])
max_t = np.max(r0_county_grid["Date"])
min_t = np.min(r0_county_grid["Date"])


# need a full list of rows and columns for the grid titles 
full_grid = pd.DataFrame([(x, y) for x in rows for y in cols], columns=["row","col"])
full_grid = pd.merge(full_grid, wi_grid, how = "left", on = ["row","col"])
full_grid = full_grid.fillna("")

# create figure with subplots
fig = make_subplots(rows=max(rows), cols=max(cols), subplot_titles = [y.replace(" County", "") for y in full_grid["name"].tolist()], x_title = "Date", y_title = "Instantaneous R<sub>0</sub>")
# can add shared_xaxes="all", shared_yaxes="all" to the make_subplots call, but it looks bad
fig.for_each_annotation(lambda a: a.update(font=dict(size = 7)))

for row, col in itertools.product(rows, cols):
    index = (r0_county_grid["row"] == row) & (r0_county_grid["col"] == col)
    subset_df = r0_county_grid[index]
    if not subset_df.empty and "Low-Case Counts" not in subset_df["Risk"].tolist():
        
        # define color scheme based on risk levels
        if np.max(subset_df["cases"]) < 10:
            bg_fill = "rgba(77,77,77, 0.1)"
        elif "High-High" in subset_df["Risk (Interval)"].tolist():
            bg_fill = "rgba(255, 25, 28, 0.1)"
        elif "High-Medium" in subset_df["Risk (Interval)"].tolist():
            bg_fill = "rgba(253,174,97, 0.1)"
        elif "Medium-Medium" in subset_df["Risk (Interval)"].tolist():
            bg_fill = "rgba(255,255,191, 0.1)"
        elif "Low-Medium" in subset_df["Risk (Interval)"].tolist():
            bg_fill = "rgba(166,217,106, 0.1)"
        elif "Low-Low" in subset_df["Risk (Interval)"].tolist():
            bg_fill = "rgba(26,150,65, 0.1)"
        elif "Low-High" in subset_df["Risk (Interval)"].tolist():
            # bg_fill = "rgba(253,253,102, 0.1)"
            bg_fill = "rgba(254, 231, 163, 0.1)"
            
        if np.max(subset_df["cases"]) < 10:
            risk_text = "<b>" + np.unique(subset_df["county"])[0] + " County</b> <br>" + "Low Case Counts" + "<br>Cases: " + str(np.max(subset_df["cases"]))
            # risk_text = "Low Case Count"
        else:
            risk_text = np.unique(subset_df["bg_hover_text"])[0] 

        

        # add background based on color
        fig.add_trace(
            go.Scatter(
                x=[min_t,max_t,max_t,min_t, min_t],
                y=[min_r,min_r,max_r,max_r, min_r],
                fill='toself',
                fillcolor=bg_fill,
                hoveron = 'fills', # select where hover is active
                line_color='rgba(0,0,0,0)',
                text=risk_text,
                hoverinfo = 'text+x+y'
            ),
            row = row,
            col = col
        )

        # add horizontal line along R0 = 1
        fig.add_trace(
            go.Scatter(
                x=[min_t, max_t],
                y=[1,1], 
                mode = "lines",
                name = "",
                line_color = "#CCCCCC",
                hoverinfo = "skip",
                line=dict(width=0.5)
            ),
            row = row,
            col = col
        )

        # add the area trace for 95% credibility interval 
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Quantile.0.025(R)"], 
                fill = None, 
                mode = "lines",
                line_color = "#CCCCCC",
                hoverinfo = "skip"
            ),
            row = row,
            col = col
        )
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Quantile.0.975(R)"], 
                fill = "tonexty",
                mode = "lines",
                line_color = "#CCCCCC",
                fillcolor = "#CCCCCC",
                hoverinfo = "skip"
            ),
            row = row,
            col = col
        )

        # add line trace for R0 
        # yellow color for whole line (others are laid on top)
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"], 
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = "#ffffbf",
                hovertemplate = "%{text}"
            ),
            row = row,
            col = col
        )

        # low-med risk
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"].mask(subset_df["Quantile.0.025(R)"] >= 1, np.nan),
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = "#a6d96a",
                hovertemplate = "%{text}",  
            ),
            row = row,
            col = col
        )
        # med-high risk
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"].mask((subset_df["Quantile.0.025(R)"] < 1) * (subset_df["Quantile.0.975(R)"] < 1.2), np.nan), # not masking enough
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = "#fdae61",
                hovertemplate = "%{text}",  
            ),
            row = row,
            col = col
        )
        # low-low risk
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"].mask((subset_df["Quantile.0.975(R)"] >= 1), np.nan), 
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = "#1a9641",
                hovertemplate = "%{text}",  
            ),
            row = row,
            col = col
        )
        # high-high risk
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"].mask((subset_df["Quantile.0.025(R)"] < 1.2), np.nan), 
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = "#d7191c",
                hovertemplate = "%{text}",  
            ),
            row = row,
            col = col
        )
        # low-high risk (very wide conf. interval)
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean R0"].mask((subset_df["Quantile.0.025(R)"] >= 1) + (subset_df["Quantile.0.975(R)"] < 1.2) >= 1, np.nan), 
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                # line_color = "#7a7a7a",
                line_color = "rgba(255, 255, 191, 0.7)",
                hovertemplate = "%{text}",  
            ),
            row = row,
            col = col
        )
        
fig.for_each_annotation(lambda a: a.update(font=dict(size = 7)))
fig.for_each_xaxis(lambda a: a.update(
    type = "date",
    tickformat = "%m/%d", # originally "%b %d"
    tickfont = dict(size = 7),
    range = [min_t, max_t])
)
fig.for_each_yaxis(lambda a: a.update(
    tickfont = dict(size = 7),
    range = [np.log10(0.01), np.log10(max_r)],
    type = "log")
)

fig.update_layout(
    showlegend = False,
    plot_bgcolor = "rgba(255,255,255,1)"
)

# tie the axes together when doing any zooming/panning
# TODO: figure out why this isn't working (see https://plotly.com/python/axes/#synchronizing-axes-in-subplots-with-matches)
# fig.update_xaxes(matches='x', overwrite = True)

fig.update_layout(
    title = "Instantaneous R<sub>0</sub> Infection Rates of COVID-19 in Wisconsin",
    width = 1200,
    height = 1200
)

fig.show()
fig.write_html("../results/instantaneous_r0_plotly.html")
