In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from dash import Dash, html, dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm

from sqlalchemy import create_engine, Column, Integer, BigInteger, String, Float, Date
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker


## Load data

In [7]:
le_calculated_risk_excluded_all_countries_df = pd.read_csv(
    os.path.join(
        'data',
        'risk_excluded_le.csv'
    )
)

risk_impact = le_calculated_risk_excluded_all_countries_df[['location_id', 'age', 'sex_id', 'rei_id', 'E_x_diff']]
life_expectancy = le_calculated_risk_excluded_all_countries_df[['location_id', 'age', 'sex_id', 'E_x_val']]

# load risk manageable ierarchy
risks_parents_names_manageable = pd.read_csv(
    os.path.join('data', 'risks_parents_names_manageable.csv')
)

#load risk factors names manageable
risks_names_manageable = pd.read_csv(
    os.path.join('data', 'risks_names_manageable.csv')
)

# load risk_ierarchy and names to id mapping
rei_ierarchy = pd.read_csv(
    os.path.join('data','rei_ierarchy.csv')
)

rei_ierarchy_3_level_manageable = pd.read_csv(
    os.path.join(
        'data',
        'rei_ierarchy_3_level_manageable.csv'
    )
)

# load mapping countries ids names and centroid coordinates
gbd_country_name_id_iso_centroid = pd.read_csv(
    os.path.join('data', 'gbd_country_name_id_iso_centroid.csv'),
)

# load code book with mappings names and ids entities from gbd research
code_book = pd.read_csv(
    os.path.join('data','code_book.csv')
).iloc[1:, 1:]

# load risk names to color map
rei_color_map = {
    k:v for k,v in pd.read_csv(
        os.path.join('data', 'rei_color_map.csv')
    ).values
}


### Set configs for plots

In [8]:
# set mapping risk manageable parent to list of their childrens
risks_parents_names_manageable = {
    x: [
        y for y in 
        risks_parents_names_manageable
        .copy()
        .query('rei_parent_name == @x')
        .rei_name.unique()
    ]
    for x in 
    risks_parents_names_manageable
    .rei_parent_name
    .unique()
}

# transform risk factors names to list
risks_names_manageable = list(np.concatenate(risks_names_manageable.values))

# set mapping risk factors ids to their parents
risk_id_to_parent_id = {
    int(k):int(v)
    for k,v in rei_ierarchy[['rei_id','parent_id']].values
}

# set mapping sex names to id
sex_name_to_id = {
    key: int(value) 
    for key, value in code_book[['sex_label', 'sex_id']].dropna().values[1:]
}

sex_id_to_name = {k: v for v, k in sex_name_to_id.items()}

# set mapping risk factors names to id
risks_name_to_id = {
    key: int(value) 
    for key, value in code_book[['rei_name', 'rei_id']].dropna().values
}

risks_id_to_name = {k: v for v, k in risks_name_to_id.items()}

# set mapping country name to gbd id
location_name_to_id = {
    key: int(value) 
    for key, value in 
    gbd_country_name_id_iso_centroid[['location_name', 'location_id']].values
}

location_id_to_name = {k: v for v, k in location_name_to_id.items()}

# set mapping countries id to iso countries codes 
gbd_id_to_iso_code_map = {
    int(k): v for k,v in 
    gbd_country_name_id_iso_centroid[['location_id', 'iso_code']].values
}

# set mapping country id to longitude and latitude of their centroids 
gbd_country_id_to_centroid_map = {
    int(x[0]): [x[1], x[2]]
    for x in 
    gbd_country_name_id_iso_centroid
    [['location_id', 'latitude', 'longitude']]
    .values
}

# set additional colors
color_mapping = {
    'Default life expectancy': {
        'Male': 'rgba(89, 52, 235, 0.5)',
        'Female': 'rgba(235, 52, 155, 0.5)'
    },
    'Estimated life extension': {
        'Male': 'rgba(89, 52, 235, 0.9)',
        'Female': 'rgba(235, 52, 155, 0.9)'
    }
}


### Prepare data function

In [9]:
def prepare_data(
    location_name: str,
    age: int,
    sex_name: int,
    risk_factors_names: list,
    risks_parents_names_manageable: dict,
    risk_impact: pd.DataFrame,
    life_expectancy: pd.DataFrame,
    risk_id_to_parent_id: dict,
    location_name_to_id: dict,
    sex_name_to_id: dict,
    risks_name_to_id: dict,
    gbd_id_to_iso_code_map: dict,
    is_dietary_risks_groupped: bool=True,
    round_n_decimals: int=2,
) -> pd.DataFrame:

    ################################################
    # get risk names from parents
    risk_factors_names_source = risks_names_manageable.copy()
    risk_factors_names = np.concatenate([
        risks_parents_names_manageable[x]
        if x in risks_parents_names_manageable.keys()
        else [x]
        for x in risk_factors_names
    ])

    # get ids from names
    location_id = location_name_to_id[location_name]

    risk_factors_id = [risks_name_to_id[x] for x in risk_factors_names]

    sex_id = sex_name_to_id[sex_name]

    ################################################
    # filtering data by setted criteries
    risk_impact_filtered = risk_impact.query(
        f'location_id == {location_id}'
        ' and rei_id in @risk_factors_id'
    )

    life_expectancy_filtered = life_expectancy.query(
        f'location_id == {location_id}'
        f' and age == {age}'
    )[['sex_id', 'E_x_val']]

    risk_impact_filtered = (
        risk_impact_filtered
        .assign(rei_parent_id = lambda x: x.rei_id.map(risk_id_to_parent_id))
    )

    risk_impact_filtered['rei_name'] = (
        risk_impact_filtered.copy()['rei_id']
        .map({k:v for v,k in risks_name_to_id.items()})
    )

    risk_impact_filtered['rei_parent_name'] = (
        risk_impact_filtered.copy()['rei_parent_id']
        .map({k:v for v,k in risks_name_to_id.items()})
    )

    risk_impact_filtered_cur_age = (
        risk_impact_filtered
        .query(f'age == {age}')
        .round(decimals=round_n_decimals)
    )

    if is_dietary_risks_groupped == True:
        groupped_dietary_risks = (
            risk_impact_filtered
            .query('rei_parent_name == "Dietary risks"')
            .groupby(by=['sex_id', 'age', 'rei_parent_name'])
            [['E_x_diff']]
            .sum()
            .reset_index()
        )

        groupped_dietary_risks.columns = ['sex_id', 'age', 'rei_name', 'E_x_diff']
        
        risk_impact_filtered.columns
        
        risk_impact_filtered_dietary_groupped = pd.concat(
            [
                risk_impact_filtered
                .query('rei_parent_name != "Dietary risks"')
                [['sex_id', 'age', 'rei_name', 'E_x_diff']],
                groupped_dietary_risks
            ], axis=0,
        ).sort_values(by=['sex_id', 'age', 'E_x_diff'], ascending=False)
        
        risk_impact_filtered_dietary_groupped_cur_age = (
            risk_impact_filtered_dietary_groupped
            .query(f'age == {age}')
        )
        
    else:
        risk_impact_filtered_dietary_groupped = risk_impact_filtered.copy()
        risk_impact_filtered_dietary_groupped_cur_age = risk_impact_filtered_cur_age.copy()

    risk_impact_filtered_dietary_groupped = (
        risk_impact_filtered_dietary_groupped
        .query('sex_id == @sex_id')
    )

    ################################################
    # create report with summary extension by sex
    report_male = pd.DataFrame(
        {
            'Default life expectancy': (
                life_expectancy_filtered
                .query(f'sex_id == {sex_name_to_id["Male"]}')
                [['E_x_val']]
                .values[0]
                +
                age
            ),
            'Estimated life extension': (
                risk_impact_filtered_dietary_groupped_cur_age
                .query(f'sex_id == {sex_name_to_id["Male"]}')
                [['E_x_diff']]
                .sum()
                .values[0]
            ),
        },
        index=['val']
    )

    report_male['Extended life expectancy'] = (
        report_male['Default life expectancy']
        +
        report_male['Estimated life extension']
    )

    report_female = pd.DataFrame(
        {
            'Default life expectancy': (
                life_expectancy_filtered
                .query(f'sex_id == {sex_name_to_id["Female"]}')
                [['E_x_val']]
                .values[0]
                +
                age
            ),
            'Estimated life extension': (
                risk_impact_filtered_dietary_groupped_cur_age
                .query(f'sex_id == {sex_name_to_id["Female"]}')
                [['E_x_diff']]
                .sum()
                .values[0]
            ),
        },
        index=['val']
    )

    report_female['Extended life expectancy'] = (
        report_female['Default life expectancy']
        +
        report_female['Estimated life extension']
    )

    report_male['sex_name'] = 'Male'
    report_female['sex_name'] = 'Female'

    report = pd.concat([report_male, report_female])
    report.reset_index(inplace=True)
    report.set_index(['sex_name', 'index'], inplace=True)
    report = report.round(decimals=round_n_decimals).reset_index()

    ################################################
    # create data frames with estimated extension by countries
    risk_impact_by_countries = risk_impact.copy().query(
        'rei_id in @risk_factors_id'
        ' and age == @age'
        ' and sex_id == @sex_id'
    )

    risk_impact_by_countries['iso_code'] = (
        risk_impact_by_countries['location_id']
        .map(gbd_id_to_iso_code_map, na_action='ignore')
    )

    # sum by risk
    risk_impact_by_countries = (
        risk_impact_by_countries
        .groupby(by=['iso_code', 'location_id'])
        [['E_x_diff']].sum()
        .reset_index()
    )

    risk_impact_by_countries['location_name'] = (
        risk_impact_by_countries['location_id']
        .map({v:k for k,v in location_name_to_id.items()})
    )

    risk_impact_by_countries.columns = ['iso_code', 'location_id', 'Years', 'location_name']

    # sumarry impact for all selected risk factors
    risk_manageable_impact_summary = (
        risk_impact
        .query('rei_id in @risk_factors_id')
        .groupby(by=['location_id', 'age', 'sex_id'])
        [['E_x_diff']]
        .sum()
        .reset_index()
    )

    risk_impact_sum_for_all_countries = (
        risk_manageable_impact_summary
        .groupby(by=['age', 'sex_id'])
        [['E_x_diff']]
        .mean()
        .join(
            risk_manageable_impact_summary
            .groupby(by=['age', 'sex_id'])
            [['E_x_diff']]
            .quantile(q=0.05),
            rsuffix='_q05',
            how='left'
        )
        .join(
            risk_manageable_impact_summary
            .groupby(by=['age', 'sex_id'])
            [['E_x_diff']]
            .quantile(q=0.95),
            rsuffix='_q95',
            how='left'
        )
        .reset_index()
        .sort_values(by='age')
    )

    risk_impact_sum_for_all_countries['sex_name'] = (
        risk_impact_sum_for_all_countries['sex_id'].map({v: k for k, v in sex_name_to_id.items()})
    )

    ################################################
    # calculate total extension for title
    total_extension = round(
        risk_impact_filtered_cur_age
        .query(f'sex_id == @sex_id')
        .E_x_diff
        .sum(),
        ndigits=round_n_decimals
    )

    ################################################
    # create suptitles for plots

    life_expectancy_extension_male = round(
        report_male.loc["val", "Estimated life extension"],
        round_n_decimals
    )

    life_expectancy_extension_female = round(
        report_female.loc["val", "Estimated life extension"],
        round_n_decimals
    )

    life_expectancy_extension = {
        'Male': life_expectancy_extension_male,
        'Female': life_expectancy_extension_female
    }

    life_expectancy_extension_by_country_suptitle = (
        f' ###### Еxtension of life expectancy'
        f' by {len(gbd_id_to_iso_code_map.keys())} countries,'
        f' with excluding {len(risk_factors_names_source)} manageable risk factors'
        f' for {sex_name} aged {age} y.o.'
    )

    life_expectancy_extension_by_risk_suptitle = (
        f' ###### Еxtension of life expectancy,'
        f' by excluded {len(risk_factors_names_source)} manageable risk factors,'
        f' for {sex_name},'
        f' aged {age} y.o.,'
        f' in {location_name}'        
    )

    life_expectancy_extension_by_sex_suptitle = (
        f' ###### Extension of life expectancy by sex'
        f' for age {age} y.o.,'
        f' in {location_name}'        
    )

    life_expectancy_extension_by_age_suptitle = (
        f' ###### Extension of estimated life expectancy by age,'
        f' for {sex_name} in {location_name}'
    )

    return (
        total_extension,
        risk_impact_by_countries,
        risk_impact_sum_for_all_countries,
        risk_impact_filtered_cur_age,
        report,
        risk_impact_filtered_dietary_groupped,
        life_expectancy_extension_by_risk_suptitle,
        life_expectancy_extension_by_sex_suptitle,
        life_expectancy_extension_by_country_suptitle,
        life_expectancy_extension_by_age_suptitle,
    )


### Get user settings

In [74]:
age = 37
location_name = 'Russian Federation'
sex_name = 'Male'


In [75]:
(
    total_extension, # for title
    risk_impact_by_countries, # for plot distribution by country
    risk_impact_sum_for_all_countries, # for plot 95 conf interval by age
    risk_impact_filtered_cur_age, # for plot distribution by risk
    report, # for plot distribution by sex
    risk_impact_filtered_dietary_groupped, # for plot distribution by age
    life_expectancy_extension_by_risk_suptitle,
    life_expectancy_extension_by_sex_suptitle,
    life_expectancy_extension_by_country_suptitle,
    life_expectancy_extension_by_age_suptitle,
) = prepare_data(
    location_name=location_name,
    age=age,
    sex_name=sex_name,
    risk_factors_names=risks_names_manageable,
    risks_parents_names_manageable=risks_parents_names_manageable,
    risk_impact=risk_impact,
    life_expectancy=life_expectancy,
    risk_id_to_parent_id=risk_id_to_parent_id,
    location_name_to_id=location_name_to_id,
    sex_name_to_id=sex_name_to_id,
    risks_name_to_id=risks_name_to_id,
    gbd_id_to_iso_code_map=gbd_id_to_iso_code_map,
    is_dietary_risks_groupped=True,
    round_n_decimals=2,
)

In [12]:
report

Unnamed: 0,sex_name,index,Default life expectancy,Estimated life extension,Extended life expectancy
0,Male,val,70.97,8.74,79.71
1,Female,val,79.25,6.17,85.42


### Plot by countries

In [71]:
def life_expectancy_extension_by_country_plotter(
    risk_impact_by_countries: pd.DataFrame,
    gbd_country_id_to_centroid_map,
    location_name_to_id,
    location_name,
    total_extension,
) -> go.Figure:

    fig = px.choropleth(
        risk_impact_by_countries,
        locations="iso_code",
        color="Years",
        hover_name="location_name",
        title = "",
        color_continuous_scale=px.colors.sequential.dense
    )

    fig.update_layout(
        geo=dict(
            showframe=False,
            showcoastlines=False,
            projection_type='natural earth',
        ),
        height=400,
        margin = dict(t=0, l=0, r=0, b=0),
        coloraxis=dict(
            colorbar=dict(
                orientation='v',
                thickness=12,
                tickfont=dict(size=12),
                len=0.8,
                yanchor='middle',
            )
        ),
    )
    fig["layout"].pop("updatemenus")

    fig.add_trace(
        go.Scattergeo(
            lat=[gbd_country_id_to_centroid_map[
                location_name_to_id[location_name]
            ][0]],
            lon=[gbd_country_id_to_centroid_map[
                location_name_to_id[location_name]
            ][1]],
            mode='markers+text',
            marker=dict(
                size=7,
                color='white'
            ),
            text=[f'<b>{total_extension}<br>years</b>'],
            textfont=dict(
                color='black',
                size=15.02,
            ),
            textposition='top center',
            hoverinfo='skip',
            showlegend=False,
        )
    )

    fig.add_trace(
        go.Scattergeo(
            lat=[gbd_country_id_to_centroid_map[
                location_name_to_id[location_name]
            ][0]],
            lon=[
                    gbd_country_id_to_centroid_map
                    [
                        location_name_to_id
                        [
                            location_name
                        ]
                    ]
                    [1]
                ],
            mode='markers+text',
            marker=dict(
                size=6,
                color='rgb(250, 46, 35)'
            ),
            text=[f'<b>{total_extension}<br>years</b>'],
            textfont=dict(
                color='rgb(250, 46, 35)',
                size=15,
            ),
            textposition='top center',
            hoverinfo='skip',
            showlegend=False,
        )
    )

    return fig


In [76]:
life_expectancy_extension_by_country_plotter(
    risk_impact_by_countries=risk_impact_by_countries,
    gbd_country_id_to_centroid_map=gbd_country_id_to_centroid_map,
    location_name_to_id=location_name_to_id,
    location_name=location_name,
    total_extension=total_extension,
)

### Plot by risk factors

In [77]:
def life_expectancy_extension_by_risk_plotter(
    risk_impact_filtered_dietary_groupped: pd.DataFrame,
    rei_color_map: dict,
    age: int,
) -> go.Figure:

    data_to_plot = (
        risk_impact_filtered_dietary_groupped
        .query('age == @age')
        [
            [
                'E_x_diff',
                'rei_name',
        ]
        ]
        .sort_values(by='E_x_diff', ascending=False)
        .round(2)
    )

    # Sample data for the pie chart
    total_sum = round(data_to_plot['E_x_diff'].sum(), 1)

    outer_labels = data_to_plot['rei_name']
    outer_values = data_to_plot['E_x_diff']
    outer_colors = [rei_color_map[x] for x in data_to_plot['rei_name']]

    # Create the pie chart trace
    outer_pie = go.Pie(
        labels=outer_labels,
        values=outer_values,
        textinfo='value',
        textfont=dict(size=15),
        hole=0.65,  # Set the size of the hole inside the pie chart (0.6 means 60% of the radius)
        marker=dict(colors=outer_colors),  # Set colors for each sector
        sort=False
    )

    # Calculate the center position for the text annotation
    center_x = 0.5
    center_y = 0.55


    fig = go.Figure(data=[outer_pie])

    # Set layout properties for the figure
    fig.update_layout(
        annotations=[
            dict(
                x=center_x,
                y=center_y,
                showarrow=False,
                text=f'<b>+{total_sum}</b>',  # Text to display in the center
                font=dict(size=35, color='black',),
            ),
            dict(
                x=center_x,
                y=center_y - 0.1,
                showarrow=False,
                text=f'<b>years</b>',  # Text to display in the center
                font=dict(size=20, color='black',),
            )

        ],
    )

    fig.update_layout(
        template='plotly_white',
        height=300,
        margin = dict(t=0, l=0, r=0, b=0),
        legend=dict(
            title="Excluded risk factors",
        ),
    )

    # Show the figure
    return fig


In [81]:
life_expectancy_extension_by_risk = life_expectancy_extension_by_risk_plotter(
    risk_impact_filtered_dietary_groupped=risk_impact_filtered_dietary_groupped,
    rei_color_map=rei_color_map,
    age=age
)
life_expectancy_extension_by_risk

### Plot by sex

In [44]:
def life_expectancy_extension_by_sex_plotter(
    report: pd.DataFrame,
    age: int,
    color_mapping: dict,
    width: float=0.4
) -> go.Figure:

    x_data = {
        sex_name: {
            'Default life expectancy': report.loc[sex_name, 'Default life expectancy'],
            'Estimated life extension': report.loc[sex_name, 'Estimated life extension']
        }    for sex_name in ['Male', 'Female']
    }

    maximum_le = max(
        [
            report.loc[sex_name, 'Extended life expectancy']
            for sex_name in ['Male', 'Female']
        ]
    )
    
    report = round(report.copy(), 1)

    fig = go.Figure()

    for sex_name in ['Female', 'Male']:

        percent_extension = round(
            100 * (
                report.loc[sex_name, 'Estimated life extension']
                /
                report.loc[sex_name, 'Default life expectancy']
            ),
            1
        )

        percent_default = 100 - percent_extension

        fig.add_trace(
            go.Bar(
                y=[x_data[sex_name]['Default life expectancy']],
                x=[sex_name],
                orientation='v',
                marker=dict(
                    color=color_mapping['Default life expectancy'][sex_name],
                    line=dict(width=1)
                ),
                text=(
                    f"{report.loc[sex_name, 'Default life expectancy']}"
                    f"<br>({percent_default} %)"
                ),
                textfont=dict(color='white'),
                insidetextanchor='end',
                hovertemplate='Default life expectancy<br> is: %{y} years',
                name=f'{sex_name} default life expectancy',
                width=width,
                legendgroup=sex_name,
                showlegend=True
            )
        )

        fig.add_trace(
            go.Bar(
                y=[x_data[sex_name]['Estimated life extension']],
                x=[sex_name],
                orientation='v',
                marker=dict(
                    color=color_mapping['Estimated life extension'][sex_name],
                    line=dict(width=1)
                ),
                text=(
                    f"{report.loc[sex_name, 'Estimated life extension']}"
                    f"<br>({percent_extension} %)"
                ),
                textfont=dict(color='white'),
                textposition = "inside",
                insidetextanchor='middle',
                hovertemplate='Extension of life expectancy<br> is: %{y} years',
                name=f'{sex_name} extension',
                width=width,
                legendgroup=sex_name,
                showlegend=True
            )
        )

    fig.update_layout(
        template='plotly_white',
        height=400,
        yaxis=dict(
                range=(age, int(maximum_le * 1.05)),
                tickvals=list(range(age, int(maximum_le // 1 + 1), 5)),
                zeroline=False,
                showgrid=False,
                showline=False,
                domain=[0.15, 1],
                title='Life expectancy (Years)'
            ),
        xaxis=dict(
            showgrid=False,
            showline=False,
            showticklabels=False,
            zeroline=False,
        ),
        legend=dict(
            y=0.1,
            orientation="h"
        ),
        barmode='stack',
        bargap=0.001,
        showlegend=True,
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return fig


In [45]:
life_expectancy_extension_by_sex_plotter(
    report.set_index('sex_name'),
    age,
    color_mapping,
)

### Plot by age

In [46]:
def life_expectancy_extension_by_age_plotter(
    risk_impact_filtered_dietary_groupped: pd.DataFrame,
    rei_color_map: dict,
    age: int,
):

    age_arr = np.sort(risk_impact_filtered_dietary_groupped.age.unique())

    rei_sorted_by_val_sum = (
        risk_impact_filtered_dietary_groupped
        .groupby(by='rei_name')['E_x_diff']
        .sum()
        .sort_values(ascending=False)
        .index
    )
    
    fig = go.Figure()

    y = np.array([0 for _ in risk_impact_filtered_dietary_groupped.age.unique()], dtype='float64')

    for rei_name in rei_sorted_by_val_sum:

        cur_y = np.array(
            risk_impact_filtered_dietary_groupped
            .query(
                'rei_name == @rei_name'
            )
            .sort_values(by='age')
            ['E_x_diff'].values,
            dtype='float64'
        )

        y += cur_y

        fig.add_trace(
            go.Scatter(
                customdata=cur_y,
                x=age_arr,
                y=y,
                line=dict(
                    width=0.1,
                    color=rei_color_map[rei_name]
                ),
                fill='tonexty',
                hovertemplate =(
                    f'{rei_name},<br> '
                    'age: %{x:}, <br>extension: %{customdata:.2f}<extra></extra>'
                ),
                hoverinfo=None,
                name=rei_name
            )
        )
    if age is not None:
        cur_age_y = (
            risk_impact_filtered_dietary_groupped
            .query(
                'age == @age'
            )
            .set_index('rei_name').loc[rei_sorted_by_val_sum, :]
            ['E_x_diff']
            .cumsum()
        )

        fig.add_trace(
            go.Scatter(
                x=[age] * (len(cur_age_y.values) + 1),
                y=[0] + list(cur_age_y.values),
                mode='lines+markers',
                marker=dict(
                    color=[rei_color_map[rei_name] for rei_name in cur_age_y.index],
                    size=5,
                ),
                line=dict(
                    width=0.3,
                    color='red'
                ),
                hoverinfo='skip',
                showlegend=False,
            )
        )


    fig.update_layout(
        template='plotly_white',
        xaxis=dict(
            showgrid=False,
            showline=False,
            zeroline=False,
            title='Age(years)',
            tickvals=list(
                range(
                    risk_impact_filtered_dietary_groupped.age.min(),
                    risk_impact_filtered_dietary_groupped.age.max(),
                    5
                )
            )
        ),
        yaxis=dict(
            range=(0, risk_impact_filtered_dietary_groupped.groupby(by='age')['E_x_diff'].sum().max() * 1.05),
            showgrid=False,
            showline=False,
            zeroline=False,
            title='Life expectancy extension (years)',
        ),
        legend=dict(
            orientation="h",
            y=-0.1,
        ),
        margin=dict(l=0, r=0, t=0, b=0)  
    )

    return fig


In [47]:
life_expectancy_extension_by_age_plotter(
    risk_impact_filtered_dietary_groupped=risk_impact_filtered_dietary_groupped,
    rei_color_map=rei_color_map,
    age=age,
)

## Summaries graphs

Life expectancy extension by age

In [13]:
risk_manageable_impact = (
    risk_impact
    .query(
        'rei_id in @risk_impact_filtered_cur_age.rei_id.unique()'
    )
    .copy()
)

# sumarry impact for all manageable risk factors
risk_manageable_impact_summary = (
    risk_manageable_impact
    .groupby(by=['location_id', 'age', 'sex_id'])
    [['E_x_diff']]
    .sum()
    .reset_index()
)

data_frame=(
    risk_manageable_impact_summary
    .groupby(by=['age', 'sex_id'])
    [['E_x_diff']]
    .mean()
    .join(
        risk_manageable_impact_summary
        .groupby(by=['age', 'sex_id'])
        [['E_x_diff']]
        .quantile(q=0.05),
        rsuffix='_q05',
        how='left'
    )
    .join(
        risk_manageable_impact_summary
        .groupby(by=['age', 'sex_id'])
        [['E_x_diff']]
        .quantile(q=0.95),
        rsuffix='_q95',
        how='left'
    )
    .reset_index()
    .sort_values(by='age')
)

data_frame['sex_name'] = data_frame['sex_id'].map({v: k for k, v in sex_name_to_id.items()})

In [14]:
fig = make_subplots(
    cols=1,
    rows=2,
    shared_xaxes=True,
    shared_yaxes=True,
)

for sex in data_frame['sex_name'].unique():
    for estimation in ['E_x_diff_q95', 'E_x_diff_q05']:
        fig.add_trace(
            go.Scatter(
                x=data_frame.query('sex_name == @sex')['age'],
                y=data_frame.query('sex_name == @sex')[estimation],
                line=dict(
                    width=0.001
                ),
                fill='tonexty' if estimation == 'E_x_diff_q05' else None,
                fillcolor=color_mapping['Default life expectancy'][sex].replace('0.5)', '0.1)'),
                legendgroup=sex,
                name=f'{sex} 5%-95% percentile',
                showlegend=False if estimation == 'E_x_diff_q95' else True,
            ),
            row=1 if sex == 'Male' else 2,
            col=1
        )
    fig.add_trace(
            go.Scatter(
                x=data_frame.query('sex_name == @sex')['age'],
                y=data_frame.query('sex_name == @sex')['E_x_diff'],
                name=f'{sex} mean',
                line=dict(
                    color=color_mapping['Default life expectancy'][sex]
                ),
                legendgroup=sex
            ),
            row=1 if sex == 'Male' else 2,
            col=1
        )

fig.update_xaxes(
    title='Age years',
    title_font_size=10,
)

fig.update_yaxes(
    title='life expectancy extension, years',
    title_font_size=10,
)

fig.update_layout(
    title=dict(
        text=(
            'Distribution by age summary life expectancy extension by excluding '
            '<br>manageable risk factors (estimated by 204 country)'
        ),
        font=dict(
            size=15
        ),
    ),
    template='plotly_white',
    width=800,
    height=600
)

fig.show()


Life expectancy extension vs life expectancy

In [56]:
location_ierarchy = pd.read_excel(
    os.path.join('data','IHME_GBD_2019_ALL_LOCATION_HIERARCHIES_Y2022M06D29.XLSX')
)

location_ierarchy.columns = [x.lower().replace(' ', '_') for x in location_ierarchy.columns]

sdi_parents = [
    44637,
    44636,
    44639,
    44634,
    44635,
]

income_parents = [
    44575,
    44578,
    44577,
    44576,
]

income_id_to_income_name = {
    location_id: (
        location_ierarchy
        .query('level == 1 and location_id == @location_id')
        .location_name.values[0]
        .replace('World Bank ', '')
        .replace(' Income', '')
    )
    for location_id in income_parents
}

countries_by_income = location_ierarchy.query('level == 2')
countries_by_income['world_bank_income_level'] = countries_by_income['parent_id'].map(income_id_to_income_name)
countries_by_income.dropna(inplace=True)

In [61]:
life_expectancy_at_birth_vs_life_extension = (
    life_expectancy
    .query('age == 0')
    .set_index(['sex_id', 'location_id'])
    [['E_x_val']]
    .join(
        risk_manageable_impact_summary
        .set_index(['sex_id', 'location_id'])
        .query('age == 0')
        [['E_x_diff']],
        how='left',
        rsuffix='_extension',
        lsuffix='_life_expectancy'
    )
    .reset_index()
    .set_index('location_id')
    .join(
        pd.DataFrame(
            {'location_name': location_name_to_id.keys()},
            index=location_name_to_id.values()
        )
    )
    .join(
        countries_by_income
        .set_index('location_id')
        [['world_bank_income_level']],
        how='left'
    )
)

In [64]:
fig = make_subplots(
    subplot_titles=[
        '',
        ''
    ],
    cols=2,
    rows=1,
    column_widths=[3,1]
)

for sex_name in ['Male']:
    sex_id = sex_name_to_id[sex_name]

    for income_level, color in zip(
        ['Low','Lower Middle', 'Upper Middle', 'High'],
        ['rgb(66, 158, 245)', 'rgb(66, 245, 105)', 'rgb(247, 200, 27)', 'rgb(245, 66, 66)']
    ):

        df = life_expectancy_at_birth_vs_life_extension.query(
            'world_bank_income_level == @income_level'
            ' and sex_id == @sex_id'
        )

        fig.add_trace(
            go.Scatter(
                x=df['E_x_val'],
                y=df['E_x_diff'],
                text=df['location_name'],
                mode='markers',
                marker=dict(
                    color=color
                ),
                name=income_level,
                legendgroup=income_level,
                showlegend=True
            ),
            col=1,
            row=1
        )

        fig.add_trace(
            go.Box(
                y=df['E_x_diff'],
                marker_color=color,
                line_width=1,
                name=income_level,
                legendgroup=income_level,
                showlegend=False
            ),
            col=2,
            row=1
        )

        fig.update_layout(
            title_font_size=15
        )

fig.update_layout(
    template='plotly_white',
    width=1000,
    height=400,
    title=dict(
        text=(
            'Male life expectancy vs. life expectancy extension at birth'
            ' by exclusion 26 <br>manageable risk factors for 204 contry colored by income level'
        ),
        font_size=18,
        x=0.5,
        y=0.96,
        xanchor='center',
    ),
    xaxis=dict(
        title=dict(
            text='Life expectancy at birth (years)',
        )
    ),
    yaxis=dict(
        title=dict(
            text='Life expectancy extension at birth (years)',
        )
    ),
    legend=dict(
        title='World bank income level',
        orientation="h",
        y=1.2,
    ),
)

fig.show()

In [65]:
fig = make_subplots(
    subplot_titles=[
        '',
        ''
    ],
    cols=2,
    rows=1,
    column_widths=[3,1]
)

for sex_name in ['Female']:
    sex_id = sex_name_to_id[sex_name]

    for income_level, color in zip(
        ['Low','Lower Middle', 'Upper Middle', 'High'],
        ['rgb(66, 158, 245)', 'rgb(66, 245, 105)', 'rgb(247, 200, 27)', 'rgb(245, 66, 66)']
    ):

        df = life_expectancy_at_birth_vs_life_extension.query(
            'world_bank_income_level == @income_level'
            ' and sex_id == @sex_id'
        )

        fig.add_trace(
            go.Scatter(
                x=df['E_x_val'],
                y=df['E_x_diff'],
                text=df['location_name'],
                mode='markers',
                marker=dict(
                    color=color
                ),
                name=income_level,
                legendgroup=income_level,
                showlegend=True
            ),
            col=1,
            row=1
        )

        fig.add_trace(
            go.Box(
                y=df['E_x_diff'],
                marker_color=color,
                line_width=1,
                name=income_level,
                legendgroup=income_level,
                showlegend=False
            ),
            col=2,
            row=1
        )

        fig.update_layout(
            title_font_size=15
        )

fig.update_layout(
    template='plotly_white',
    width=1000,
    height=400,
    title=dict(
        text=(
            'Female life expectancy vs. life expectancy extension at birth'
            ' by exclusion 26 <br>manageable risk factors for 204 contry colored by income level'
        ),
        font_size=18,
        x=0.5,
        y=0.96,
        xanchor='center',
    ),
    xaxis=dict(
        title=dict(
            text='Life expectancy at birth (years)',
        )
    ),
    yaxis=dict(
        title=dict(
            text='Life expectancy extension at birth (years)',
        )
    ),
    legend=dict(
        title='World bank income level',
        orientation="h",
        y=1.2,
    ),
)

fig.show()

Life expectancy extension distribution by risk factors

In [62]:
dietary_risks_ids = [risks_name_to_id[x] for x in risks_parents_names_manageable['Dietary risks']]

In [63]:
risk_impact_at_birth_mean_by_country = (
    risk_impact.query(
        'rei_id in @risk_impact_filtered_cur_age.rei_id.unique()'
        ' and age == 0'
        ' and rei_id not in @dietary_risks_ids'
    )
    .groupby(by=['age', 'sex_id', 'rei_id'])
    [['val', 'lower', 'upper']].mean()
    .reset_index()
)

dietary_risks_impact_at_birth_mean_by_country = (
    risk_impact.query(
        'rei_id in @dietary_risks_ids'
        ' and age == 0'
    )
    .groupby(by=['age', 'sex_id', 'rei_id'])
    [['val', 'lower', 'upper']].mean()
    .reset_index()
    .groupby(by=['age', 'sex_id'])
    [['val', 'lower', 'upper']].sum()
    .reset_index()
)

dietary_risks_impact_at_birth_mean_by_country

dietary_risks_impact_at_birth_mean_by_country['rei_id'] = risks_name_to_id['Dietary risks']

risk_impact_at_birth_mean_by_country = pd.concat(
    [
        risk_impact_at_birth_mean_by_country,
        dietary_risks_impact_at_birth_mean_by_country
    ],
    axis=0
)

risk_impact_at_birth_mean_by_country['sex_name'] = (
    risk_impact_at_birth_mean_by_country
    .sex_id.map({v: k for k,v in sex_name_to_id.items()})
)

risk_impact_at_birth_mean_by_country['rei_name'] = (
    risk_impact_at_birth_mean_by_country
    .rei_id.map({v: k for k,v in risks_name_to_id.items()})
)

risk_impact_at_birth_mean_by_country.sort_values(by=['sex_id', 'rei_id'], inplace=True)

In [69]:
fig = go.Figure()

for sex_name in ['Male', 'Female']:
    for rei_name in risk_impact_at_birth_mean_by_country.rei_name.unique():
        df = risk_impact_at_birth_mean_by_country.query(
            'sex_name == @sex_name'
        )

        fig.add_trace(
            go.Bar(
                x=[sex_name],
                y=100 * (
                    df.query(
                        'rei_name == @rei_name'
                    ).val.values 
                    /
                    df.val.sum()
                ),
                marker=dict(
                    color=rei_color_map[rei_name],
                    line=dict(width=0.5, color='black')
                ),
                width=0.99,
                name=rei_name,
                legendgroup=rei_name,
                showlegend=True if sex_name == "Male" else False
            ),
        )

fig.update_traces(opacity=0.6)

fig.update_layout(
    template='plotly_white',
    height=600,
    title=dict(
        text=(
            'Comparing risk factors impact to the mean extension '
            'of life expectancy at birth by sex among 204 countries'
        ),
        font_size=15,
    ),
    barmode='stack',
    bargap=0.01,
    bargroupgap=0.001,
    yaxis=dict(
        title='Share of years of life expectancy extension (%)'
    ),
    xaxis=dict(
        title='Sex'
    ),
    legend=dict(
        title='Risk factor'
    )
)

fig.show()

In [67]:
risk_impact_at_birth_mean_by_income_list = []

for income_level in countries_by_income.world_bank_income_level.unique():

    locations_ids = countries_by_income.query("world_bank_income_level == @income_level").location_id.values

    risk_impact_at_birth_mean_by_income = (
        risk_impact.query(
            'rei_id in @risk_impact_filtered_cur_age.rei_id.unique()'
            ' and age == 0'
            ' and location_id in @locations_ids'
            ' and rei_id not in @dietary_risks_ids'
        )
        .groupby(by=['age', 'sex_id', 'rei_id'])
        [['val', 'lower', 'upper']].mean()
        .reset_index()
    )

    risk_impact_at_birth_mean_by_income['world_bank_income_level'] = income_level

    risk_impact_at_birth_mean_by_income_list.append(
        risk_impact_at_birth_mean_by_income
    )

    dietary_risk_impact_at_birth_mean_by_income = (
        risk_impact.query(
            'rei_id in @dietary_risks_ids'
            ' and age == 0'
            ' and location_id in @locations_ids'
        )
        .groupby(by=['age', 'sex_id', 'rei_id'])
        [['val', 'lower', 'upper']].mean()
        .reset_index()
        .groupby(by=['age', 'sex_id'])
        [['val', 'lower', 'upper']].sum()
        .reset_index()
    )

    dietary_risk_impact_at_birth_mean_by_income['world_bank_income_level'] = income_level

    dietary_risk_impact_at_birth_mean_by_income['rei_id'] = risks_name_to_id['Dietary risks']

    risk_impact_at_birth_mean_by_income_list.append(dietary_risk_impact_at_birth_mean_by_income)

risk_impact_at_birth_mean_by_income = pd.concat(
    risk_impact_at_birth_mean_by_income_list,
    axis=0
)

risk_impact_at_birth_mean_by_income['sex_name'] = (
    risk_impact_at_birth_mean_by_income
    .sex_id.map({v: k for k,v in sex_name_to_id.items()})
)

risk_impact_at_birth_mean_by_income['rei_name'] = (
    risk_impact_at_birth_mean_by_income
    .rei_id.map({v: k for k,v in risks_name_to_id.items()})
)

risk_impact_at_birth_mean_by_income.sort_values(by=['sex_id', 'rei_id'], inplace=True)

In [70]:
fig = go.Figure()

sex_name = 'Male'

for income_level in ['Low','Lower Middle', 'Upper Middle', 'High']:
    df = risk_impact_at_birth_mean_by_income.query(
        'sex_name == @sex_name'
        ' and world_bank_income_level == @income_level'
    )

    df['val_normed'] = 100 * (df['val'] / df['val'].sum())
    
    for rei_name in risk_impact_at_birth_mean_by_income.rei_name.unique():
        fig.add_trace(
            go.Bar(
                x=[income_level],
                y=df.query(
                    'rei_name == @rei_name'
                ).val_normed.values,
                marker=dict(
                    color=rei_color_map[rei_name],
                    line=dict(width=0.5, color='black')
                ),
                width=0.99,
                name=rei_name,
                legendgroup=rei_name,
                showlegend=True if income_level == "High" else False
            ),
        )

fig.update_traces(opacity=0.6)

fig.update_layout(
    template='plotly_white',
    #width=800,
    height=600,
    title=dict(
        text=(
            'Comparing risk factors impact to the mean extension '
            'of males life expectancy at birth by income level among 204 countries'
        ),
        font_size=15,
    ),
    barmode='relative',
    bargap=0.01,
    bargroupgap=0.001,
    yaxis=dict(
        title='Share of years of life expectancy extension (%)'
    ),
    xaxis=dict(
        title='World bank income level'
    ),
    legend=dict(
        title='Risk factor',
        traceorder="reversed",
        tracegroupgap=20
    )
)

fig.show()

In [71]:
fig = go.Figure()

sex_name = 'Female'

for income_level in ['Low','Lower Middle', 'Upper Middle', 'High']:
    df = risk_impact_at_birth_mean_by_income.query(
        'sex_name == @sex_name'
        ' and world_bank_income_level == @income_level'
    )

    df['val_normed'] = 100 * (df['val'] / df['val'].sum())
    for rei_name in risk_impact_at_birth_mean_by_income.rei_name.unique():
        fig.add_trace(
            go.Bar(
                x=[income_level],
                y=df.query(
                    'rei_name == @rei_name'
                ).val_normed.values,
                marker=dict(
                    color=rei_color_map[rei_name],
                    line=dict(width=0.5, color='black')
                ),
                width=0.99,
                name=rei_name,
                legendgroup=rei_name,
                showlegend=True if income_level == "High" else False
            ),
        )

fig.update_traces(opacity=0.6)

fig.update_layout(
    template='plotly_white',
    #width=800,
    height=600,
    title=dict(
        text=(
            'Comparing risk factors impact to the mean extension '
            'females life expectancy at birth by income level among 204 countries'
        ),
        font_size=15,
    ),
    barmode='stack',
    bargap=0.01,
    bargroupgap=0.001,
    yaxis=dict(
        title='Share of years of life expectancy extension (%)'
    ),
    xaxis=dict(
        title='World bank income level'
    ),
    legend=dict(
        title='Risk factor'
    )
)

fig.show()