In [None]:
# Auto-reload frequently changed files
%load_ext autoreload
%autoreload 2
%aimport utils

import pandas as pd
import numpy as np
import altair as alt
from altair import datum
from ipywidgets import interact
from os.path import join
from datetime import date, datetime, timedelta

from constants import COLUMNS
from utils import read_latest_daily_counts_df, preprocess_daily_counts_df_for_vis, add_aligned_time_steps_column, apply_theme

# Dataset

In [None]:
df = read_latest_daily_counts_df()
df = preprocess_daily_counts_df_for_vis(df)

SITE_IDS = df[COLUMNS.SITE_ID].unique().tolist()

# Columns (TODO: do we really need this here?)
new_positive_cases = "New positive cases"
patients_in_icu = "Patients in ICU"
new_deaths = "New deaths"

# Range of some columns
start_date = min(df[COLUMNS.DATE])
end_date = max(df[COLUMNS.DATE])
duration = (end_date - start_date).days

start_date_margin = start_date - timedelta(days=1)
start_date_margin = start_date_margin.strftime("%Y-%m-%d")

end_date_margin = end_date + timedelta(days=1)
end_date_margin = end_date_margin.strftime("%Y-%m-%d")

df

# Visualizations

In [None]:
color_scale = alt.Scale(
    domain=[new_deaths, new_positive_cases, patients_in_icu], 
    range=["#CA2026", "#377FB8", "#60B75D"]
)

def daily_counts_chart(SiteID):

    category = "category"
    value = "num_patients"
    
    # Mouseover effect
    mouseover = alt.selection_single(on="mouseover", fields=[category])
    
    base = alt.Chart(df).encode(
        x=alt.X(
            f"{COLUMNS.DATE}:T", axis=alt.Axis(tickCount=7), 
            scale=alt.Scale(domain=[start_date_margin, end_date_margin]),
            title=None,
        ),y=alt.Y(
            f"sum({value}):Q", axis=alt.Axis(tickCount=5), 
            scale=alt.Scale(domain=[-2, 110]),
            title="Number of patients"
        ),
        color=alt.Color(f"{category}:N", title="", scale=color_scale),
        size=alt.condition(~mouseover, alt.value(4), alt.value(6)),
        tooltip=[f"{COLUMNS.DATE}:T", f"{category}:N", f"sum({value}):Q"],
    )
    
    
    if SiteID != "All":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=COLUMNS.SITE_ID, equal=SiteID)
        )
    
    line = base.mark_line(size=4).properties(
        width=750, height=400, title="Number of Positive Cases, Patients in ICU, and Deaths"
    )
    
    circle = base.mark_circle().encode(
        size=alt.condition(~mouseover, alt.value(100), alt.value(150))
    )
    
    return apply_theme(line + circle).add_selection(mouseover).interactive()

interact(daily_counts_chart, SiteID=["All"] + SITE_IDS, )

## Total Daily Counts

In [None]:
def daily_counts_chart(SiteID):
    category = "category"
    value = "value"
    
    # Mouseover effect
    mouseover = alt.selection_single(on="mouseover", fields=[category]) # nearest=True does not look working
    
    base = alt.Chart(df).transform_fold(
        # Fold three quantitative fields, making three rows from one original row.
        fold=[new_positive_cases, patients_in_icu, new_deaths], 
        as_=[category, value]
    ).encode(
        x=alt.X(
            f"{COLUMNS.DATE}:T", axis=alt.Axis(tickCount=7), 
            scale=alt.Scale(domain=[start_date_margin, end_date_margin]),
            title=None,
        ),y=alt.Y(
            f"sum({value}):Q", axis=alt.Axis(tickCount=5), 
            scale=alt.Scale(domain=[-2, 110]),
            title="Number of Patients"
        ),
        color=alt.Color(f"{category}:N", scale=alt.Scale(range=["#5C63A2", "#EC7176", "#F4AB32"])),
        # mouse hover visual effect
        size=alt.condition(~mouseover, alt.value(4), alt.value(6)),
        tooltip=[f"{COLUMNS.DATE}:T", f"{category}:N", f"sum({value}):Q"],
    )
    
    if SiteID != "All":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=COLUMNS.SITE_ID, equal=SiteID)
        )
    
    line = base.mark_line(size=4).properties(
        width=750, height=400, title="Number of Positive Cases, Patients in ICU, and Deaths"
    ).add_selection(mouseover)
    
    circle = base.mark_circle(size=70).encode(
        size=alt.condition(~mouseover, alt.value(70), alt.value(100))
    )
    
    # Uncomment below to show label at the end of lines
    # text = base.mark_text(align='left', fontWeight=400, dx=10, dy=1).encode(
    #     text=f"{category}:N",
    #     size=alt.condition(~mouseover, alt.value(20), alt.value(24))
    # ).transform_window(
    #     sort=[alt.SortField(date, order="descending")], 
    #     rank="rank(date)"
    # ).transform_filter(alt.datum.rank == 1)

    # Guidelines
    g = alt.Chart(pd.DataFrame({
        "date":[start_date, end_date_margin],
        "value":[0, duration]
    })).mark_line(color="lightgray", strokeDash=[10,1], size=3).encode(
        x="date:T",
        y="value:Q"
    )

    g10 = alt.Chart(pd.DataFrame({
        "date":[start_date, end_date_margin],
        "value":[0, duration * 10]
    })).mark_line(color="lightgray", strokeDash=[10,1], size=3).encode(
        x="date:T",
        y="value:Q"
    )
    
#     g20 = alt.Chart(pd.DataFrame({
#         "date":[start_date, end_date_margin],
#         "value":[0, duration * 20]
#     })).mark_line(color="lightgray", strokeDash=[10,1], size=3).encode(
#         x="date:T",
#         y="value:Q"
#     )
    
    return apply_theme(g + g10 + line + circle).interactive()

interact(daily_counts_chart, SiteID=["All"] + SITE_IDS, )

# DEPRECATED CODES BELOW:

## Since N Possitive Cases

In [None]:
"""
Deprecated code
"""

def daily_counts_chart(SiteID):
    category = "category"
    value = "value"
    
    # Mouseover effect
    mouseover = alt.selection_single(on="mouseover", fields=[category]) # nearest=True does not look working
    
    base = alt.Chart(df_aligned).transform_fold(
        # Fold three quantitative fields, making three rows from one original row.
        fold=[new_positive_cases, patients_in_icu, new_deaths], 
        as_=[category, value]
    ).transform_filter(
        (0 <= alt.datum.timestep) & (alt.datum.timestep <= 8) # Just because the number going down because of less institution is missleading
    ).encode(
        x=alt.X(
            f"timestep", axis=alt.Axis(tickCount=7),
            title="Number of days since 1st possitive cases",
            scale=alt.Scale(domain=[0,8])
        ),y=alt.Y(
            f"sum({value}):Q", axis=alt.Axis(tickCount=5),
            title="Number of Patients"
        ),
        color=alt.Color(f"{category}:N", scale=alt.Scale(range=["#5C63A2", "#EC7176", "#F4AB32"])),
        # mouse hover visual effect
        size=alt.condition(~mouseover, alt.value(4), alt.value(6)),
        tooltip=[f"timestep", f"{category}:N", f"sum({value}):Q"],
    )
    
    if SiteID != "All":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=COLUMNS.SITE_ID, equal=SiteID)
        )
    
    line = base.mark_line(size=4).properties(
        width=750, height=400, title="Number of Possitive Cases, Patients in ICU, and Deaths"
    ).add_selection(mouseover)
    
    circle = base.mark_circle(size=70).encode(
        size=alt.condition(~mouseover, alt.value(70), alt.value(100))
    )
    
    # Uncomment below to show label at the end of lines
    # text = base.mark_text(align='left', fontWeight=400, dx=10, dy=1).encode(
    #     text=f"{category}:N",
    #     size=alt.condition(~mouseover, alt.value(20), alt.value(24))
    # ).transform_window(
    #     sort=[alt.SortField(date, order="descending")], 
    #     rank="rank(date)"
    # ).transform_filter(alt.datum.rank == 1)
    
    return apply_theme(line + circle).interactive()

interact(daily_counts_chart, SiteID=["All"] + SITE_IDS, )

## By Individual Sites

In [None]:
def daily_counts_site_chart(SiteID):
    
    base = alt.Chart(df).encode(
        x=alt.X(
            f"{date}:T", 
            axis=alt.Axis(tickCount=7), 
            title="",
            scale=alt.Scale(domain=[start_date_margin, end_date_margin])
        ),
        y=alt.Y(
            f"{new_positive_cases}:Q", 
            axis=alt.Axis(tickCount=5), 
            title=None
        )
    ).properties(width=700, height=500, title="Number of Possitive Cases")
    
    # Color encoding
    if SiteID == "All":
        base = base.encode(color=alt.Color(siteid, title="Site"))
    else:
        base = base.encode(
            color=alt.Color(siteid, scale=alt.Scale(range=["lightgray"]), legend=None)
        )
    
    line = base.mark_line(size=4)
    circle = line.mark_circle(size=80)
    
    # Only when seeing individual data
    hl_line = base.mark_line(size=4).encode(
        color=alt.value("steelblue")
    ).transform_filter(alt.datum.siteid == SiteID)
    hl_circle = hl_line.mark_circle(size=80)
    
    
    text = hl_line.mark_text(align='left', fontWeight=400, dx=10, dy=1).encode(
        text=f"{siteid}:N",
        size=alt.value(20),
    ).transform_window(
        sort=[alt.SortField(date, order="descending")], 
        rank="rank(date)"
    ).transform_filter(alt.datum.rank == 1)
    
    return apply_theme(line + text + hl_line + circle + hl_circle).interactive()

interact(daily_counts_site_chart, SiteID=["All"] + SITE_IDS, )

## Since N Possitive Cases

In [None]:
def daily_counts_site_chart(SiteID):
    
    base = alt.Chart(df).encode(
        x=alt.X(
            f"timestep", 
            axis=alt.Axis(tickCount=7), 
            title="Number of days since 1st possitive cases",
        ),
        y=alt.Y(
            f"{new_positive_cases}:Q", 
            axis=alt.Axis(tickCount=5), 
            title=None
        )
    ).transform_filter(
        alt.datum.timestep >= 0
    ).properties(width=700, height=500, title="Number of Possitive Cases")
    
    # Color encoding
    if SiteID == "All":
        base = base.encode(color=alt.Color(siteid, title="Site"))
    else:
        base = base.encode(
            color=alt.Color(siteid, scale=alt.Scale(range=["lightgray"]), legend=None)
        )
    
    line = base.mark_line(size=4)
    circle = line.mark_circle(size=80)
    
    # Only when seeing individual data
    hl_line = base.mark_line(size=4).encode(
        color=alt.value("steelblue")
    ).transform_filter(alt.datum.siteid == SiteID)
    hl_circle = hl_line.mark_circle(size=80)
    
    
    text = hl_line.mark_text(align='left', fontWeight=400, dx=10, dy=1).encode(
        text=f"{siteid}:N",
        size=alt.value(20),
    ).transform_window(
        sort=[alt.SortField(date, order="descending")], 
        rank="rank(date)"
    ).transform_filter(alt.datum.rank == 1)
    
    return apply_theme(line + text + hl_line + circle + hl_circle).interactive()

interact(daily_counts_site_chart, SiteID=["All"] + SITE_IDS, )