In [None]:
# Auto-reload frequently changed files
%load_ext autoreload
%autoreload 2
%aimport utils

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
from os.path import join
from web import for_website

from constants import COLUMNS
from utils import (
    read_combined_demographics_df, read_combined_by_country_demographics_df, read_combined_by_site_demographics_df,
    apply_grouped_bar_theme, apply_pyramid_theme
)

# Required Setups
- All datasets should be placed in `../data/combined` (e.g., `../data/combined/Demographics-Combinedyymmdd.csv`).

In [None]:
"""
Common info that should be defined everytime before rendering visualizations
"""
SITES = read_combined_by_site_demographics_df()[COLUMNS.SITE_ID].unique()

# Titles
NUM_SITES = len(SITES)
DATA_DATE = "2020-04-10"
VIS_DATE = "2020-04-10"
SUBTITLE = f"Data as of {DATA_DATE} | {NUM_SITES} Sites | Plots generated on {VIS_DATE}"

SAVE_DIR = join("..", "output") # Where to save visualization *.PNG files

# Colors
COMBINED = "All countries"
COMBINED_COLOR = "#444444"

COUNTRIES = ["France", "Germany", "Italy", "Singapore", "USA"]
COUNTRY_COLOR = ["#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00"]
COLOR_BY_COUNTRY = {COUNTRIES[i]: COUNTRY_COLOR[i] for i in range(len(COUNTRIES))} 

# Site-leve colors
SITES = [
    'APHP', 'FRBDX', 
    "UMM", 'UKER', 'UKFR', 
    "HPG23", 'ICSM1', 'ICSM20', 'ICSM5', 'POLIMI', 
    "NUH",
    'BCH', 'BIDMC', 'CHOP', 'KUMC', 'MAYOC', 'MGB', 'MUSC', 'UCLA', 'UMICH', 'UPenn', 'UTSW', "UNC"
]
SITES_ANONYMOUS = [f"SITE {(i+1):02d}" for i in range(len(SITES))]
SITES_TO_ANONYMOUS = { SITES[i]: SITES_ANONYMOUS[i] for i in range(len(SITES)) } 
SITES_COUNTRY = [
    "France", "France", 
    "Germany", "Germany", "Germany", 
    "Italy", "Italy", "Italy", "Italy", "Italy", 
    "Singapore",
    "USA", "USA", "USA", "USA", "USA", "USA", "USA", "USA", "USA", "USA", "USA", "USA"
]
SITE_COLOR = [COLOR_BY_COUNTRY[SITES_COUNTRY[i]] for i in range(len(SITES))]
COLOR_BY_SITE = { SITES[i]: COLOR_BY_COUNTRY[SITES_COUNTRY[i]] for i in range(len(SITES)) } 

COUNTRIES_AND_COMBINED = [COMBINED] + COUNTRIES
COUNTRY_AND_COMBINED_COLOR = [COMBINED_COLOR] + COUNTRY_COLOR
COLOR_BY_COUNTRY_AND_COMBINED = {COUNTRIES_AND_COMBINED[i]: COUNTRY_AND_COMBINED_COLOR[i] for i in range(len(COUNTRIES_AND_COMBINED))} 

COLOR20 = [
    "#3366cc", "#dc3912", "#ff9900", "#109618", "#990099", "#0099c6", 
    "#dd4477", "#66aa00", "#b82e2e", "#316395", "#994499", "#22aa99", 
    "#aaaa11", "#6633cc", "#e67300", "#8b0707", "#651067", "#329262", "#5574a6", "#3b3eac"
]

# Data Preprocess

In [None]:
def preprocess_demo_df(df_dm):
    
    # Drop unused columns before preprocessing for the simplicity
    df_dm = df_dm.drop(columns=[
        COLUMNS.UNMASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.UNMASKED_SITES_AGE_0TO2,
        COLUMNS.UNMASKED_SITES_AGE_3TO5,
        COLUMNS.UNMASKED_SITES_AGE_6TO11,
        COLUMNS.UNMASKED_SITES_AGE_12TO17,
        COLUMNS.UNMASKED_SITES_AGE_18TO25,
        COLUMNS.UNMASKED_SITES_AGE_26TO49,
        COLUMNS.UNMASKED_SITES_AGE_50TO69,
        COLUMNS.UNMASKED_SITES_AGE_70TO79,
        COLUMNS.UNMASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.MASKED_SITES_AGE_0TO2,
        COLUMNS.MASKED_SITES_AGE_3TO5,
        COLUMNS.MASKED_SITES_AGE_6TO11,
        COLUMNS.MASKED_SITES_AGE_12TO17,
        COLUMNS.MASKED_SITES_AGE_18TO25,
        COLUMNS.MASKED_SITES_AGE_26TO49,
        COLUMNS.MASKED_SITES_AGE_50TO69,
        COLUMNS.MASKED_SITES_AGE_70TO79,
        COLUMNS.MASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_UPPER_BOUND_TOTAL_PATIENTS,
        COLUMNS.TOTAL_PATIENTS,
    ])

    # Wide to long
    df_dm = pd.melt(df_dm, id_vars=[
        COLUMNS.SITE_ID,
        COLUMNS.SEX,
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])
    df_dm = df_dm.rename(columns={"variable": COLUMNS.AGE_GROUP, "value": COLUMNS.NUM_PATIENTS})

    # Leave only the 'upper' and 'under' values for the certain 'age_group' only
    """
    # TODO: We do not use error bars anymore
    for c in [
            COLUMNS.AGE_0TO2,
            COLUMNS.AGE_3TO5,
            COLUMNS.AGE_6TO11,
            COLUMNS.AGE_12TO17,
            COLUMNS.AGE_18TO25,
            COLUMNS.AGE_26TO49,
            COLUMNS.AGE_50TO69,
            COLUMNS.AGE_70TO79,
            COLUMNS.AGE_80PLUS
            ]:
        filter_c = df_dm[COLUMNS.AGE_GROUP] == c
        df_dm.loc[filter_c, "upper"] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] + df_dm.loc[filter_c, "masked_upper_bound_" + c]
        df_dm.loc[filter_c, "under"] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS]
        df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] + df_dm.loc[filter_c, "masked_upper_bound_" + c] / 2.0
    """

    df_dm[COLUMNS.SEX] = df_dm[COLUMNS.SEX].apply(lambda x: x.capitalize())

    # df_dm = df_dm[df_dm[COLUMNS.SEX] != "All"] # We will use this group

    # Drop unused columns
    df_dm = df_dm.drop(columns=[
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])

    # Add percentage column
    unique_site_ids = df_dm[COLUMNS.SITE_ID].unique()
    for site in unique_site_ids:
        unique_sex = df_dm[df_dm[COLUMNS.SITE_ID] == site][COLUMNS.SEX].unique()
        for sex in unique_sex:
            df_filter = (df_dm[COLUMNS.SITE_ID] == site) & (df_dm[COLUMNS.SEX] == sex)
            total = df_dm.loc[
                df_filter, 
                COLUMNS.NUM_PATIENTS
            ].sum()
            
            df_dm.loc[
                df_filter, 
                "per_patients"
            ] = df_dm.loc[df_filter, COLUMNS.NUM_PATIENTS] / total * 100    
    
    return df_dm

# Process data
df_dm = read_combined_by_country_demographics_df()
df_dm = preprocess_demo_df(df_dm)

df_dm_combined = read_combined_demographics_df()
df_dm_combined = preprocess_demo_df(df_dm_combined)

# Merge
df_dm = pd.concat([df_dm, df_dm_combined])

# Use readable names
df_dm.loc[df_dm[COLUMNS.SITE_ID] == "Combined", COLUMNS.SITE_ID] = COMBINED

# Readable group names
readable_age_group = {
    COLUMNS.AGE_0TO2: "0 - 2",
    COLUMNS.AGE_3TO5: "3 - 5",
    COLUMNS.AGE_6TO11: "6 - 11",
    COLUMNS.AGE_12TO17: "12 - 17",
    COLUMNS.AGE_18TO25: "18 - 25",
    COLUMNS.AGE_26TO49: "26 - 49",
    COLUMNS.AGE_50TO69: "50 - 69",
    COLUMNS.AGE_70TO79: "70 - 79",
    COLUMNS.AGE_80PLUS: "80+"
}

df_dm[COLUMNS.AGE_GROUP] = df_dm[COLUMNS.AGE_GROUP].apply(lambda x: readable_age_group[x])

df_dm

# Visualizations

In [None]:
def demographics(is_percent=False, is_comparison=False):

    country_dropdown = alt.binding_select(options=COUNTRIES_AND_COMBINED)
    country_selection = alt.selection_single(fields=[COLUMNS.SITE_ID], bind=country_dropdown, name="Country", init={COLUMNS.SITE_ID: COUNTRIES_AND_COMBINED[0]})
    sex_dropdown = alt.binding_select(options=["All", "Male", "Female"])
    sex_selection = alt.selection_single(fields=[COLUMNS.SEX], bind=sex_dropdown, name="Sex", init={COLUMNS.SEX: "All"})
    legend_selection = alt.selection_multi(fields=[COLUMNS.SITE_ID], bind="legend")

    # Filter
    filtered_chart = alt.Chart(df_dm).transform_filter(
        alt.datum[COLUMNS.SEX] != "Other"
    )

    if is_comparison:
        filtered_chart = filtered_chart.transform_filter(
            sex_selection
        ).transform_filter(
            legend_selection
        )
    else:
        filtered_chart = filtered_chart.transform_filter(
            country_selection
        )

    DEMO_TOOLTIP = [
        alt.Tooltip(COLUMNS.SITE_ID, title="Country"),
        alt.Tooltip(COLUMNS.SEX, title="Sex"),
        alt.Tooltip(COLUMNS.AGE_GROUP, title="Age group"),
        alt.Tooltip(COLUMNS.NUM_PATIENTS, title="# of patients")
    ]

    y_field = COLUMNS.NUM_PATIENTS
    if is_percent:
        y_field = "per_patients"
        DEMO_TOOLTIP += [alt.Tooltip("per_patients", title="% of patients", format=".1f")]

    # Render
    # legend = None if is_comparison else alt.Legend(title=None)
    color_scale = alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)  if is_comparison else alt.Scale(domain=["Male", "Female"], range=COLOR20[:2]) 
    separate_field = COLUMNS.SITE_ID if is_comparison else COLUMNS.SEX
    y_title = "Percentage of patients (%)" if is_percent else "Number of patients"
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{separate_field}:N", title=None, axis=None),
        y=alt.Y(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        color=alt.Color(f"{separate_field}:N", title=None, scale=color_scale),
        tooltip=DEMO_TOOLTIP
    ).properties(width=67,height=400)

    result_vis = bar.encode(
        column=alt.Column(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="bottom", title="Age group", titleOrient="bottom")
        )
    )

    if is_comparison:
        result_vis = result_vis.add_selection(
            sex_selection
        ).add_selection(
            legend_selection
        )
    else:
        result_vis = result_vis.add_selection(
            country_selection
        )

    return result_vis

## Demographics by country

In [None]:
demo = apply_grouped_bar_theme(demographics(is_percent=True), strokeColor="lightgray").properties(title={
    "text": "Demographics by sex",
    "subtitle": SUBTITLE,
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by sex") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files

## Comparison between country

In [None]:
demo = apply_grouped_bar_theme(demographics(is_percent=True, is_comparison=True), strokeColor="lightgray").properties(title={
    "text": "Demographics by country",
    "subtitle": SUBTITLE,
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by country") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files

## Pyramid Plot (WIP)
Know issues: https://github.com/vega/vega-lite/issues/4680

In [None]:
def demographics_pyramid(is_percent=False):

    legend_selection = alt.selection_multi(fields=[COLUMNS.SITE_ID], bind="legend")

    # Filter
    filtered_chart = alt.Chart(df_dm).transform_filter(
        alt.datum[COLUMNS.SEX] == "Male"
    ).transform_filter(
        legend_selection
    )

    DEMO_TOOLTIP = [
        alt.Tooltip(COLUMNS.SITE_ID, title="Country"),
        alt.Tooltip(COLUMNS.SEX, title="Sex"),
        alt.Tooltip(COLUMNS.AGE_GROUP, title="Age group"),
        alt.Tooltip(COLUMNS.NUM_PATIENTS, title="# of patients")
    ]

    y_field = "per_patients" if is_percent else COLUMNS.NUM_PATIENTS
    if is_percent:
        DEMO_TOOLTIP += [alt.Tooltip("per_patients", title="% of patients", format=".1f")]
    
    # Render left
    y_title = "Percentage of patients (%)" if is_percent else "Number of patients"
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        y=alt.Y(f"{COLUMNS.SITE_ID}:N", title=None, axis=None),
        color=alt.Color(f"{COLUMNS.SITE_ID}:N", title=None, scale=alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)),
        tooltip=DEMO_TOOLTIP
    ).properties(width=400,height=67)

    left = bar.encode(
        row=alt.Row(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="left", title=None, titleOrient="bottom")
        )
    )
    
    # Render right
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        y=alt.Y(f"{COLUMNS.SITE_ID}:N", title=None, axis=None),
        color=alt.Color(f"{COLUMNS.SITE_ID}:N", title=None, scale=alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)),
        tooltip=DEMO_TOOLTIP
    ).properties(width=400,height=67)

    right = bar.encode(
        row=alt.Row(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="right", title=None, titleOrient="bottom")
        )
    )
    
    result_vis = alt.hconcat(left, right).resolve_scale(x="independent", y="independent", color="independent").properties(title="Male").add_selection(
        legend_selection
    )

    # .properties(title={
    #     "text": "Demographics",
    #     "subtitle": SUBTITLE,
    #     "subtitleColor": "gray",
    #     "anchor": "start",
    #     "dx": 60
    # })

    return result_vis

In [None]:
demo = apply_pyramid_theme(demographics_pyramid(is_percent=True), legend_orient="right", strokeColor="lightgray")
demo.display()

for_website(demo, "Demographics", "Demographics by country and sex") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files