# Dashboard

Present the results in an interpretable manner, with interactive plots, explanations in natural language and justifications for each possible diagnostic.

In [1]:
# Patient to be studied
patient_id = '126_S_4458'

## Imports and data loading

In [2]:
# Numerical and data manipulation
import numpy as np
import pandas as pd

# Language model and tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

# Neuroimaging tools
import nibabel as nib
from nilearn import plotting, datasets

# Visualization tools
from PIL import Image
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from matplotlib.colors import LinearSegmentedColormap

# Dashboard utils
import dash_bootstrap_components as dbc
from dash import Dash, html, dash_table, dcc, callback, \
    callback_context, Output, Input, State

# Other utils
import os
import time
import requests
from tqdm.notebook import tqdm

In [3]:
# Load decision explanations from csv
cat_map = {0: 'Cognitive Normal', 1: 'Mild Cognitive Impairment', 2: "Alzheimer's Disease"}

path = f'results/{patient_id}/'
probs = dict()
with open(path + "xgboost-probs.txt", 'r') as f:
    probs['xgboost'] = list(map(float, f.read().strip().split('\t')))
with open(path + "rf-probs.txt", 'r') as f:
    probs['rf'] = list(map(float, f.read().strip().split('\t')))

cognitive_confidences = {model: {cat_map[k]: v for k, v in enumerate(probs[model])} for model in ['xgboost', 'rf']}

with open(path + "decision_path.txt", "r") as f:
    dt_decisions = f.read().split(";")

norm_ranges = pd.read_csv('norm_range.csv', sep=';', names=['feature', 'lower', 'upper'])

# Use the most probable category for the plot
explanations_cn = pd.read_csv(path + 'explanations-cn.csv')
explanations_mci = pd.read_csv(path + 'explanations-mci.csv')
explanations_ad = pd.read_csv(path + 'explanations-ad.csv')

## Brain graph

In [4]:
# Preprocess explanation dataframe
def preprocessing_brain_df(explanations_df, model):
    # Select only original feature names and importance weights for brain plots
    df = explanations_df[['feature', f'importance-{model}']].copy()
    df.columns = ['region_name', 'value']
    
    # Visualize the most relevant statistic per brain region
    not_brain_region = list()
    for i, name in enumerate(df['region_name']):
        if (name.startswith('aseg') or name.startswith('cerebellum')):
            
            # Remove starting code
            name = name.split('stats-')[1]
            
            # Remove final part
            name = name.split('-')[:-1]
            name = '-'.join(name)
            df.loc[i, 'region_name'] = name
            continue
            
        not_brain_region.append(i)
        
    df.drop(not_brain_region, inplace=True)
    df = df.groupby('region_name', as_index=False)['value'].apply(lambda x: x[x.abs().idxmax()])
    return df

df_cn = {model: preprocessing_brain_df(explanations_cn, model) for model in ["xgboost", "rf"]}
df_mci = {model: preprocessing_brain_df(explanations_mci, model) for model in ["xgboost", "rf"]}
df_ad = {model: preprocessing_brain_df(explanations_ad, model) for model in ["xgboost", "rf"]}

In [5]:
# Create the NIfTI image and the slice view plot
def surface_representation(df):
    # From the patient atlas image, load data on top
    atlas_img_co  = nib.load(os.path.join(path, 'aparc.DKTatlas+aseg.deep.mgz'))
    atlas_data_co = atlas_img_co.get_fdata().astype(int)
    
    atlas_img_ce  = nib.load(os.path.join(path, 'cerebellum.CerebNet.nii.gz'))
    atlas_data_ce = atlas_img_ce.get_fdata().astype(int)
    
    # Load the FreeSurfer LUT into a DataFrame 
    LUT_URL = (
        'https://raw.githubusercontent.com/freesurfer/freesurfer/dev/'
        'distribution/FreeSurferColorLUT.txt'
    )
    
    def load_lut(url=LUT_URL):
        response = requests.get(url)
        response.raise_for_status()
        lines = response.text.splitlines()
    
        records = []
        for line in lines:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split(None, 6)
            if len(parts) < 2:
                continue
            region_id, label = parts[:2]
            records.append({'Label': label, 'ID': int(region_id)})
    
        df = pd.DataFrame(records)
        return df
    lut = load_lut()
    label_to_name = dict(zip(lut['ID'], lut['Label']))
    
    # Build the stat map
    stat_map = np.zeros_like(atlas_data_co, dtype=float)
    for label_idx, region in tqdm(label_to_name.items()):
        match = df.loc[df.region_name == region, 'value']
        if not match.empty:
            stat_map[atlas_data_co == label_idx] += float(match.iloc[0])
            stat_map[atlas_data_ce == label_idx] += float(match.iloc[0])
    
    # Make a NIfTI and plot
    stat_img = nib.Nifti1Image(stat_map, atlas_img_co.affine)
    
    cmap = LinearSegmentedColormap.from_list(
        'red_white_green',
        ['red', 'white', 'green']
    )
    maxabs = np.max(np.abs(stat_map))
    
    slice_view = plotting.view_img(
        stat_img,
        threshold=0.1*maxabs,
        cmap=cmap,
        vmax=+maxabs,
        width_view=1250,
    )
    
    html_view = slice_view._repr_html_()
    return html_view

html_view_cn = {k: surface_representation(df) for k, df in df_cn.items()}
html_view_mci = {k: surface_representation(df) for k, df in df_mci.items()}
html_view_ad = {k: surface_representation(df) for k, df in df_ad.items()}

  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


  0%|          | 0/1804 [00:00<?, ?it/s]

  a.partition(kth, axis=axis, kind=kind, order=order)


## Language model

### Prompt engineering

In [None]:
system_prompt = """
 You are a clinician. The user will provide a dataset in JSON format
 containing a collection of supporting features and counterpoint features.
 For each feature, there is the value and its expected range.
 The text must be based only on the provided data.""".replace("\n ", " ")

def get_user_prompt(explanations_df, prediction, model):
    # The task and output format definition are the same for all user prompts
    start = """
     Summarise the patient data into 3 cohesive and medical term-rich paragraphs
     of 100 words each, without bullet points. First, the diagnosis
     support (use only features from the Support section). Second, the
     diagnosis counterpoints (use only features from the Counterpoints
     section). Third, categorise the given diagnosis into possible
     or impossible based on the data and justify.""".replace("\n     ", " ")

    # Gather patient data 
    df = explanations_df[['feature', 'name', 'value', f'importance-{model}']]
    df = df.rename(columns={f'importance-{model}': 'importance'})
    df['value'] = [float(f'{val:.2f}') for val in df['value']]
    df['expected range'] = [f"[{norm_ranges[norm_ranges['feature'] == col]['lower'].iloc[0]}, \
{norm_ranges[norm_ranges['feature'] == col]['upper'].iloc[0]}]" \
                   for col in df['feature']]
    
    # Support data
    df_p = df[df['importance'] > 0.01]
    if len(df_p) == 0:
        p_prompt = 'No significative supporting data'
    else:
        df_p = df_p.sort_values(by='importance', ascending=False)[:7]
        p_prompt = df_p[['name', 'value', 'expected range']].T.to_dict()
        p_prompt = [str(val).replace("'", "").replace('"', "") for val in p_prompt.values()]
    
    # Counterpoint data
    df_n = df[df['importance'] < -0.01]
    if len(df_n) == 0:
        n_prompt = 'No significative counterpoints'
    else:
        df_n = df_n.sort_values(by='importance', ascending=True)[:7]
        n_prompt = df_n[['name', 'value', 'expected range']].T.to_dict()
        n_prompt = [str(val).replace("'", "").replace('"', "") for val in n_prompt.values()]
    
    # Join data
    prompt = f'{start}\n\nSupport: {p_prompt}\nCounterpoints: {n_prompt}\nPrediction: {prediction}'
    return prompt

In [59]:
# Construct final message
def get_messages(explanations_df, prediction, model):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": get_user_prompt(explanations_df, prediction, model)}
    ]
    return messages

messages_cn = {model: get_messages(explanations_cn, cat_map[0], model) for model in ['xgboost', 'rf']}
messages_mci = {model: get_messages(explanations_mci, cat_map[1], model) for model in ['xgboost', 'rf']}
messages_ad = {model: get_messages(explanations_ad, cat_map[2], model) for model in ['xgboost', 'rf']}

In [60]:
messages_cn['xgboost'][1]

{'role': 'user',
 'content': " Summarise the patient data into 3 cohesive and medical term-rich paragraphs of 100 words each, without bullet points. First, the diagnosis support (use only features from the Support section). Second, the diagnosis counterpoints (use only features from the Counterpoints section). Third, categorise the given diagnosis into possible or impossible based on the data and justify.\n\nSupport: ['{name: Volume in cubic millimeters of the Cerebellum CrusI region in the right hemisphere., value: 9052.63, expected range: [8211.81, 11130.2]}', '{name: Baseline Intracranial volume (mm³), value: 1405700.0, expected range: [1342166.32, 1642541.39]}', '{name: Standard deviation of the normalized intensity of the Cerebellum CrusII region in the right hemisphere., value: 10.2, expected range: [9.83, 11.59]}']\nCounterpoints: ['{name: Baseline Logical Memory Delayed Recall total score, value: 6.0, expected range: [10.05, 16.71]}', '{name: ADNI‐modified Preclinical Alzheimer

In [61]:
# Load fine-tuned model and tokenizer
model_id = "lfm2-sft"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="bfloat16",
    trust_remote_code=True,
    
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [62]:
import time
# Generate explanation texts
def generate_text(message):
    t0 = time.time()
    input_ids = tokenizer.apply_chat_template(
        message,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
    ).to(model.device)
    
    output = model.generate(
        input_ids,
        do_sample=True,
        temperature=0.3,
        min_p=0.15,
        repetition_penalty=1.2,
        max_new_tokens=1024,
    )
    
    explanation_text = tokenizer.decode(output[0], skip_special_tokens=True)
    explanation_text = explanation_text.split('assistant\n')[1]
    print(f"It has taken {time.time() - t0} seconds to generate")
    return explanation_text

explanation_texts_cn = {model: generate_text(messages_cn[model]) for model in ['xgboost', 'rf']}
explanation_texts_mci = {model: generate_text(messages_mci[model]) for model in ['xgboost', 'rf']}
explanation_texts_ad = {model: generate_text(messages_ad[model]) for model in ['xgboost', 'rf']}

It has taken 100.560387134552 seconds to generate
It has taken 72.66641855239868 seconds to generate
It has taken 178.2550323009491 seconds to generate
It has taken 141.52759432792664 seconds to generate
It has taken 91.74632334709167 seconds to generate
It has taken 116.47807955741882 seconds to generate


In [63]:
explanation_texts_cn = {key: value.split("\n") for key, value in explanation_texts_cn.items()}
explanation_texts_mci = {key: value.split("\n") for key, value in explanation_texts_mci.items()}
explanation_texts_ad = {key: value.split("\n") for key, value in explanation_texts_ad.items()}

## Dashboard

In [48]:
from plotly.subplots import make_subplots

def feature_value_graph(df, model):
    df = df[['feature', 'name', 'value', f'importance-{model}', 'percentile', 'group']]
    df = df.rename(columns={f'importance-{model}': 'importance'})
    for i, name in zip(df.index, df['feature']):
        if name.startswith("aseg") or name.startswith("cereb"):
            name = name.split("stats-")[1]
            df.loc[i, 'feature'] = name
            
    
    # Prepare plot
    na_mask = [not bl for bl in df["group"].isna()]
    df = df[na_mask]
    groups, counts = np.unique(df["group"], return_counts=True)
    n_groups = len(groups)
    global_min, global_max = 0, 100

    # Distribute row heights after limiting at n per group
    n = 10
    counts = [c if c < n else n for c in counts]
    total = sum(counts)
    row_heights = [c/total for c in counts]
    
    fig = make_subplots(
        rows=n_groups, cols=1,
        shared_xaxes=True,
        row_heights=row_heights,
        subplot_titles=list(groups),
        vertical_spacing=0.05
    )
    
    # Add one colored-bar trace per group
    for i, grp in enumerate(groups, start=1):
        sub_df = df[df["group"] == grp].sort_values("importance", key=lambda x: abs(x), ascending=False)
        if len(sub_df) > n:
            sub_df = sub_df[:n]
        
        fig.add_trace(
            go.Bar(
                x=sub_df["percentile"],
                y=sub_df["feature"],
                orientation="h",
                customdata=sub_df["name"],
                hovertemplate=(
                    "<b>%{customdata}</b>"
                ),
                text=[f"{int(p)}" for p in sub_df["percentile"]],
                textposition="outside",
                marker=dict(
                    color=sub_df["percentile"],
                    colorscale=[
                        [0, '#FF5964'], [0.2, '#FFE74C'], 
                        [0.5, '#6BF178'], [0.8, '#FFE74C'], [1, '#FF5964']],
                    cmin=global_min,
                    cmax=global_max,
                ),
                showlegend=False
            ),
            row=i, col=1
        )
        fig.update_yaxes(autorange="reversed", row=i, col=1)
        for j, (_, row) in enumerate(sub_df.iterrows()):
            y = row["feature"]
            # left: importance
            fig.add_annotation(
                xref="paper", x=1.055,
                yref=f'y{i}', y=j,
                text=f'{row["importance"]:.2f}',
                showarrow=False,
                xanchor="right",
                yanchor="middle"
            )
            # right: value
            fig.add_annotation(
                xref="paper", x=1.125,
                yref=f'y{i}', y=j,
                text=f'{row["value"]:.2f}',
                showarrow=False,
                xanchor="left",
                yanchor="middle"
            )
    
    # Left‑column header
    fig.add_annotation(
        xref="paper", yref="paper",
        x=1.08, y=1,
        text="<b>Importance</b>",
        showarrow=False,
        xanchor="right",
        yanchor="bottom"
    )
    # Right‑column header
    fig.add_annotation(
        xref="paper", yref="paper",
        x=1.115, y=1,
        text="<b>Value</b>",
        showarrow=False,
        xanchor="left",
        yanchor="bottom"
    )
    
    # Layout tweaks
    fig.update_layout(
        plot_bgcolor='#222222',
        paper_bgcolor='#222222',
        font_color='white',
        height=1200,
        title="Feature percentiles",
        margin=dict(t=100, l=200, r=200, b=50)
    )
    
    return fig


In [49]:
def feature_importance_graph(df, model, n):
    df = df[['feature', 'name', 'value', f'importance-{model}', 'percentile', 'group']]
    df = df.rename(columns={f'importance-{model}': 'importance'})
    
    # Sort by importance weight
    df = df.sort_values(by='importance', ascending=False).reset_index(drop=True)
    df = pd.concat([df.iloc[:n], df.iloc[-n:]])
    df['importance'] = [float(f'{val:.3f}') for val in df['importance']]

    # Adapt feature names and values
    for i, feat in zip(df.index, df['feature']):
        if feat.startswith('aseg') or feat.startswith('cereb'):
            feat = feat.split('stats-')[1]
            df.loc[i,'feature'] = feat
    
    # Create Plotly figure
    fig = go.Figure()

    # Add actual and expected bar traces
    fig.add_trace(go.Bar(
        x=df['feature'],
        y=df['importance'],
        name='Importance',
        marker_color='#35A7FF',
        text=df['importance'],
        textposition='outside',
        customdata=df['name'],
        hovertemplate=
            '%{text:.4f}' + 
            '<br>%{customdata}'
    ))

    # Layout adjustments
    fig.update_layout(
        barmode='group',
        plot_bgcolor='#222222',
        paper_bgcolor='#222222',
        font_color='white',
        legend=dict(bgcolor='#222222', bordercolor='white'),
        yaxis=dict(
            showgrid=False,
            showticklabels=False,
            title_text='',
        ),
        hovermode="x unified",
        margin=dict(l=20, r=20, t=40, b=40)
    )
    return fig

In [50]:
# Info panel content
info_panel_content = [
    # Section 1: Introduction
    html.H3("Introduction"),
    html.P(["""This dashboard contains the prediction summary 
    for the selected patient.""", html.Br(), html.Br(),
    
    """Created by Èric Sánchez López for the master's thesis on 
    Artificial Intelligence at KU Leuven."""]),
    
    # Section 2: First section
    html.H3("First section"),
    html.P(["""The percentages shown on top of the dashboard
    quantify the predicted probability for the patient to be on
    each of those states. The probability can change depending 
    on the chosen model, as well as the relevant features and 
    explanations described further down. The model can be chosen 
    using the bottom middle button that contains a spinning arrow.
    The available models are:""", html.Br(), html.Br(),

    "XGBoost: 86% normalized accuracy.", html.Br(),
    "Random Forest: 84% normalized accuracy.", html.Br(), html.Br(),
    
    """To choose the cognitive state to get a description from,
    select the corresponding state on the 
    "Category" drop-down menu at the left."""]),
    
    # Section 3: Explanation plot and texts
    html.H3("Diagnosis support"),
    html.P(["""The most interesting part is the description of
    the features and the explanation to support them. Inside
    the section, a horizontal barplot is shown to represent the 
    position of the patient's feature value compared with the 
    rest of the population, as a percentile of the distribution.
    On the right of the bar, the actual value is displayed, 
    together with the importance weight. The importance 
    represents the relevance of that feature's value for the
    model when predicting the selected cognitive state. If the 
    weight is negative, it means that it goes against the 
    selected prediction. To get the full name in natural 
    language, place the cursor on top of the corresponding bar.
    """, html.Br(), html.Br(),

    """
    Additionally, there is an AI-generated explanation of the
    decision based on the previously mentioned values. The
    explanations try to support and tell the counterpoints of
    the selected category, providing a final conclusion. The
    conclusion tends to be safe and generally proposes an
    optimistic diagnosis, so the user is encouraged to read
    each feature and reach their own conclusion.""", 
    html.Br(), html.Br(),

    """ Warning: the AI-generated text can make mistakes,
    the only correct values are those present on the plots.
    """
    ]),
    
    # Section 4: Feature importance plot and brain graph
    html.H3("Final graphs"),
    html.P(["""To conclude the summary, two additional graphs can
    be observed at the bottom of the dashboard. First, an 
    interactive slice view of the brain. The colored sections 
    relate to important features of the model. A green region has
    a positive impact on the prediction, and a red region has a 
    negative one (e.g. when predicting Alzheimer's, a red region 
    implies that the brain is healthier than expected in that 
    section). Not only is volume being represented, but also MRI 
    statistics, like average intensity.""", html.Br(), html.Br(),
    
    """To conclude, there is a feature
    importance bar plot, which represents the importance weights
    that were previously shown to the right of the percentile plot
    and visually on the slice view brain plot.
    With the positive and negative features being represented on 
    the same plot, the relative importance is more easily 
    comparable. Taller bars mean greater impact on the final 
    prediction, with positive values supporting the prediction
    and negative values doing the opposite.
    """]),

    # Section 5: Decision Tree path
    html.H3("Decision Tree path"),
    html.P("""To complement the approximate descriptions given by
    the XGBoost and Random Forest models, an additional white-box
    model Decision Tree has been trained on the same data and 
    provides the exact path taken to the final prediction. With a 
    normalized accuracy of 84%, it has a comparable performance 
    to its black-box counterparts, but it is fully interpretable. 
    A full schema of the decision tree can be visualized and 
    interacted with using the buttons "Show/Hide Decision Tree" 
    and "Zoom In/Out"."""),
]  

In [51]:
# INTERACTIVE DASHBOARD

# Structure dashboard and add all sections
textcolor = {
    'Cognitive Normal': '#6BF178', 
    'Mild Cognitive Impairment': '#FFE74C', 
    "Alzheimer's Disease": '#FF5964'
}
reverse_cat_map = {
    'Cognitive Normal': 0,
    'Mild Cognitive Impairment': 1,
    "Alzheimer's Disease": 2
}

# Initialize Dash app
def create_dash_app(dfs, explanations, confidences, html_views):
    # Initial values
    model_name = 'xgboost'
    largest_cat = max(confidences[model_name], key=confidences[model_name].get)
    largest_prob = confidences[model_name][largest_cat]
    
    app = Dash(external_stylesheets=[dbc.themes.DARKLY, dbc.icons.FONT_AWESOME])
    app.layout = dbc.Container([

        # Title
        dbc.Row([
            dbc.Col(
                html.H1("ALZHEIMER'S PREDICTION SUMMARY", style={'margin-top':30, 'margin-bottom':30}),
                width=True,
            ),
            dbc.Col(
                html.Div([
                    dbc.Button("Info", id="brain-info", n_clicks=0),
                    dbc.Offcanvas(info_panel_content,
                        id="info-panel",
                        title="Information panel",
                        scrollable=True,
                        is_open=False,
                    ),
                ]), style={'margin': 20, 'width': '100px'},
                width="auto",
            ),
        ], className="d-flex justify-content-center align-items-center",),

        # Probabilities
        dbc.Card([
            dbc.Row([
                *[dbc.Col([
                    html.Div(f"{key}",
                             style={'fontSize': 20+val*20, 'color': textcolor[key], 'margin':10}),
                    html.Div(f"{val*100: .1f}%", id=f'{reverse_cat_map[key]}-prob',
                             style={'fontSize': 40+val*60, 'color': textcolor[key], 'margin': 10})
                  ], className="d-flex justify-content-center align-items-center",)\
                  for key, val in confidences[model_name].items()]
            ]),
        ], style={'margin':20, 'margin-bottom':30}),

        # Choose category
        dbc.Row([
            dbc.Col(
                dbc.DropdownMenu(
                    label="Category",
                    id='cat-menu',
                    menu_variant="dark",
                    children=[
                        dbc.DropdownMenuItem("Cognitive Normal", id='cat-0'),
                        dbc.DropdownMenuItem("Mild Cognitive Impairment", id='cat-1'),
                        dbc.DropdownMenuItem("Alzheimer's Disease", id='cat-2'),
                    ],
                    toggle_style={"background": "#303030", 'width': '250px'},
                    toggleClassName="border-white",
                    style={'margin-bottom': 20},
                ), width="auto"
            ),
            dbc.Col(
                dbc.Button(
                    ["XGBoost   ", html.I(className="fa-solid fa-arrows-rotate")], 
                    id="model-btn", 
                    n_clicks=0,
                    className="border-white",
                    style={"background": "#303030", 'width': '200px', 'margin-bottom': 20},
                ), width='auto'
            ),
        ]),
        
        # For each category, sorted by probability
        dbc.Row([
            # Section title + info button
            dbc.Row([
                html.H2(
                    f"{largest_cat} ({largest_prob*100:.1f}%)",
                    id="section-title",
                    style={'margin-top': 20}
                ),
            ]),
            # Explanations + graph
            dbc.Card([
                html.H3("Diagnosis Support", 
                        style={'margin-top': 20, 'margin-left':20}),
                dcc.Graph(id='value-bars',
                          figure=feature_value_graph(
                              dfs[reverse_cat_map[largest_cat]], model_name),
                          style={'margin': 20}),
                dbc.Col([
                    html.P(expl,
                           style={'fontSize': 17,
                                  'margin': 20,
                                  "textAlign": "justify"})
                    for expl in explanations[reverse_cat_map[largest_cat]][model_name]
                ], id='explanations'),
            ], style={'margin':20},
               className="d-flex justify-content-center"
            ),
            # Brain graph iframe
            dbc.Card([
                html.H3("Relevant brain sections",
                        style={'margin-top': 20, 'margin-left':20}),
                html.P("""The green regions represent the features that support the selected
                cognitive state, while red represents the regions that do not support it.
                For example, if a negative cognitive state like Alzheimer's Disease is selected,
                a red region means that it is healthier than expected, while a green region 
                supports the cognitive state, therefore it is in a bad state.""", style={'margin':20}),
                html.Iframe(id='brain-iframe',
                    srcDoc=html_views[reverse_cat_map[largest_cat]][model_name],
                    style={"width": "90vw",
                           "height": "600px",
                           "border": "none"}
                )
            ], style={'margin':20},
                className="d-flex justify-content-center"
            ),
            
            # Feature Importance graph
            dbc.Card([
                dbc.Row([
                    dbc.Col(
                        html.H3("Feature importance to the prediction",
                                style={'margin-top': 20, 'margin-left':20}),
                        width="auto"
                    ),
                    dbc.Col(
                        dcc.Slider(2, 20, 1,
                            value=10,
                            id='nbars-slider'
                        ),
                        width="400px"
                    ),
                ]),
                dcc.Graph(id='importance-bar',
                          figure=feature_importance_graph(
                              dfs[reverse_cat_map[largest_cat]], model_name, n=10),
                          style={'margin':20}),
                ],
                style={'margin':20},
                className="d-flex justify-content-center"
            ),
            
        ]),

        html.Hr(),
        
        # Decision model prediction
        dbc.Card([
            html.H3("Decision tree predictions",
                   style={'margin': 20}),
            dbc.Row([*[
                html.P(node)
                for node in dt_decisions[:-1]],
                *[
                html.H4(pred)
                for pred in dt_decisions[-1].split(',\t')]],
                style={'margin':20, 'margin-top': 0}),
            dbc.Row([
                dbc.Col(
                    dbc.Button("Show / Hide Decision Tree", 
                               id="btn-toggle", 
                               className='border-white',
                               style={'margin-left':30, 'margin-bottom': 30}
                    ), width="auto"),
                dbc.Col(
                    dbc.Button("Zoom In/Out",
                               id="btn-zoom",
                               className='border-white',
                               style={'margin-left':30, 'margin-bottom': 30}
                    ), className="ml-2", width="auto"),
            ], className="my-2"),
            dcc.Loading(
                id="loading-svg",
                type="default",
                children=dbc.Collapse(
                    html.Div(id="svg-wrapper"),
                    id="collapse-svg",
                    is_open=False
                )
            ),
        ], style={'margin-bottom': 40}),
        dcc.Store(id="store-model", data="xgboost"),
        dcc.Store(id="store-cat", data=largest_cat),
    ])

    # Callbacks
    # Info Offcanvas
    @app.callback(
        Output("info-panel", "is_open"),
        Input("brain-info", "n_clicks"),
        [State("info-panel", "is_open")],
    )
    def toggle_offcanvas(n1, is_open):
        if n1:
            return not is_open
        return is_open

    # Open Decision Tree
    @app.callback(
        [Output("collapse-svg", "is_open"),
         Output("svg-wrapper", "children")],
        Input("btn-toggle", "n_clicks"),
        State("collapse-svg", "is_open"),
        prevent_initial_call=True
    )
    def toggle_svg(n, is_open):
        new_open = not is_open
        if new_open:
            # load the inline SVG only when opening
            svg = html.Img(id="inline-svg", src="/assets/decision_tree.svg", style={"width": "100%", "height": "auto"})
            wrapper = svg
        else:
            wrapper = None
        return new_open, wrapper
    
    @app.callback(
        Output("svg-wrapper", "className"),
        Input("btn-zoom", "n_clicks"),
        State("svg-wrapper", "className"),
        prevent_initial_call=True
    )
    def toggle_zoom(n, current_cls):
        return "" if current_cls == "zoomed" else "zoomed"
        
    # Store values
    @app.callback(
        Output("store-cat", "data"),
        [
            Input("cat-0", "n_clicks"),
            Input("cat-1", "n_clicks"),
            Input("cat-2", "n_clicks"),
        ]
    )
    def update_data(n_a, n_b, n_c):
        ctx = callback_context
        if not ctx.triggered:
            cat = largest_cat
        else:
            triggered_id = ctx.triggered[0]["prop_id"].split(".")[0]
            choice_map = {
                "cat-0": "Cognitive Normal",
                "cat-1": "Mild Cognitive Impairment",
                "cat-2": "Alzheimer's Disease"
            }
            cat = choice_map.get(triggered_id, largest_cat)
        return cat

    @app.callback(
        Output("store-model", "data"),
        Input("model-btn", "n_clicks")
    )
    def update_data(n):
        model_name = "rf" if n and (n % 2 == 1) else "xgboost"
        return model_name

    # Update texts and graphs
    @app.callback(
        Output("pred-title", "children"),
        Input("store-model", "data")
    )
    def toggle_name(name):
        name = "Random Forest predictions" if name == "rf" else "XGBoost predictions"
        return name
     
    @app.callback(
        Output("model-btn", "children"),
        Input("store-model", "data")
    )
    def toggle_name(name):
        name = "Random Forest   " if name == "rf" else "XGBoost   "
        return [name, html.I(className="fa-solid fa-arrows-rotate")]
        
    @app.callback(
        Output("0-prob", "children"),
        Input("store-model", "data")
    )
    def update_prob_cn(name):
        val = confidences[name]['Cognitive Normal']
        return f"{val*100: .1f}%"

    @app.callback(
        Output("1-prob", "children"),
        Input("store-model", "data")
    )
    def update_prob_cn(name):
        val = confidences[name]['Mild Cognitive Impairment']
        return f"{val*100: .1f}%"

    @app.callback(
        Output("2-prob", "children"),
        Input("store-model", "data")
    )
    def update_prob_cn(name):
        val = confidences[name]["Alzheimer's Disease"]
        return f"{val*100: .1f}%"
        
    @app.callback(
        Output("section-title", "children"),
        [
            Input("store-cat", "data"),
            Input("store-model", "data")
        ]
    )
    def update_title(cat, model):
        prob = confidences[model][cat]
        return f"{cat} ({prob*100:.1f}%)"

    @app.callback(
        Output("explanations", "children"),
        [
            Input("store-cat", "data"),
            Input("store-model", "data")
        ]
    )
    def update_expl(cat, model):
        return [html.P(expl,style={'fontSize': 17, 'margin': 20, "textAlign": "justify"})
             for expl in explanations[reverse_cat_map[cat]][model]]

    @app.callback(
        Output('value-bars', "figure"),
        [
            Input("store-cat", "data"),
            Input("store-model", "data"),
            Input("nbars-slider", "value"),
        ]
    )
    def update_imp_graph(cat, model, n):
        return feature_value_graph(dfs[reverse_cat_map[cat]], model)

    @app.callback(
        Output('importance-bar', "figure"),
        [
            Input("store-cat", "data"),
            Input("store-model", "data"),
            Input("nbars-slider", "value"),
        ]
    )
    def update_imp_graph(cat, model, n):
        return feature_importance_graph(dfs[reverse_cat_map[cat]], model, n=n)

    @app.callback(
        Output('brain-iframe', "srcDoc"),
        [
            Input("store-cat", "data"),
            Input("store-model", "data"),
        ]
    )
    def update_imp_graph(cat, model):
        return html_views[reverse_cat_map[cat]][model]

    return app   


In [52]:
explanation_texts = [explanation_texts_cn, explanation_texts_mci, explanation_texts_ad]
explanation_dfs = [explanations_cn, explanations_mci, explanations_ad]
html_views = [html_view_cn, html_view_mci, html_view_ad]

app = create_dash_app(explanation_dfs, explanation_texts, cognitive_confidences, html_views)
app.run(port=8000)