In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
from os.path import join
from web import for_website

from constants import COLUMNS
from utils import (
    read_combined_demographics_df,
    read_combined_by_country_demographics_df,
    read_combined_by_site_demographics_df,
    get_visualization_subtitle,
    apply_theme
)

# Required Setups
- All datasets should be placed in `../data/combined` (e.g., `../data/combined/Demographics-Combinedyymmdd.csv`).

In [None]:
# Path to save *.PNG files
SAVE_DIR = join("..", "output")

# Country Info
ALL_COUNTRY = "All countries"
ALL_COUNTRY_COLOR = "#444444"
COUNTRIES = [ "France", "Germany", "Italy", "Singapore", "USA" ]
COUNTRY_COLOR = [ "#0072B2", "#E69F00", "#009E73", "#CC79A7", "#D55E00" ]
COLOR_BY_COUNTRY = { COUNTRIES[i]: COUNTRY_COLOR[i] for i in range(len(COUNTRIES)) }

# Data Preprocess

In [None]:
def preprocess_demo_df(df_dm):
    
    # Drop unused columns before preprocessing for the simplicity
    df_dm = df_dm.drop(columns=[
        COLUMNS.UNMASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.UNMASKED_SITES_AGE_0TO2,
        COLUMNS.UNMASKED_SITES_AGE_3TO5,
        COLUMNS.UNMASKED_SITES_AGE_6TO11,
        COLUMNS.UNMASKED_SITES_AGE_12TO17,
        COLUMNS.UNMASKED_SITES_AGE_18TO25,
        COLUMNS.UNMASKED_SITES_AGE_26TO49,
        COLUMNS.UNMASKED_SITES_AGE_50TO69,
        COLUMNS.UNMASKED_SITES_AGE_70TO79,
        COLUMNS.UNMASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_SITES_TOTAL_PATIENTS,
        COLUMNS.MASKED_SITES_AGE_0TO2,
        COLUMNS.MASKED_SITES_AGE_3TO5,
        COLUMNS.MASKED_SITES_AGE_6TO11,
        COLUMNS.MASKED_SITES_AGE_12TO17,
        COLUMNS.MASKED_SITES_AGE_18TO25,
        COLUMNS.MASKED_SITES_AGE_26TO49,
        COLUMNS.MASKED_SITES_AGE_50TO69,
        COLUMNS.MASKED_SITES_AGE_70TO79,
        COLUMNS.MASKED_SITES_AGE_80PLUS,
        COLUMNS.MASKED_UPPER_BOUND_TOTAL_PATIENTS,
        COLUMNS.TOTAL_PATIENTS,
    ])

    # Wide to long
    df_dm = pd.melt(df_dm, id_vars=[
        COLUMNS.SITE_ID,
        COLUMNS.SEX,
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])
    df_dm = df_dm.rename(columns={"variable": COLUMNS.AGE_GROUP, "value": COLUMNS.NUM_PATIENTS})

    # Leave only the 'upper' and 'under' values for the certain 'age_group' only
    """
    # TODO: We do not use error bars anymore
    for c in [
            COLUMNS.AGE_0TO2,
            COLUMNS.AGE_3TO5,
            COLUMNS.AGE_6TO11,
            COLUMNS.AGE_12TO17,
            COLUMNS.AGE_18TO25,
            COLUMNS.AGE_26TO49,
            COLUMNS.AGE_50TO69,
            COLUMNS.AGE_70TO79,
            COLUMNS.AGE_80PLUS
            ]:
        filter_c = df_dm[COLUMNS.AGE_GROUP] == c
        df_dm.loc[filter_c, "upper"] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] + df_dm.loc[filter_c, "masked_upper_bound_" + c]
        df_dm.loc[filter_c, "under"] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS]
        df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] = df_dm.loc[filter_c, COLUMNS.NUM_PATIENTS] + df_dm.loc[filter_c, "masked_upper_bound_" + c] / 2.0
    """

    df_dm[COLUMNS.SEX] = df_dm[COLUMNS.SEX].apply(lambda x: x.capitalize())

    # Drop unused columns
    df_dm = df_dm.drop(columns=[
        COLUMNS.MASKED_UPPER_BOUND_AGE_0TO2,
        COLUMNS.MASKED_UPPER_BOUND_AGE_3TO5,
        COLUMNS.MASKED_UPPER_BOUND_AGE_6TO11,
        COLUMNS.MASKED_UPPER_BOUND_AGE_12TO17,
        COLUMNS.MASKED_UPPER_BOUND_AGE_18TO25,
        COLUMNS.MASKED_UPPER_BOUND_AGE_26TO49,
        COLUMNS.MASKED_UPPER_BOUND_AGE_50TO69,
        COLUMNS.MASKED_UPPER_BOUND_AGE_70TO79,
        COLUMNS.MASKED_UPPER_BOUND_AGE_80PLUS,
    ])

    # Add a percentage column
    unique_site_ids = df_dm[COLUMNS.SITE_ID].unique()
    for site in unique_site_ids:
        unique_sex = df_dm[df_dm[COLUMNS.SITE_ID] == site][COLUMNS.SEX].unique()
        for sex in unique_sex:
            df_filter = (df_dm[COLUMNS.SITE_ID] == site) & (df_dm[COLUMNS.SEX] == sex)
            total = df_dm.loc[
                df_filter, 
                COLUMNS.NUM_PATIENTS
            ].sum()
            
            df_dm.loc[
                df_filter, 
                "per_patients"
            ] = df_dm.loc[df_filter, COLUMNS.NUM_PATIENTS] / total * 100
    
    # Use readable names
    df_dm.loc[df_dm[COLUMNS.SITE_ID] == "Combined", COLUMNS.SITE_ID] = ALL_COUNTRY
    readable_age_group = {
        COLUMNS.AGE_0TO2: "0 - 2",
        COLUMNS.AGE_3TO5: "3 - 5",
        COLUMNS.AGE_6TO11: "6 - 11",
        COLUMNS.AGE_12TO17: "12 - 17",
        COLUMNS.AGE_18TO25: "18 - 25",
        COLUMNS.AGE_26TO49: "26 - 49",
        COLUMNS.AGE_50TO69: "50 - 69",
        COLUMNS.AGE_70TO79: "70 - 79",
        COLUMNS.AGE_80PLUS: "80+"
    }
    df_dm[COLUMNS.AGE_GROUP] = df_dm[COLUMNS.AGE_GROUP].apply(lambda x: readable_age_group[x])

    return df_dm

# Process data
df_dm = read_combined_by_country_demographics_df()
df_dm = preprocess_demo_df(df_dm)

df_dm_combined = read_combined_demographics_df()
df_dm_combined = preprocess_demo_df(df_dm_combined)

# Merge
df_dm = pd.concat([df_dm, df_dm_combined])

df_dm

# Visualizations

In [None]:
def demographics(is_percent=False, by_country=False):

    # Selection components
    country_dropdown = alt.binding_select(options=[ALL_COUNTRY] + COUNTRIES)
    country_selection = alt.selection_single(fields=[COLUMNS.SITE_ID], bind=country_dropdown, name="Country", init={COLUMNS.SITE_ID: ALL_COUNTRY})
    sex_dropdown = alt.binding_select(options=["All", "Male", "Female"])
    sex_selection = alt.selection_single(fields=[COLUMNS.SEX], bind=sex_dropdown, name="Sex", init={COLUMNS.SEX: "All"})
    color_field = COLUMNS.SITE_ID if by_country else COLUMNS.SEX
    legend_selection = alt.selection_multi(fields=[color_field], bind="legend")

    # Filter
    filtered_chart = alt.Chart(df_dm).transform_filter(
        alt.datum[COLUMNS.SEX] != "Other"
    ).transform_filter(
        legend_selection
    )

    if by_country:
        filtered_chart = filtered_chart.transform_filter(
            sex_selection
        ).transform_filter(
            alt.datum[COLUMNS.SITE_ID] != ALL_COUNTRY
        )
    else:
        filtered_chart = filtered_chart.transform_filter(
            country_selection
        ).transform_filter(
            alt.datum[COLUMNS.SEX] != "All"
        )

    DEMO_TOOLTIP = [
        alt.Tooltip(COLUMNS.SITE_ID, title="Country"),
        alt.Tooltip(COLUMNS.SEX, title="Sex"),
        alt.Tooltip(COLUMNS.AGE_GROUP, title="Age group"),
        alt.Tooltip(COLUMNS.NUM_PATIENTS, title="Number of patients")
    ]

    y_field = "per_patients" if is_percent else COLUMNS.NUM_PATIENTS
    if is_percent:
        DEMO_TOOLTIP += [alt.Tooltip("per_patients", title="Percentage of patients (%)", format=".1f")]

    # Render
    color_scale = alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)  if by_country else alt.Scale(domain=["Male", "Female"], range=["#3366cc", "#dc3912"]) 
    y_title = "Percentage of patients (%)" if is_percent else "Number of patients"
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{color_field}:N", title=None, axis=None),
        y=alt.Y(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        color=alt.Color(f"{color_field}:N", title=None, scale=color_scale),
        tooltip=DEMO_TOOLTIP
    ).properties(width=67,height=400)

    result_vis = bar.encode(
        column=alt.Column(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="bottom", title="Age group", titleOrient="bottom")
        )
    ).add_selection(
        legend_selection
    )

    if by_country:
        result_vis = result_vis.add_selection(
            sex_selection
        )
    else:
        result_vis = result_vis.add_selection(
            country_selection
        )

    return result_vis

## Demographics by sex

In [None]:
demo = apply_theme(demographics(is_percent=True), legend_stroke_color="lightgray", axis_title_font_size=18).properties(title={
    "text": "Demographics by Sex",
    "subtitle": get_visualization_subtitle(),
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by sex") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files

## Demographics by country

In [None]:
demo = apply_theme(demographics(is_percent=True, by_country=True), legend_stroke_color="lightgray").properties(title={
    "text": "Demographics by Country",
    "subtitle": get_visualization_subtitle(),
    "subtitleColor": "gray",
    "anchor": "start",
    "dx": 60
})
demo.display()

for_website(demo, "Demographics", "Demographics by country") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files

## Pyramid Plot (WIP)
Known issues: https://github.com/vega/vega-lite/issues/4680

In [None]:
def demographics_pyramid(is_percent=False):

    legend_selection = alt.selection_multi(fields=[COLUMNS.SITE_ID], bind="legend")

    # Filter
    filtered_chart = alt.Chart(df_dm).transform_filter(
        alt.datum[COLUMNS.SEX] == "Male"
    ).transform_filter(
        legend_selection
    )

    DEMO_TOOLTIP = [
        alt.Tooltip(COLUMNS.SITE_ID, title="Country"),
        alt.Tooltip(COLUMNS.SEX, title="Sex"),
        alt.Tooltip(COLUMNS.AGE_GROUP, title="Age group"),
        alt.Tooltip(COLUMNS.NUM_PATIENTS, title="# of patients")
    ]

    y_field = "per_patients" if is_percent else COLUMNS.NUM_PATIENTS
    if is_percent:
        DEMO_TOOLTIP += [alt.Tooltip("per_patients", title="% of patients", format=".1f")]
    
    # Render left
    y_title = "Percentage of patients (%)" if is_percent else "Number of patients"
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        y=alt.Y(f"{COLUMNS.SITE_ID}:N", title=None, axis=None),
        color=alt.Color(f"{COLUMNS.SITE_ID}:N", title=None, scale=alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)),
        tooltip=DEMO_TOOLTIP
    ).properties(width=400,height=67)

    left = bar.encode(
        row=alt.Row(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="left", title=None, titleOrient="bottom")
        )
    )
    
    # Render right
    bar = filtered_chart.mark_bar().encode(
        x=alt.X(f"{y_field}:Q", title=y_title, axis=alt.Axis(tickCount=5)),
        y=alt.Y(f"{COLUMNS.SITE_ID}:N", title=None, axis=None),
        color=alt.Color(f"{COLUMNS.SITE_ID}:N", title=None, scale=alt.Scale(domain=COUNTRIES, range=COUNTRY_COLOR)),
        tooltip=DEMO_TOOLTIP
    ).properties(width=400,height=67)

    right = bar.encode(
        row=alt.Row(
            "age_group:O",
            sort=["age_0to2","age_3to5","age_6to11","age_12to17","age_18to25","age_26to49","age_50to69","age_70to79", "age_80plus"],
            header=alt.Header(labelOrient="right", title=None, titleOrient="bottom")
        )
    )
    
    result_vis = alt.hconcat(left, right).resolve_scale(x="independent", y="independent", color="independent").properties(title="Male").add_selection(
        legend_selection
    )

    # .properties(title={
    #     "text": "Demographics",
    #     "subtitle": SUBTITLE,
    #     "subtitleColor": "gray",
    #     "anchor": "start",
    #     "dx": 60
    # })

    return result_vis

In [None]:
demo = apply_pyramid_theme(demographics_pyramid(is_percent=True), legend_orient="right", strokeColor="lightgray")
demo.display()

for_website(demo, "Demographics", "Demographics by country and sex") # TODO: Remove this before deploying notebook
# save(demo, join(SAVE_DIR, f"demographics.png".lower()), scalefactor=2.0) # Uncomment this to save *.png files