In [None]:
# Auto-reload frequently changed files
%load_ext autoreload
%autoreload 2
%aimport utils

import pandas as pd
import numpy as np
import altair as alt
from ipywidgets import interact
from os.path import join

from constants import COLUMNS, ALL_AGE_COLUMNS
from utils import read_latest_demographics_df, apply_theme

# Dataset

In [None]:
df = read_latest_demographics_df()
SITE_IDS = df[COLUMNS.SITE_ID].unique().tolist()

# Columns
siteid = COLUMNS.SITE_ID
sex = COLUMNS.SEX
total_patients = COLUMNS.TOTAL_PATIENTS

# Use the consistent capitalization
df[COLUMNS.SEX] = df[COLUMNS.SEX].apply(lambda x: x.capitalize())

# remove aggregate rows and columns
not_all_filter = df[COLUMNS.SEX] != "All"
df = df[not_all_filter]
df = df.drop(columns=[COLUMNS.TOTAL_PATIENTS])

# wide to long
df = pd.melt(df, id_vars=[COLUMNS.SITE_ID, COLUMNS.SEX])
df = df.rename(columns={"variable": "age_group", "value": "num_patients"})

# TODO: Use readable category names here
# "age_0to2" to "0-2"


df

# Visualization

In [None]:
color_scale = alt.Scale(domain=["Male", "Female", "Other"], range=["#377FB8", "#CA2026", "gray"])
    
def demographics_chart(SiteID, Normalize): 
    
    # Base bar chart
    base = alt.Chart(df).mark_bar().encode(
        x=alt.X('age_group:N', title="Age group",sort="x"),
        y=alt.Y(f"sum(num_patients):Q", title="Number of patients", axis=alt.Axis(tickCount=5)),
        color=alt.Color("sex:N", title=None, scale=color_scale),
        tooltip=[siteid, sex, f"sum(num_patients)"]
    ).properties(
        title="COVID-19 patients (" + SiteID + ")",
        width=500,
        height=300
    )
    
    if SiteID != "All Sites":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=siteid, equal=SiteID)
        )

    if Normalize == "Yes":
        base = base.encode(
            y=alt.Y(f"sum(num_patients):Q", title="Fraction of Patients", stack="normalize"),
        )    
    
    chart = apply_theme(base)
        
    return chart.interactive()

interact(demographics_chart, SiteID=["All Sites"] + SITE_IDS, Normalize=["No", "Yes"] )

In [None]:
color_scale = alt.Scale(domain=["Male", "Female", "Other"], range=["#0072B2", "#D55E00", "gray"])

def apply_theme(base):
    return base.configure_axis(
        labelFontSize=14,
        labelFontWeight=300,
        titleFontSize=18,
        titleFontWeight=300
    ).configure_title(fontSize=18, fontWeight=400, anchor="middle")
    
def demographics_chart(SiteID, Normalize): 
    
    # Base bar chart
    base = alt.Chart(df).mark_bar().encode(
        x=alt.X('age_group:N', title="Age Group",sort="x"),
        y=alt.Y(f"sum(num_patients):Q", title="Number of Patients", axis=alt.Axis(tickCount=5)),
        
        color=alt.Color("sex:N", title="Sex", scale=color_scale),
        tooltip=[siteid, sex, f"sum(num_patients)"]
    ).properties(
        title="COVID-19 Patients (" + SiteID + ")",
        width=500,
        height=300
    )
    
    if SiteID != "All Sites":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=siteid, equal=SiteID)
        )

    if Normalize == "Yes":
        base = base.encode(
            y=alt.Y(f"sum(num_patients):Q", title="Fraction of Patients", stack="normalize"),
        )    
    
    chart = apply_theme(base)
        
    return chart.interactive()

interact(demographics_chart, SiteID=["All Sites"] + SITE_IDS, Normalize=["No", "Yes"] )

In [None]:
click = alt.selection_multi(encodings=['color'])

def apply_theme(base):
    return base.configure_axis(
        labelFontSize=14,
        labelFontWeight=300,
        titleFontSize=18,
        titleFontWeight=300
    ).configure_title(fontSize=18, fontWeight=400, anchor="start")
    
def demographics_chart():     
    base = alt.Chart(df).mark_bar().encode(
        x=alt.X('siteid:N', title="Site",sort="x", axis=None),
        y=alt.Y(f"sum(num_patients):Q", title="Number of Patients", axis=alt.Axis(tickCount=5)),
        column=alt.Column('age_group:O', title="", sort=ALL_AGE_COLUMNS),
        
        color=alt.Color("siteid:N", title="Site", scale=alt.Scale(scheme="category20"), legend=None),
        tooltip=[siteid, f"sum(num_patients)"]
    ).properties(
        width=70,
        height=300
    ).transform_filter(
        click
    )
        
    return base.interactive()

def site_chart():     
    base = alt.Chart(df).mark_circle(size=100).encode(
        y=alt.Y('siteid:N', title="Site",sort="y"),
        color=alt.condition( click, alt.Color("siteid:N", title="Site", scale=alt.Scale(scheme="category20"), legend=None), alt.value('gray') ),
    ).properties(
        selection=click        
    ).properties(
        title="",
        height=300
    )
                
    return base.interactive()



apply_theme( site_chart() | demographics_chart()
    ).properties(
        title="COVID-19 Patients by Age Group and Site"
    )

Removing narrow white strokes arround bars by not separating bars by `siteid` when "All Sites" is selected.

In [None]:
sex_color_scale = alt.Scale(domain=["Male", "Female", "Other"], range=["#0072B2", "#D55E00", "gray"])

def apply_theme(base):
    return base.configure_axis(
        labelFontSize=14,
        labelFontWeight=300,
        titleFontSize=18,
        titleFontWeight=300
    ).configure_title(fontSize=18, fontWeight=400, anchor="middle")
    
def demographics_chart(SiteID, Normalize): 
    
    # Base bar chart
    base = alt.Chart(df).mark_bar().encode(
        x=alt.X('age_group:N', title="Age Group",sort="x"),
        y=alt.Y(f"sum(num_patients):Q", title="Number of Patients", axis=alt.Axis(tickCount=5)),
        
        color=alt.Color("sex:N", title="Sex", scale=sex_color_scale),
    ).properties(
        title="COVID-19 Patients (" + SiteID + ")",
        width=500,
        height=300
    )
    
    if SiteID != "All Sites":
        base = base.transform_filter(
            alt.FieldEqualPredicate(field=siteid, equal=SiteID)
        ).encode(tooltip=[siteid, sex, f"sum(num_patients)"])
    else:
        base = base.encode(tooltip=[sex, f"sum(num_patients)"])

    if Normalize != "no":
        base = base.encode(
            y=alt.Y(f"sum(num_patients):Q", title="Fraction of Patients", stack="normalize"),
        )
    
    chart = apply_theme(base)
        
    return (chart).interactive()

interact(demographics_chart, SiteID=["All Sites"] + SITE_IDS, Normalize=["yes", "no"] )