In [1]:
# import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import datetime
import itertools

In [2]:
wi_grid = pd.read_csv("../data/mygrid.csv")
r0_county = pd.read_csv("../data/df_gg_ms.csv")

In [3]:
r0_county_grid = pd.merge(r0_county, wi_grid, how = "left", on = "code")
# Add columns to be used as hover-tooltip text
r0_county_grid["nice_date"] = r0_county_grid["Date"].apply(lambda x: x if np.isnan(x) else datetime.datetime.fromtimestamp(x/1000).strftime("%b %d, %Y"))
r0_county_grid["bg_hover_text"] = "<b>" + r0_county_grid["county"] + " County</b> <br>" + r0_county_grid["Risk"] + " Risk"
r0_county_grid["r0_line_hover_text"] = "<b>" + r0_county_grid["county"] + " County</b> <br>" + r0_county_grid["nice_date"] + "<br>R0: " + r0_county_grid["Mean(R)"].map('{:,.2f}'.format) + "<br>95% CI: (" + r0_county_grid["Quantile.0.025(R)"].map('{:,.2f}'.format) + ", " + r0_county_grid["Quantile.0.975(R)"].map('{:,.2f}'.format) + ")" 
# TODO: add cases once they are in the data set

In [4]:
rows = sorted(set(wi_grid["row"].tolist()))
cols = sorted(set(wi_grid["col"].tolist()))

In [9]:
# define limits for background area plots
min_r = 0
max_r = np.max(r0_county_grid["Quantile.0.975(R)"])
max_t = np.max(r0_county_grid["Date"])
min_t = np.min(r0_county_grid["Date"])


# need a full list of rows and columns for the grid titles 
full_grid = pd.DataFrame([(x, y) for x in rows for y in cols], columns=["row","col"])
full_grid = pd.merge(full_grid, wi_grid, how = "left", on = ["row","col"])
full_grid = full_grid.fillna("")

# create figure with subplots
fig = make_subplots(rows=max(rows), cols=max(cols), subplot_titles = [y.replace(" County", "") for y in full_grid["name"].tolist()])
# can add shared_xaxes=True, shared_yaxes=True to the make_subplots call, but it looks bad
fig.for_each_annotation(lambda a: a.update(font=dict(size = 7)))

for row, col in itertools.product(rows, cols):
    index = (r0_county_grid["row"] == row) & (r0_county_grid["col"] == col)
    subset_df = r0_county_grid[index]
    if not subset_df.empty and "Low-Case Counts" not in subset_df["Risk"].tolist():

        # define color scheme based on risk levels
        if "High" in subset_df["Risk"].tolist():
            bg_fill = "rgba(255, 0, 0, 0.1)"
            # risk_text = "High Risk"
            line_col = "rgb(255, 0, 0)"
        elif "Medium" in subset_df["Risk"].tolist():
            bg_fill = "rgba(255, 255, 0, 0.1)"
            # risk_text = "Medium Risk"
            line_col = "rgb(255, 255, 0)"
        elif "Low" in subset_df["Risk"].tolist():
            bg_fill = "rgba(0, 255, 0, 0.1)"
            # risk_text = "Low Risk"
            line_col = "rgb(0, 255, 0)"

        risk_text = np.unique(subset_df["bg_hover_text"])[0]    

        # add background based on color
        fig.add_trace(
            go.Scatter(
                x=[min_t,max_t,max_t,min_t, min_t],
                y=[min_r,min_r,max_r,max_r, min_r],
                fill='toself',
                fillcolor=bg_fill,
                hoveron = 'fills', # select where hover is active
                line_color='rgba(0,0,0,0)',
                text=risk_text,
                # hovertemplate="%Testing <br> %{text} Risk",
                hoverinfo = 'text+x+y'
            ),
            row = row,
            col = col
        )

        # add horizontal line along R0 = 1
        # fig.add_shape(type="line", x0=min_t, y0=1, x1=max_t, y1=1,
        #     row = row, 
        #     col = col, 
        #     line=dict(color="#CCCCCC", width=2)
        # )

        fig.add_trace(
            go.Scatter(
                x=[min_t, max_t],
                y=[1,1], 
                mode = "lines",
                name = "",
                line_color = "#CCCCCC",
            ),
            row = row,
            col = col
        )

        
        # add the area trace for 95% credibility interval 
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Quantile.0.025(R)"], 
                fill = None, 
                mode = "lines",
                line_color = "#CCCCCC",
                hoverinfo = "skip"
                # marker = dict(color = "blue")
            ),
            row = row,
            col = col
        )
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Quantile.0.975(R)"], 
                fill = "tonexty",
                mode = "lines",
                line_color = "#CCCCCC",
                fillcolor = "#CCCCCC",
                hoverinfo = "skip"
                # marker = dict(color = "blue")
            ),
            row = row,
            col = col
        )

        # add line trace for R0 
        fig.add_trace(
            go.Scatter(
                x=subset_df["Date"],
                y=subset_df["Mean(R)"], 
                text=subset_df["r0_line_hover_text"],
                mode = "lines",
                name = "",
                line_color = line_col,
                hovertemplate = "%{text}"
                # marker = dict(color = line_col)
                
            ),
            row = row,
            col = col
        )

     

fig.for_each_annotation(lambda a: a.update(font=dict(size = 7)))
fig.for_each_xaxis(lambda a: a.update(
    type = "date",
    tickformat = "%m/%d", # originally "%b %d"
    tickfont = dict(size = 7),
    range = [min_t, max_t]))
fig.for_each_yaxis(lambda a: a.update(tickfont = dict(size = 7),
                                      range = [np.log10(0.01), np.log10(max_r)],
                                      type = "log"))

# # fig.update_xaxes()
fig.update_layout(
    showlegend = False,
    plot_bgcolor = "rgba(255,255,255,1)"
)


# tie the axes together when doing any zooming/panning
# TODO: figure out why this isn't working (see https://plotly.com/python/axes/#synchronizing-axes-in-subplots-with-matches)
# fig.update_xaxes(matches='x', overwrite = True)
fig.update_yaxes(matches='y', overwrite = True)
fig.update_xaxes(matches='x')

fig.update_layout(
    width = 1200,
    height = 1200
)

fig.show()
