In [None]:
#Utils
import os
import pandas as pd
import numpy as np
from joblib import dump
import joblib
import yaml

# basic plotting
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
from interpret import show
from interpret.glassbox._ebm._research import *

# more detailed plotting
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
import pprint

In [None]:
stage = "train"
params = yaml.safe_load(open("/workspace/growthcurves/params.yaml"))
paths = params["config"]["paths"]
root_path = '/workspace/data/out/results_2024_09_25/target_12_months_sds/'
data_save_root = f"{root_path}/reports/figures/ebm_explainability/"
os.chdir(root_path)
print(os.curdir)

if not os.path.exists(data_save_root):
    os.mkdir(data_save_root)

In [None]:

# Load pre-trained model
model_path = os.path.join(root_path, paths["models"], "other", "ExplainableBoostingRegressor_model.joblib")
ebm = joblib.load(model_path)

# Get data 
model_data_path = os.path.join(root_path, paths["features"], "modelling_dataset.joblib")
modelling_dataset = joblib.load(model_data_path)

x_train = modelling_dataset['x_train']
y_train = modelling_dataset['y_train']
x_test = modelling_dataset['x_test']
y_test = modelling_dataset['y_test']

## Global explanations

In [None]:
#default visualisation
set_visualize_provider(InlineProvider())
global_exp = ebm.explain_global()
show(global_exp)

## prettier global explanation for manuscript

In [None]:
variable_names_dict = {
    'height_velocity_gh_start_1': 'Height velocity one year before treatment',
    'igf_1_up': 'IGF-1 increase',
    'gh_dose_proportional_delta_3_m': 'Change in GH dose last 3 months',
    'sd_weight_delta_12_m': 'Change in weight last 12 months',
    'age_gh_start': 'Age at treatment start',
    'ostradiol_imputed_up': 'Estradiol increase',
    'gh_max_stimulation':'Stimulated GH peak',
    'birth_weight': 'Birth weight',
    'testosteron_imputed':'Testosterone',
    'testicle_size_imputed':'Testicle size',
    'perc_change_igf1_gh_dos_date_3m':'Change in IGF-1 first 3 months on treatment',
    'target_height_deficit_sds_delta_12_m':'Change in target height deficit last 12 months',
    'age': 'Age',
    'igf_1_sds': 'IGF-1 SDS',
    'gh_dose_proportional': 'GH dose',
    'igf_1_igfbp_3': 'IGF-1/IGFBP3 ',
    'target_height_deficit_sds': 'Target height deficit',
    'height_velocity_1': 'Growth second year of life',
    'height_velocity_0': 'Growth first year of life',
    'birth_length': 'Birth length',
    'ostradiol_imputed': 'Estradiol',
    'sd_weight': 'Weight',
}


In [None]:
# make a new dataframe to play around with
features = global_exp.data()['names']
scores = global_exp.data()['scores']

global_explanation= pd.DataFrame({'names': global_exp.data()['names'], 
                                   'scores': global_exp.data()['scores']}).sort_values("scores", ascending=False)

# based on the variable name, create new display names
global_explanation['display_name'] = global_explanation['names']

for old, new in variable_names_dict.items():
    global_explanation['display_name'] = global_explanation['display_name'].str.replace(old, new, regex=False)


In [None]:
global_explanation

In [None]:
sns.set_style("whitegrid")
g= sns.barplot(global_explanation[0:15], y="display_name", x="scores",
               color = '#005B89')
g.set_xticks([0.1,0.2,0.3])
sns.set(font_scale=1.4)
sns.set(rc = {'figure.figsize':(4, 4)})
plt.ylabel('')
plt.xlabel('Mean Absolute Score (Weighted)')
sns.despine(left=True)
plt.savefig(f"{data_save_root}/global_explainability_summary_seaborn.png", bbox_inches='tight')

# Local explanations
## 3 months feature importance

In [None]:
variables_to_visualize = ['igf_1_sds','height_velocity_gh_start_1', 'gh_dose_proportional', 'igf_1_igfbp_3', 'igf_1_up']
variable_names_in_order = global_exp.data()['names']
variable_display_names = []
for variable_name in variables_to_visualize:
    variable_display_names.append(global_explanation.query("names == @variable_name").display_name.iloc[0])
variable_display_names = ["Height velocity one <br>year before treatment" if x=="Height velocity one year before treatment" else x for x in variable_display_names]

In [None]:
# 3m predictions figure with 3m predictor specific adjustments
ytickvals_frequency = [0,40,80]
bigger_plot_lims = [0.3, 1]
smaller_plot_lims = [0, 0.1]
mainfig = make_subplots(rows=2, cols=5,
                        shared_yaxes = 'rows',
                        shared_xaxes = 'columns',
                        subplot_titles=variable_display_names)
#for ind in range(15):
for i, variable_name in enumerate(variables_to_visualize):
    ind = variable_names_in_order.index(variable_name)
    fig = global_exp.visualize(ind)
    variable_name = global_exp.data()['names'][ind]
    display_name = variable_display_names[i]

    plot_traces = fig.data
    mainfig.add_traces(plot_traces[0:3], cols=i+1, rows=1)
    mainfig.add_traces(plot_traces[3], cols=i+1, rows=2)

mainfig.update_xaxes(
    mirror=False,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='gray',
    zerolinecolor = 'black'
)
mainfig.update_yaxes(
    mirror=False,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='gray',
    zerolinecolor='black'
)

mainfig.update_layout(
    autosize=False,
    minreducedwidth=250,
    minreducedheight=250,
    width=1500,
    height=650,
    plot_bgcolor='white',
    font=dict(size=18),
    # edit relative plot sizes
    yaxis6=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency,
                title = dict(text ='# data-<br>points',
                               standoff=5,
                               font_size = 18)),
    yaxis1 = dict(domain=bigger_plot_lims,
                  title = dict(text ='Score',
                               standoff=0,
                               font_size=18)),
    yaxis7=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis2 = dict(domain=bigger_plot_lims),
    yaxis8=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis3 = dict(domain=bigger_plot_lims),
    yaxis9=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis4 = dict(domain=bigger_plot_lims),
    yaxis10=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis5 = dict(domain=bigger_plot_lims),
    xaxis = dict(showticklabels=True,
                 title=dict(text = variable_display_names[0],
                 font_size = 18)), 
    xaxis2 = dict(showticklabels=True,
                 title=dict(text = variable_display_names[1],
                            standoff = 0,
                            font_size=18)),
    xaxis3 = dict(showticklabels=True,
                 title=dict(text=variable_display_names[2],
                 font_size = 18)), 
    xaxis4 = dict(showticklabels=True,
                 title=dict(text =variable_display_names[3],
                 font_size = 18)),
    xaxis5 = dict(showticklabels=True,
                tickvals = [0,1],
                ticktext=['no','yes'],
                title=dict(text = variable_display_names[4],
                 font_size = 18)),
    xaxis10 = dict(showticklabels=True,
                  tickvals = [0,1],
                  ticktext=['no','yes']))

mainfig.update_annotations(font=dict(family="sans-serif", size=22))

labels = ['A', 'B', 'C', 'D', 'E']
titlepos = [0, .207, .41, .62, 0.83]
for i, label in enumerate(labels):
    mainfig.update_annotations(selector={"text":variable_display_names[i]}, text=f"{labels[i]}) {variable_display_names[i]}", x=titlepos[i], xanchor = 'left',yref= 'paper')

mainfig.update_traces(marker_color='rgba(0,91,137,1)')
mainfig.update_layout(showlegend=False)
mainfig.write_image(f"{data_save_root}/ebm_variable_explanations_for_manuscript_seaborn.png")

mainfig.show()

## 12 months feature importance

In [None]:
variables_to_visualize = ['igf_1_sds','age', 'target_height_deficit_sds', 'birth_length', 'testosteron_imputed']
variable_names_in_order = global_exp.data()['names']
variable_display_names = []
for variable_name in variables_to_visualize:
    variable_display_names.append(global_explanation.query("names == @variable_name").display_name.iloc[0])
variable_display_names = ["Height velocity one <br>year before treatment" if x=="Height velocity one year before treatment" else x for x in variable_display_names]

In [None]:
# 12m predictions figure with 3m predictor specific adjustments
ytickvals_frequency = [0,40,80]
bigger_plot_lims = [0.3, 1]
smaller_plot_lims = [0, 0.1]
mainfig = make_subplots(rows=2, cols=5,
                        shared_yaxes = 'rows',
                        shared_xaxes = 'columns',
                        subplot_titles=variable_display_names)
#for ind in range(15):
for i, variable_name in enumerate(variables_to_visualize):
    ind = variable_names_in_order.index(variable_name)
    fig = global_exp.visualize(ind)
    variable_name = global_exp.data()['names'][ind]
    display_name = variable_display_names[i]

    plot_traces = fig.data
    mainfig.add_traces(plot_traces[0:3], cols=i+1, rows=1)
    mainfig.add_traces(plot_traces[3], cols=i+1, rows=2)

mainfig.update_xaxes(
    mirror=False,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='gray',
    zerolinecolor = 'black'
)
mainfig.update_yaxes(
    mirror=False,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='gray',
    zerolinecolor='black'
)

mainfig.update_layout(
    autosize=False,
    minreducedwidth=250,
    minreducedheight=250,
    width=1500,
    height=650,
    plot_bgcolor='white',
    font=dict(size=18),
    # edit relative plot sizes
    yaxis6=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency,
                title = dict(text ='# data-<br>points',
                               standoff=5,
                               font_size = 18)),
    yaxis1 = dict(domain=bigger_plot_lims,
                  title = dict(text ='Score',
                               standoff=0,
                               font_size=18)),
    yaxis7=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis2 = dict(domain=bigger_plot_lims),
    yaxis8=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis3 = dict(domain=bigger_plot_lims),
    yaxis9=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis4 = dict(domain=bigger_plot_lims),
    yaxis10=dict(domain=smaller_plot_lims,
                tickvals=ytickvals_frequency),
    yaxis5 = dict(domain=bigger_plot_lims),
    xaxis = dict(showticklabels=True,
                 title=dict(text = variable_display_names[0],
                 font_size = 18)), 
    xaxis2 = dict(showticklabels=True,
                 title=dict(text = variable_display_names[1],
                            font_size=18)),
    xaxis3 = dict(showticklabels=True,
                 title=dict(text=variable_display_names[2],
                 font_size = 18)), 
    xaxis4 = dict(showticklabels=True,
                 title=dict(text =variable_display_names[3],
                 font_size = 18)),
    xaxis5 = dict(showticklabels=True,
                title=dict(text = variable_display_names[4],
                 font_size = 18)),
    xaxis10 = dict(showticklabels=True))

mainfig.update_annotations(font=dict(family="sans-serif", size=22))

labels = ['A', 'B', 'C', 'D', 'E']
titlepos = [0, .207, .41, .62, 0.83]
for i, label in enumerate(labels):
    mainfig.update_annotations(selector={"text":variable_display_names[i]}, text=f"{labels[i]}) {variable_display_names[i]}", x=titlepos[i], xanchor = 'left',yref= 'paper')

mainfig.update_traces(marker_color='rgba(0,91,137,1)')
mainfig.update_layout(showlegend=False)
mainfig.write_image(f"{data_save_root}/ebm_variable_explanations_for_manuscript_seaborn.png")

mainfig.show()

## plot variables one at a time

In [None]:

for i, variable_name in enumerate(variables_to_visualize):
    ind = variable_names_in_order.index(variable_name)
    fig = global_exp.visualize(ind)
    variable_name = global_exp.data()['names'][ind]
    display_name = global_explanation.query("names == @variable_name").display_name.iloc[0]
    fig.update_layout(
        autosize=False,
        minreducedwidth=250,
        minreducedheight=250,
        width=500,
        height=700,
        plot_bgcolor='white',
        title=dict(text=display_name),
        xaxis=dict(title=dict(
            text=display_name)),
            font=dict(size=18),
        xaxis2=dict(
            showticklabels=False,
            title=None))

    fig.update_xaxes(
        mirror=False,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='gray'
    )
    fig.update_yaxes(
        mirror=False,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='gray'
    )
    fig.update_traces(marker_color='rgba(0,91,137,1)')

    
    #fig.write_image(f"{data_save_root}/global_explainability_{variable_name}_seaborn.png")
    fig.show()


In [None]:
# Save all visualizations in their original format
fig_global_summary = global_exp.visualize()
fig_global_summary.write_image(f"{data_save_root}/global_explainability_summary.png")

feature_names_global = global_exp.selector.Name.tolist()
for key in range(len(global_exp.selector)):
    fig_global = global_exp.visualize(key)
    fig_global.write_image(f"{data_save_root}/global_explainability_{feature_names_global[key]}.png")

# Local explanations
local_exp = ebm.explain_local(x_test[:5], y_test[:5])
show(local_exp, 0)

# Save local explanations
for key in range(len(local_exp.selector)):
    fig_local = local_exp.visualize(key)
    fig_local.write_image(f"{data_save_root}/ebm_local_explainability_person{key}.png")

# Checking importance of specific features

In [None]:
sex_hormones_group = ["ostradiol_imputed", "testosteron_imputed", "testicle_size_imputed", 'testosteron_imputed_up', 'testosteron_imputed_down',
       'ostradiol_imputed_up', 'ostradiol_imputed_down', 'testicle_size_delta_12_m']
importance = compute_group_importance(sex_hormones_group, ebm, x_test)
print(f"Group: {sex_hormones_group} - Importance: {importance}")

In [None]:
importances = ebm.term_importances()
names = ebm.term_names_

for (term_name, importance) in zip(names, importances):
    print(f"Term {term_name} importance: {importance}")