In [None]:
import pandas as pd
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pltkit
import plotly.express as px
import seaborn as sns
import statsmodels.api as sm
from statsmodels.discrete.discrete_model import NegativeBinomial
from statsmodels.stats.outliers_influence import variance_inflation_factor as VIF
from process_files import secondary_operations
from sklearn.model_selection import train_test_split

wdir = 'X:\\user\\liprandicn\\Health Impacts Model'

country_data = xr.open_dataset(f'{wdir}\\Analysis\\grouped_country_data_1980-2019.nc')
country_data = secondary_operations(country_data)
country_data = country_data.drop_sel(country="Taiwan")

## Statistical Analysis

In [None]:
y = country_data.loc[{'age_group':'young', 'cause_name':'Diabetes and kidney diseases'}]['relative_mortality']
y.plot.hist(bins=50)
print('Mean:', y.mean().values, 'Variance:', y.std().values)

In [None]:
#Bivariate analysis of climate/socioeconomic statistics and relative mortality

country_df = country_data.to_dataframe().reset_index()
pca_df = country_df[(country_df['population'] > 1e6) & (country_df['age_group']=='oldest') & (country_df['cause_name']=='Cardiovascular diseases')]
# pca_df = pca_df.drop(columns=['country', 'year', 'age_group', 'cause_name', 'total_mor', 'population'])

temperature_vars = ['CDD_20', 'CDD_20_squared', 'CDD_23_3', 'CDD_25', 'HDD_15', 'HDD_18_3', 'HDD_20', 'HDD_20_squared', 
                    'degree_days', 'temperature_kurtosis', 'temperature_mean', 'temperature_skewness', 'temperature_std',
                    'climatology', 'mean_relative_humidity', 'CDD_25_rh', 'log_kurtosis', 'log_CDD_20', 'log_HDD_20', 
                    'CDD_20_sq', 'HDD_20_sq', 'CDD_20_k', 'HDD_20_k', 'CDD_20_k_sq', 'HDD_20_k_sq', 'CDD_25_rh_k', 
                    'temperature_mean_sq', 'CDD_23_3_st', 'temperature_mean_cen', 'temperature_mean_sq_cen',
                    'temperature_mean_3', 'temperature_mean_4', 'temperature_mean_loggdppc', 'temperature_mean_sq_loggdppc', 
                    'temperature_kurtosis_sq', 'degree_days_loggdppc', 'degree_days_sq', 'degree_days_sq_loggdppc',
                    'degree_days_std', 'degree_days_std_sq', 'degree_days_std_loggdppc', 'CDD_20_k_loggdppc', 'HDD_20_k_loggdppc']

socioeconomic_vars = ['gdp', 'gdppc', 'loggdppc', 'health_expenditure', 'medical_doctors', 'HDI', 'GINI', 'education_expenditure', 'schooling_years',
                      'log_health_exp', 'urban_share', 'log_health_expenditure']

correlations = pca_df[temperature_vars + ['relative_mortality']].corr()

mort_corr = correlations[['relative_mortality']].sort_values(by='relative_mortality', ascending=False)


plt.figure(figsize=(6, len(mort_corr) * 0.4))  
sns.heatmap(mort_corr, annot=True, cmap='RdBu', center=0, cbar_kws={'label': 'Correlation'})

plt.title('Correlation between Temperature Statistics and Relative Mortality')
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Seleccionar los datos
datax = country_data.sel(
    country=['Russia'],
    cause_name='Cardiovascular diseases',
    age_group='oldest'
).loggdppc.values

datay = country_data.sel(
    country=['Russia'],
    cause_name='Cardiovascular diseases',
    age_group='oldest'
).relative_mortality.values

# Hacer el plot como línea a lo largo de 'year'
plt.scatter(datax, datay, cmap='viridis', s=50, alpha=0.7)
plt.title("Relative mortality over time")
plt.show()


In [None]:
country_df = country_data.to_dataframe().reset_index()
test = country_df[(country_df['cause_name'] == 'Cardiovascular diseases') & (country_df['age_group'] == 'oldest') 
                      & (country_df['population'] > 1e6)] #& country_data['cause_name'].isin(pltkit.diseases_level3.values())] 
fig = px.scatter(test, x='temperature_mean_loggdppc', y='relative_mortality', color='year', hover_data='country', height=500, width=700)
fig.update_layout(showlegend=False)
fig.show()

## OLS Regression

In [None]:
df = country_data.loc[{'cause_name':'Cardiovascular diseases', 'age_group': 'oldest'}].to_dataframe().reset_index().dropna()
df = df.drop(['country', 'age_group', 'cause_name'], axis=1)
y = df.pop('relative_mortality')
predictors = ['temperature_mean', 'temperature_std', 'temperature_mean_loggdppc', 'gdppc', 'schooling_years']
X = sm.add_constant(df[predictors])
model = sm.OLS(y, X)
results = model.fit()
print(results.summary())
print('RSS:', pltkit.rss(results.fittedvalues, y))
vif_data = pd.DataFrame()
vif_data["variable"] = X.columns
vif_data["VIF"] = [VIF(X.values, i) for i in range(X.shape[1])]
print(vif_data)

In [None]:
df = country_data.loc[{'cause_name':'Cardiovascular diseases', 'age_group': 'oldest'}].to_dataframe().reset_index().dropna()
variable = 'temperature_mean'
X = sm.add_constant(df[[variable]])
y = df['relative_mortality']

# Ajustar modelo
model = sm.OLS(y, X).fit()
influence = model.get_influence()
leverage = influence.hat_matrix_diag

# Umbral típico para leverage
p = X.shape[1]  # número de parámetros (constante + variables)
n = X.shape[0]  # número de observaciones
umbral = 2 * p / n

# Gráfico
plt.scatter(df['temperature_mean'], leverage)
plt.axhline(umbral, color='red', linestyle='--', label=f'Umbral (2p/n ≈ {umbral:.3f})')
plt.xlabel(variable)
plt.ylabel("Leverage")
plt.title("Leverage por observación")
plt.legend()
plt.show()


In [None]:
influence = model.get_influence()
summary = influence.summary_frame()

# Agrega a tu DataFrame
df_out = df.copy()
df_out["leverage"] = summary["hat_diag"]
df_out["studentized_resid"] = summary["student_resid"]
df_out["cooks_d"] = summary["cooks_d"]

# Observaciones potencialmente influyentes
influyentes = df_out[df_out["cooks_d"] > 0.5]  # o algún otro umbral

plt.scatter(df_out["leverage"], df_out["studentized_resid"], 
            s=100 * df_out["cooks_d"], alpha=0.8)
plt.xlabel("Leverage")
plt.ylabel("Residuo studentizado")
plt.title("Influencia: leverage vs residuo")
plt.axhline(0, color='black', linestyle='--')
plt.axhline(3, color='red', linestyle='--', label='±3')
plt.axhline(-3, color='red', linestyle='--')
plt.legend()
plt.show()

In [None]:
df = country_data.loc[{'cause_name':'Cardiovascular diseases', 'age_group': 'oldest'}].to_dataframe().reset_index().dropna()
worng_cols = ['country', 'age_group', 'cause_name', 'log_rel_mor', 'total_mortality', 'log_mortality', 'health_expenditure', 'health_expenditure_k', 'population', 'gdppc', 'log_kurtosis', 'HDD_20', 'CDD_20', 'degree_days_std_sq',
              'temperature_mean_cen', 'temperature_mean_sq', 'log_health_expenditure', 'temperature_mean_sq_cen', 'degree_days_std', 'gdp', 'HDD_20_squared', 'degree_days_std_loggdppc', 'temperature_mean_4', 
              'degree_days_sq_loggdppc', 'HDD_18_3', 'temperature_mean_3', 'CDD_23_3_st', 'degree_days_loggdppc', 'CDD_25', 'temperature_mean_sq_loggdppc', 'log_population', 'log_HDD_20', 'bin_15_20', 'HDI', 'temperature_mean_loggdppc', 'year',
              'CDD_20_k_loggdppc', 'bin_20_30', 'bin_10_20', 'bin_0_10', 'bin_20_25', 'CDD_25_rh', 'HDD_20_k_loggdppc', 'bin_m10_0', 'medical_doctors', 'CDD_23_3', 'bin_5_10', 'bin_0_5', 'bin_10_15', 'HDD_15', 'urban_share', 'temperature_mean', 
              'log_health_exp', 'bin_m5_0', 'bin_25', 'CDD_20_sq', 'CDD_20_squared', 'log_CDD_20', 'bin_m10_m5', 'education_expenditure', 'bin_m15_m10', 'bin_m20_m10', 'climatology', 'bin_40', 'bin_30', 'bin_40_45', 'GINI',
              'bin_25_30', 'bin35_loggdppc', 'temperature_std', 'bin_m20_m15', 'bin_m20', 'bin_m15', 'bin_m10', 'schooling_years', 'CDD_20_k_sq', 'CDD_25_rh_k', 'bin_35_40', 'HDD_20_sq', 'bin_30_35', 'bin_30_40', 'HDD_20_k_sq', 'mean_relative_humidity', 
              'temperature_kurtosis_sq', 'HDD_20_k']
df = df.drop(worng_cols, axis=1)
y = df.pop('relative_mortality')
rss_matrix = pd.DataFrame(columns=df.columns)
predictors = pd.DataFrame(index=df.index)
predictor_names = []

for i in range(1,len(df.columns)):
    print(i)
    for column in df.columns:
        X = pd.concat([predictors, df[column]], axis=1)
        X = sm.add_constant(X)
        model = sm.OLS(y, X)
        results = model.fit()
        rss_matrix.loc[i,column] = pltkit.rss(results.fittedvalues, y)
        
    predictor_names.append(rss_matrix.loc[i].idxmin())
    predictors = pd.concat([predictors, df[rss_matrix.loc[i].idxmin()]], axis=1)
    df = df.drop(rss_matrix.loc[i].idxmin(), axis=1)

    model_pvalues = sm.OLS(y,predictors)
    results = model_pvalues.fit()
    significant_cols = results.pvalues[results.pvalues <= 0.05].index
    predictors = predictors[significant_cols]
    
vif_data = pd.DataFrame()
vif_data["variable"] = predictors.columns
vif_data["VIF"] = [VIF(predictors.values, i) for i in range(predictors.shape[1])]

predictors

In [None]:
predictors = sm.add_constant(predictors)
model_pvalues = sm.OLS(y,predictors)
results = model_pvalues.fit()
print(results.summary())
print(vif_data)

In [31]:
def ols_errors(predictors, disease, age_group):
    
    errors_df = pd.DataFrame(index = range(25), columns = range(len(predictors)))
    # params_df = pd.DataFrame(index = predictors[0], columns=pd.MultiIndex.from_product([['params', 'standard_error', 'p-value'], ['oldest', 'old', 'young'], list(diseases)]))
    
    for i, predictor in enumerate(predictors):
        for j in range(15):

            df = country_data.loc[{'cause_name':disease, 'age_group':age_group}].to_dataframe().reset_index().dropna()
            train_df, val_df = train_test_split(df, test_size=0.2)

            X_train = train_df[predictor]
            y_train = train_df['relative_mortality']
            X_train = sm.add_constant(X_train)
            model = sm.OLS(y_train, X_train).fit()

            X_valid = val_df[predictor]
            X_valid = sm.add_constant(X_valid)
            y_valid = val_df['relative_mortality']
            y_pred = model.predict(X_valid)

            errors_df.loc[j,i] = np.round(pltkit.rmse(y_valid, y_pred),3)
            
    return errors_df

In [None]:
predictors = [[],['temperature_mean', 'temperature_std', 'gdppc'],
              ['temperature_mean', 'temperature_std', 'temperature_mean_loggdppc', 'gdppc'],
              ['temperature_mean', 'temperature_std', 'temperature_mean_loggdppc', 'loggdppc'],
              ['temperature_mean', 'temperature_std', 'temperature_mean_loggdppc', 'gdppc', 'schooling_years', 'GINI'],
              ['temperature_mean', 'temperature_std', 'gdppc', 'schooling_years', 'GINI'],
              ['degree_days', 'degree_days_std', 'gdppc'],
              ['HDD_20_squared', 'CDD_20_squared', 'gdppc'],
              ['HDD_20', 'CDD_20', 'loggdppc']
              ]
errors_df = ols_errors(predictors, 'Cardiovascular diseases', 'oldest')
labels = ['Intercept' if len(p) == 0 else ' + '.join(p) for p in predictors]

for column,label in zip(errors_df.columns,labels):
    plt.plot(errors_df.index, errors_df[column], label=label)
    
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=1)
plt.ylabel('MAPE')
plt.xlabel('Iteration')

## Negative Binomial Regresison

In [13]:
X = country_data.loc[{'cause_name':'Chronic respiratory diseases', 'age_group': 'oldest'}].to_dataframe().reset_index().dropna()
y = X.pop('total_mortality')
predictors = ['CDD_20_k', 'HDD_20_k', 'loggdppc', 'log_health_expenditure', 'climatology', 'GINI', 'urban_share']
# predictors = ['CDD_20_k', 'HDD_20_k']

In [56]:
# df = country_data.loc[{'cause_name':'Chronic respiratory diseases', 'age_group': 'oldest'}].to_dataframe().reset_index().dropna()
# df = df.drop(columns=['cause_name', 'age_group'])
# X = df.pipe(lambda df: pd.get_dummies(df, columns=['year', 'country']).astype(float))
# y = X.pop('total_mortality')

# fixed_effects = [col for col in X.columns if 'country' in col or 'year' in col]
# predictors = ['CDD_23_3', 'HDD_18_3', 'CDD_20_k_sq', 'HDD_20_k_sq', 'loggdppc', 'log_health_expenditure', 'climatology', 'GINI']
# predictors = list(set(fixed_effects + predictors))

In [None]:
### Estimate alpha
nb = NegativeBinomial(y, X[predictors].values, offset=np.log(X["population"].values),)
result_nb = nb.fit()
print(result_nb.params[-1])

model = sm.GLM(
    y,
    X[predictors],
    offset=np.log(X["population"]+1),
    family=sm.families.NegativeBinomial(alpha=result_nb.params[-1]),
)
result = model.fit()
print(result.summary())

Xi = X[predictors]
Xi['intercept'] = 1

vif_data = pd.DataFrame()
vif_data["variable"] = Xi.columns
vif_data["VIF"] = [VIF(Xi.values, i) for i in range(Xi.shape[1])]

print(vif_data)

print(f"AIC: {result.aic}")
print(f"Deviance: {result.deviance}")

In [None]:
plt.plot(X['relative_mortality'], result.fittedvalues/X['population']*1e5, 'o')
plt.plot(X['relative_mortality'], X['relative_mortality'], '--', label='y = x')
plt.ylabel("fitted value")
plt.xlabel("observed value")
plt.legend()
plt.title(f'MAPE = {mape(X['relative_mortality'], result.fittedvalues/X['population']*1e5)}')
# plt.xscale('log')
# plt.yscale('log')
plt.show()

In [None]:
f, axes = plt.subplots(1, 2, figsize=(17, 6))
axes[0].plot(y, result.resid_response, 'o')
axes[0].set_ylabel("Residuals")
axes[0].set_xlabel("$y$")
axes[1].plot(y, result.resid_pearson, 'o')
axes[1].axhline(y=-1, linestyle=':', color='black', label='$\pm 1$')
axes[1].axhline(y=+1, linestyle=':', color='black')
axes[1].set_ylabel("Standardized residuals")
axes[1].set_xlabel("$y$")
plt.legend()
plt.show()

In [None]:
R = result.pearson_chi2 / result.df_resid
print(R)  # 21.88

In [18]:
import warnings

def nb_errors(predictors, diseases, fixed_effects=False):
    
    errors_df = pd.DataFrame(index = range(len(predictors)), columns=pd.MultiIndex.from_product([['RMSE', 'MAPE'], ['oldest', 'old', 'young'], list(diseases)]))
    params_df = pd.DataFrame(index = predictors[0], columns=pd.MultiIndex.from_product([['params', 'std', 'p-value'], ['oldest', 'old', 'young'], list(diseases)]))
    
    for i, predictor in enumerate(predictors):

        # Iterate per disease
        for age_group in ['oldest', 'old', 'young']:
            for j, disease in enumerate(diseases):
                            
                # Convert data to dataframe
                if fixed_effects:
                    df = country_data.loc[{'cause_name':disease, 'age_group': age_group}].to_dataframe().reset_index().dropna()
                    df = df.drop(columns=['cause_name', 'age_group'])
                    X = df.pipe(lambda df: pd.get_dummies(df, columns=['country', 'year']).astype(float))
                    y = X.pop('total_mortality')
                    
                    fixed_effects = [col for col in X.columns if 'country' in col or 'year' in col]
                    predictor = list(set(fixed_effects + predictor))
                    
                else:
                    df = country_data.loc[{'cause_name':disease, 'age_group': age_group}].to_dataframe().reset_index().dropna()
                    df = df.drop(columns=['country', 'year'])
                    X = df[df['cause_name'].isin(list(pltkit.diseases_level3.values())[:-2])]
                    y = X.pop('total_mortality')
                
                
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    
                    # Estimate dispersion parameter
                    nb = NegativeBinomial(y, X[predictor].values, offset=np.log(X["population"].values),)
                    result_nb = nb.fit()
                    
                    # Main Negatie Binomial regression model
                    model = sm.GLM(
                        y,
                        X[predictor],
                        offset=np.log(X["population"]),
                        family=sm.families.NegativeBinomial(alpha=result_nb.params[-1]),
                        )
                    result = model.fit()
                
                # errors_df.loc[i,('deviance', age_group, disease)] = np.round(result.deviance,3)
                # errors_df.loc[i,('akaike', age_group, disease)] = np.round(result.aic, 3)
                errors_df.loc[i,('MAPE', age_group, disease)] = np.round(mape(df['relative_mortality'], result.fittedvalues/X['population']*1e5),3)
                errors_df.loc[i,('RMSE', age_group, disease)] = np.round(rmse(df['relative_mortality'], result.fittedvalues/X['population']*1e5),3)
                
                params_df.loc[:,('params',age_group,disease)] = result.params
                params_df.loc[:,('std',age_group,disease)] = result.bse
                params_df.loc[:,('p-value',age_group,disease)] = result.pvalues
    
    return errors_df, params_df

In [None]:
predictors_set = [
                  ['CDD_20_k', 'HDD_20_k', 'CDD_20_k_sq', 'HDD_20_k_sq', 'loggdppc', 'log_health_expenditure', 'climatology', 'GINI']
                  ]

errors_df, params_df = nb_errors(predictors_set, list(pltkit.diseases_level3.values())[:-2], fixed_effects=False)

fig, axes = plt.subplots(1, 2, figsize=(25,8), sharey=True)

for i, metric in enumerate(['RMSE', 'MAPE']):
    ax = axes[i]
    sub_df = errors_df[metric]
    
    vmin = sub_df.min().min()
    vmax = sub_df.max().max()

    sns.heatmap(sub_df.astype(float),  ax=ax, annot=True, annot_kws={"size": 8, "rotation": 90}, fmt=".1f", cmap='YlGnBu', vmin=vmin, vmax=vmax, cbar=True)
    ax.set_title(f"{metric}")

axes[0].set_yticklabels([' + '.join(predictors) for predictors in predictors_set],rotation = 0)
                        #  r'$\alpha \cdot CDD_{20} + \beta \cdot CDD^2_{20} + \gamma \cdot HDD_{20} + \delta + \cdot + HDD^2_{20} + \epsilon \cdot T_{clim} + \zeta \cdot \text{log(GDPpc)} + \eta \cdot \text{log(HEpc)} + \theta \cdot \text{GINI}$'
                        #  ],rotation = 0)

plt.tight_layout()
plt.show()

In [None]:
params_df.to_csv(f'{wdir}\\Analysis\\params_1.csv')

In [None]:
predictor = ['CDD_20_k', 'HDD_20_k', 'CDD_20_k_sq', 'HDD_20_k_sq', 'loggdppc', 'log_health_expenditure', 'climatology', 'GINI']

errors_len = 20
errors_df = pd.DataFrame(index = range(errors_len), columns=pd.MultiIndex.from_product([['RMSE', 'MAPE'], ['oldest', 'old', 'young'], list(pltkit.diseases_level3.values())[:-2]]))

# Iterate per disease
for i, disease in enumerate(list(pltkit.diseases_level3.values())[:-2]):
    for age_group in ['oldest', 'old', 'young']:
    
        # Select disease and age_group
        data_disease = country_data.loc[{'cause_name':disease, 'age_group':age_group}].to_dataframe().reset_index().dropna()
        
        for j in range(errors_len):
            
            # Randomly select data
            data_train = data_disease.sample(frac=0.8)
            data_valid = data_disease.loc[~data_disease.index.isin(data_train.index)]
            
            # Convert data to dataframe

            X = data_train[data_train['cause_name'].isin(list(pltkit.diseases_level3.values())[:-2])]
            y = X.pop('total_mor')
                
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                
                # Estimate dispersion parameter
                nb = NegativeBinomial(y, X[predictor].values, offset=np.log(X["population"].values),)
                result_nb = nb.fit()
                
                # Main Negatie Binomial regression model
                model = sm.GLM(y, X[predictor], offset=np.log(X["population"]), family=sm.families.NegativeBinomial(alpha=result_nb.params[-1]),)
                result = model.fit()
            
                X_valid = data_valid[predictors]
                y_valid = result.predict(X_valid, offset=np.log(data_valid['population']))
                y_real = data_valid['rel_mor']
            
            errors_df.loc[j, ('RMSE', age_group, disease)] = np.round(rmse(y_real, y_valid/data_valid['population']*1e5),3)
            errors_df.loc[j, ('MAPE', age_group, disease)] = np.round(mape(y_real, y_valid/data_valid['population']*1e5),3)

In [None]:
nrow = 1 # subplot rows, columns
ncol = 2
fig = plt.figure(1, figsize=(10,6))

for i, error in enumerate(['RMSE', 'MAPE']):
    for j, age_group in enumerate(['oldest', 'old', 'young']):
        
        plt.subplot(nrow, ncol, i+1)
        for disease,color in zip(list(pltkit.diseases_level3.values())[:-2],['C0', 'C1', 'C2', 'C3']):
            plt.scatter(errors_df.index, errors_df.loc[:,(error, age_group, disease)], alpha=1-j*0.4, color=color, label=f'{age_group}-{disease}')
    plt.title(error)
plt.legend(loc='lower center', bbox_to_anchor=(-0.1, -.4), ncol=3)
# plt.tight_layout()

## Plots

In [None]:
def equation(X, params):
    b0 = params[0]
    betas = params[1:]
    return b0 + sum(beta * x for beta, x in zip(betas, X))

def equation(X, params):
    return sum(param * x for param, x in zip(params, X))

fig = px.scatter(country_data, x='temperature_mean', y='rel_mor', hover_data=['year', 'country', 'loggdppc'], height=500, width=800)

# t_mean = np.linspace(country_data['temperature_mean'].min(), country_data['temperature_mean'].max(), 100)

# for loggdppc in range(14, 15):
#     y_vals = equation(([loggdppc]), list(model.params.values[:-172]))
    
#     fig.add_trace(go.Scatter(x=t_mean, y=y_vals, mode='lines', name=f'kurtosis={loggdppc}'))

# fig.update_layout(showlegend=True)
fig.show()

In [None]:
sns.heatmap(country_data[['temperature_mean', 'temperature_std', 'temperature_skewness', 'temperature_kurtosis', 'CDD_20', 'HDD_20', 
                            'CDD_23_3', 'HDD_18_3', 'CDD_25', 'HDD_15', 'degree_days']].corr(), annot=True, cmap = 'RdBu')
# sns.heatmap(country_data[['loggdppc', 'health_expenditure', 'medical_doctors', 'GINI', 'schooling_years', 'education_expenditure']].corr(), annot=True, cmap='RdBu')

In [None]:
(country_data['CDD_25']+3).plot(kind='hist', bins=10)

In [None]:
fig, axs = plt.subplots(3,2, figsize=(12,16))
axs = axs.flatten()

for i, disease in enumerate(pltkit.diseases_level3.values()):
    gbd_disease = country_data[(country_data['cause_name'] == disease) & (country_data['age_group']=='oldest')].copy()
    
    # fig = px.scatter(gbd_disease, x='temperature_mean', y='rel_mor', color='year',  hover_name='country', title=disease,
    #                  labels={'temperature_mean': 'Mean Temperature (°C)', 'rel_mor': 'Relative Mortality', 'year': 'Year', 'country': 'Country'},
    #                  template='plotly_white', height=500, width=700)
    # fig.show()
    
    axs[i].scatter(gbd_disease['temperature_mean'], gbd_disease['rel_mor'], label=gbd_disease['year'].unique()[0], alpha=0.8, zorder=2, c='C2')
    pltkit.stylize_axes(axs[i], xlabel='Mean Temperature (°C)', ylabel='Relative Mortality', title=disease, grid=True, grid_kwargs= {'color':'white', 'zorder':'0'}, facecolor='whitesmoke')
    
plt.tight_layout()

In [None]:
data = country_data.sum(dim='country') \
    .sel(cause_name='Enteric infections') \
    .mean(dim='year') \
    .isel(age_group=slice(None, -1))
    
fig, ax1 = plt.subplots()

data.rel_mor.plot(ax=ax1, color='tab:blue', label='Relative Mortality')
ax1.set_xlabel('Age group')
ax1.set_ylabel('Relative Mortality', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax1.set_xticklabels(data.age_group.values, rotation=45)
# ax1.set_title('Mortality by Age Group')

ax2 = ax1.twinx()

data.total_mor.plot(ax=ax2, color='tab:red')
ax2.set_ylabel('Total mortality', color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')

plt.tight_layout()
plt.show

In [None]:
df_heatmap = country_data[(country_data['country']=='China') & (country_data['age_group'] == 'oldest')
                            & (country_data['cause_name'] == 'Cardiovascular diseases')]
df_heatmap = df_heatmap.drop(columns=['country', 'age_group', 'cause_name', 'GDP', 'total_mor'])
df_heatmap = df_heatmap.set_index('year')
df_normalized = df_heatmap.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=0)

sns.heatmap(df_normalized, annot=False)