In [None]:
!python -m pip install --upgrade pip plotly

Collecting plotly
  Using cached plotly-6.1.1-py3-none-any.whl.metadata (6.9 kB)
Collecting narwhals>=1.15.1 (from plotly)
  Using cached narwhals-1.41.0-py3-none-any.whl.metadata (11 kB)
Using cached plotly-6.1.1-py3-none-any.whl (16.1 MB)
Using cached narwhals-1.41.0-py3-none-any.whl (357 kB)
Installing collected packages: narwhals, plotly

   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---



In [None]:
!python -m pip install --upgrade pip dash

Collecting dash
  Using cached dash-3.0.4-py3-none-any.whl.metadata (10 kB)
Collecting Flask<3.1,>=1.0.4 (from dash)
  Using cached flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting Werkzeug<3.1 (from dash)
  Using cached werkzeug-3.0.6-py3-none-any.whl.metadata (3.7 kB)
Collecting importlib-metadata (from dash)
  Using cached importlib_metadata-8.7.0-py3-none-any.whl.metadata (4.8 kB)
Collecting retrying (from dash)
  Using cached retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Collecting itsdangerous>=2.1.2 (from Flask<3.1,>=1.0.4->dash)
  Using cached itsdangerous-2.2.0-py3-none-any.whl.metadata (1.9 kB)
Collecting click>=8.1.3 (from Flask<3.1,>=1.0.4->dash)
  Using cached click-8.2.1-py3-none-any.whl.metadata (2.5 kB)
Collecting blinker>=1.6.2 (from Flask<3.1,>=1.0.4->dash)
  Using cached blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting zipp>=3.20 (from importlib-metadata->dash)
  Using cached zipp-3.22.0-py3-none-any.whl.metadata (3.6 kB)
Using cached dash-3.0



In [9]:
import dash
from dash import html, dcc, Input, Output
import pandas as pd
import pickle
import shap
import plotly.express as px
import warnings

warnings.filterwarnings("ignore")

# Load encoded data
df = pd.read_csv('../data/processed/mental_health_cleaned.csv')
X = df.drop(columns=['treatment_Yes'])
y = df['treatment_Yes']

# Ensure features are numeric (important for SHAP)
X_numeric = X.astype(float)

# Load models
with open('../models/logistic_regression_model.pkl', 'rb') as f:
    log_reg = pickle.load(f)
with open('../models/random_forest_model.pkl', 'rb') as f:
    rf_clf = pickle.load(f)
with open('../models/xgboost_model.pkl', 'rb') as f:
    xgb_clf = pickle.load(f)

# Initialize SHAP explainers
explainer_log = shap.Explainer(log_reg, X_numeric)
explainer_rf = shap.Explainer(rf_clf, X_numeric)
explainer_xgb = shap.Explainer(xgb_clf, X_numeric)

# Dash app setup
app = dash.Dash(__name__)
app.title = "Mental Health Prediction Dashboard"

app.layout = html.Div([
    html.H1("Mental Health Treatment Prediction Dashboard", style={'textAlign': 'center'}),

    html.Label("Select Model:"),
    dcc.Dropdown(
        id='model-dropdown',
        options=[
            {'label': 'Logistic Regression', 'value': 'log_reg'},
            {'label': 'Random Forest', 'value': 'rf_clf'},
            {'label': 'XGBoost', 'value': 'xgb_clf'}
        ],
        value='xgb_clf',
        style={'width': '300px'}
    ),

    html.Br(),

    html.Label("Select a record to explain:"),
    dcc.Dropdown(
        id='record-dropdown',
        options=[{'label': f'Index {i}', 'value': i} for i in range(len(X))],
        value=0,
        style={'width': '300px'}
    ),

    html.Br(),

    html.Div(id='prediction-output'),
    dcc.Graph(id='shap-bar-plot')
])

@app.callback(
    [Output('prediction-output', 'children'),
     Output('shap-bar-plot', 'figure')],
    [Input('model-dropdown', 'value'),
     Input('record-dropdown', 'value')]
)
def update_dashboard(selected_model, record_index):
    sample = X_numeric.iloc[[record_index]]
    true_label = y.iloc[record_index]

    if selected_model == 'log_reg':
        model = log_reg
        explainer = explainer_log
        model_name = "Logistic Regression"
    elif selected_model == 'rf_clf':
        model = rf_clf
        explainer = explainer_rf
        model_name = "Random Forest"
    else:
        model = xgb_clf
        explainer = explainer_xgb
        model_name = "XGBoost"

    pred = model.predict(sample)[0]
    prob = model.predict_proba(sample)[0][1]

    shap_values = explainer(sample)
    shap_df = pd.DataFrame({
        'Feature': sample.columns,
        'SHAP Value': shap_values.values[0]
    }).sort_values(by='SHAP Value', key=abs, ascending=False)

    fig = px.bar(
        shap_df,
        x='SHAP Value',
        y='Feature',
        orientation='h',
        title=f'SHAP Explanation ({model_name}) for Record Index {record_index}'
    )

    result = html.Div([
        html.H4(f"Model: {model_name}"),
        html.P(f"Predicted: {'Needs Treatment' if pred else 'Does Not Need Treatment'}"),
        html.P(f"Prediction Probability: {prob:.2f}"),
        html.P(f"True Label: {'Needs Treatment' if true_label else 'Does Not Need Treatment'}")
    ])

    return result, fig

if __name__ == '__main__':
    app.run(debug=True)

