In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import argparse
import json
import joblib
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression


In [2]:

features_adults = ['HEI_TS', 'CSES_TS', 'SCL-90 DEP', 'SCL-90 GSI', 'SCL-90 PST', 'SCL-90 TS', 'SCL-90 ANX', 'SCL-90 PSY', 'SCL-90 NST', 'SCL-90 ADD', 'SCL-90 PSDI', 'SCL-90 PAR', 'EMBU-M OI', 'SCL-90 SOM', 'EMBU-F OP', 'EMBU-F EW', 'EMBU-M EW', 'SSRS_TS', 'DES-Ⅱ_AMN', 'DES-Ⅱ_TS', 'SCL-90 OC', 'SSRS_SS', 'SCL-90 IS', 'DES-Ⅱ_ABS', 'EMBU-F OI', 'SCL-90 HOS', 'DES-Ⅱ_DPDR', 'CSQ_FAN', 'SSRS_OS', 'EMBU-F REJ', 'CSQ_RAT', 'EMBU-F PUN', 'CSQ_HS', 'SSRS_SU', 'CSQ_REP', 'SCL-90 PHOB', 'CSQ_PS', 'EMBU-M REJ', 'EMBU-F FS', 'EMBU-M PUN',  'CSQ_SB', 'EMBU-M FS']

features_teens = ['CSES_TS', 'SCL-90 DEP', 'HEI_TS', 'SCL-90 ANX', 'A-DES-Ⅱ_TS', 'A-DES-Ⅱ_PI', 'SCL-90 GSI', 'SCL-90 NST', 'SCL-90 PSY', 'EMBU-F EW', 'A-SSRS_SS', 'SCL-90 ADD', 'A-DES-Ⅱ_DPDR', 'SCL-90 PSDI', 'A-SSRS_TS', 'A-SSRS_SU', 'A-DES-Ⅱ_DA', 'SCL-90 HOS', 'EMBU-F OI', 'SCL-90 SOM', 'EMBU-M EW', 'EMBU-F PUN', 'SCL-90 TS', 'SCL-90 PHOB', 'EMBU-F OP', 'EMBU-M OI', 'SCL-90 PST', 'CSQ_FAN', 'A-SSRS_OS', 'EMBU-M PUN', 'CSQ_REP', 'SCL-90 IS', 'CSQ_PS', 'SCL-90 PAR', 'SCL-90 OC', 'CSQ_HS', 'A-DES-Ⅱ_AII', 'CSQ_SB', 'CSQ_RAT', 'EMBU-M REJ', 'EMBU-F REJ', 'EMBU-M FS', 'EMBU-F FS']

features_children = ['CSES_TS', 'HEI_TS', 'A-DES-Ⅱ_TS', 'A-SSRS_TS', 'A-DES-Ⅱ_PI', 'A-SSRS_SU', 'A-SSRS_OS', 'A-DES-Ⅱ_DA', 'A-DES-Ⅱ_AII', 'A-DES-Ⅱ_DPDR', 'CSQ_PS', 'EMBU-M PUN', 'A-SSRS_SS', 'EMBU-M OI', 'EMBU-F EW', 'EMBU-M EW', 'EMBU-M REJ', 'CSQ_SB', 'CSQ_HS', 'EMBU-F OP', 'CSQ_REP', 'EMBU-F FS', 'EMBU-F REJ', 'CSQ_RAT', 'EMBU-F PUN', 'EMBU-F OI', 'CSQ_FAN', 'EMBU-M FS']

top10_features_adults = ['HEI_TS', 'SCL-90 DEP', 'CSES_TS', 'SCL-90 PSY', 'SCL-90 ANX', 'EMBU-F EW', 'DES-Ⅱ_TS', 'DES-Ⅱ_ABS', 'DES-Ⅱ_AMN', 'EMBU-M FS']

top10_features_teens = ['SCL-90 DEP', 'A-DES-Ⅱ_PI', 'SCL-90 ANX', 'CSES_TS', 'EMBU-F EW', 'HEI_TS', 'SCL-90 PSY', 'A-DES-Ⅱ_DPDR', 'A-DES-Ⅱ_AII', 'EMBU-M FS']

top10_features_children = ['CSES_TS', 'HEI_TS', 'A-DES-Ⅱ_PI', 'A-DES-Ⅱ_AII', 'A-DES-Ⅱ_DPDR', 'EMBU-F EW', 'EMBU-F PUN', 'EMBU-M FS', 'CSQ_REP', 'A-SSRS_OS']

In [None]:
### macro parameters (modify as needed) ###
group_name = 'adults'
model_name = 'LogisticRegression'
top10 = False
clf_path = 'ckpt/adults/clean_adults_LogisticRegression_acc_0.91_run_123.pkl'
anomaly_path = 'risk_prob'
### end macro parameters ###

clf_path_base = os.path.dirname(clf_path)
top10_str = '_top10' if top10 else ''
scaler_name = f'clean_{group_name}_scaler{top10_str}.pkl'
if top10:
    anomaly_file_path = os.path.join(anomaly_path, 'top10', 'quantiles')
    scaler_path = os.path.join(clf_path_base, scaler_name)
else:
    anomaly_file_path = os.path.join(anomaly_path, 'full', 'quantiles')
    scaler_path = os.path.join(clf_path_base, scaler_name)

kde_q_low = json.load(open(os.path.join(anomaly_file_path, f'{group_name}_kde_q_low.json')))
kde_q_high = json.load(open(os.path.join(anomaly_file_path, f'{group_name}_kde_q_high.json')))
kde_q_low = pd.DataFrame(kde_q_low).T
kde_q_high = pd.DataFrame(kde_q_high).T

exclude_cols = ['Gender']
if group_name == 'adults':
    if top10:
        base_features = top10_features_adults
    else:
        base_features = features_adults
elif group_name == 'teens':
    if top10:
        base_features = top10_features_teens
    else:
        base_features = features_teens
elif group_name == 'children':
    if top10:
        base_features = top10_features_children
    else:
        base_features = features_children
else:
    raise ValueError("Invalid group name or top10 setting")

# Add demographic features for prediction
demographic_features = ['Age', 'Gender']
scale_cols = [f for f in base_features + demographic_features if f not in exclude_cols]

if scaler_path and os.path.isfile(scaler_path):
    # Load existing scaler bundle
    scaler_bundle = joblib.load(scaler_path)
    scaler = scaler_bundle['scaler']
    saved_scale_cols = scaler_bundle.get('scale_cols', scale_cols)
    
    # CRITICAL: Reorder features to match the scaler's expected order
    # Filter saved_scale_cols to only include features that are in base_features
    scaler_feature_names = [f for f in saved_scale_cols if f in base_features]
    
    
    # Use scaler's order for UI features (this ensures correct mean value alignment)
    features = scaler_feature_names
    all_features = ['Gender'] + saved_scale_cols if not top10 else saved_scale_cols
    
    print('mean value of each col:')
    for col in saved_scale_cols:
        print(f"  {col}: {scaler.mean_[scaler.feature_names_in_.tolist().index(col)]}")

else:
    raise ValueError("Scaler path is invalid or file does not exist.")


mean value of each col:
  Age: 20.13875
  SCL-90 SOM: 1.293125
  SCL-90 OC: 1.6636250000000001
  SCL-90 IS: 1.5577777777777777
  SCL-90 DEP: 1.5405769230769228
  SCL-90 ANX: 1.435625
  SCL-90 HOS: 1.390625
  SCL-90 PHOB: 1.3014285714285714
  SCL-90 PAR: 1.395625
  SCL-90 PSY: 1.366125
  SCL-90 ADD: 1.5008928571428573
  SCL-90 TS: 130.5525
  SCL-90 GSI: 1.4505833333333333
  SCL-90 PST: 25.755
  SCL-90 NST: 64.245
  SCL-90 PSDI: 2.1235399960108934
  EMBU-M EW: 55.5925
  EMBU-M OI: 30.66625
  EMBU-M REJ: 12.0075
  EMBU-M PUN: 12.20125
  EMBU-M FS: 5.10875
  EMBU-F EW: 52.54625
  EMBU-F PUN: 16.2125
  EMBU-F OI: 17.06
  EMBU-F FS: 4.70625
  EMBU-F REJ: 8.6175
  EMBU-F OP: 10.47625
  CSES_TS: 37.925
  HEI_TS: 5.89875
  CSQ_PS: 0.8411458333333335
  CSQ_SB: 0.426875
  CSQ_HS: 0.683125
  CSQ_FAN: 0.511875
  CSQ_REP: 0.5254545454545454
  CSQ_RAT: 0.5297727272727273
  SSRS_TS: 39.15625
  SSRS_SS: 22.65375
  SSRS_OS: 10.7525
  SSRS_SU: 8.05
  DES-Ⅱ_ABS: 15.645833333333336
  DES-Ⅱ_AMN: 11.853125
 

In [5]:
# Interactive single-sample inference with ipywidgets (compact layout)
# Assumes variables defined earlier: column_names, scaler (or fallback), and a trained model path.

import ipywidgets as W
from IPython.display import display, clear_output, HTML
import joblib
import numpy as np
import pandas as pd
import math
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from anomaly_quantile import compute_signed_anomaly_score, find_uncertain_intervals

# ---- Load trained model (adjust if needed) ----
clf = joblib.load(clf_path)

# Define all_features based on whether we're in top10 mode
# The classifier was trained with Gender + scaler features, but scaler only has non-Gender features
if top10:
    all_features = [col for col in saved_scale_cols if col != 'Age']  # No demographics in top10 mode
    show_demographics = False
else:
    # Classifier expects: Gender + Age + all other features (in that order)
    all_features = ['Gender'] + saved_scale_cols  # Include Gender + scaler features
    show_demographics = True

# read unscaled stats from scaler
means = scaler.mean_
std = scaler.scale_
mins = means - std * 3
maxs = means + std * 3

# Create feature_means mapping using the correct scaler indices
# Since features is now ordered to match scaler, we can use direct indexing
feature_means = {}
for i, feature in enumerate(features):
    scaler_idx = scaler.feature_names_in_.tolist().index(feature)
    feature_means[feature] = means[scaler_idx]

# Attempt to access external scaler else identity
try:
    scaler  # noqa: F821
except NameError:
    class _IdentityScaler:
        def transform(self, X):
            return X
    scaler = _IdentityScaler()


def create_anomaly_heatmap(sample_df, features, kde_q_high, kde_q_low):
    '''Create an anomaly heatmap for a single sample'''
    import base64
    from io import BytesIO
    
    # Compute anomaly scores for each feature
    anomaly_scores = []
    for feat in features:
        if feat in kde_q_high.columns and feat in kde_q_low.columns:
            value = sample_df[feat].iloc[0]
            q_high = kde_q_high[feat]
            q_low = kde_q_low[feat]
            score = compute_signed_anomaly_score(value, q_high, q_low)
            anomaly_scores.append(score)
        else:
            anomaly_scores.append(0.0)  # Default to normal if no quantile data
    
    # Create DataFrame for heatmap
    anomaly_df = pd.DataFrame([anomaly_scores], columns=features)
    
    # Create the heatmap
    fig, ax = plt.subplots(figsize=(max(12, len(features) * 0.8), 2))
    
    # Create heatmap with coolwarm colormap
    sns.heatmap(
        anomaly_df,
        ax=ax,
        cmap="coolwarm",
        center=0,
        vmin=-2,
        vmax=2,
        xticklabels=True,
        yticklabels=False,
        cbar_kws={'label': 'Anomaly Score', 'shrink': 0.8},
        annot=True,
        fmt='.2f',
        annot_kws={'size': 8}
    )
    
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=10)
    ax.set_title('Feature Anomaly Scores', fontsize=14, pad=20)
    
    # Convert plot to HTML
    buffer = BytesIO()
    plt.savefig(buffer, format='png', bbox_inches='tight', dpi=100)
    buffer.seek(0)
    plot_data = buffer.getvalue()
    buffer.close()
    plt.close()
    
    # Encode as base64 for HTML display
    plot_url = base64.b64encode(plot_data).decode()
    
    return f'<img src="data:image/png;base64,{plot_url}" style="max-width: 100%; height: auto;">'

# ---------- Incremental Scaler Update Functions ----------
def update_scaler_statistics(scaler, new_sample_df, features):
    """
    Incrementally update scaler statistics with a new sample.
    This implements Welford's online algorithm for computing running mean and variance.
    """
    # Get new sample values as numpy array
    new_sample = new_sample_df[features].values.flatten()
    
    # Current statistics
    current_mean = scaler.mean_.copy()
    current_var = scaler.var_.copy()
    current_n = scaler.n_samples_seen_
    
    # Update sample count
    new_n = current_n + 1
    
    # Update mean using incremental formula: new_mean = old_mean + (x - old_mean) / n
    delta = new_sample - current_mean
    new_mean = current_mean + delta / new_n
    
    # Update variance using incremental formula
    delta2 = new_sample - new_mean
    new_var = ((current_n * current_var) + delta * delta2) / new_n
    
    # Update scaler attributes
    scaler.mean_ = new_mean
    scaler.var_ = new_var
    scaler.scale_ = np.sqrt(new_var)
    scaler.n_samples_seen_ = new_n
    
    return scaler

def save_updated_scaler(scaler, scale_cols, scaler_path):
    """Save the updated scaler to disk"""
    scaler_bundle = {
        'scaler': scaler,
        'scale_cols': scale_cols
    }
    joblib.dump(scaler_bundle, scaler_path)
    print(f"Updated scaler saved to: {scaler_path}")

# ---------- Widget Construction ----------
sliders = {}
text_inputs = {}
value_labels = {}
input_mode = {}  # Track whether in text input mode for each feature

def adjust_slider_range(col, new_value):
    """Adjust slider range if new value is outside current range"""
    current_min = sliders[col].min
    current_max = sliders[col].max
    
    if new_value < current_min or new_value > current_max:
        # Expand range to accommodate new value with generous padding
        range_size = current_max - current_min
        padding = max(abs(new_value) * 0.3, range_size * 0.2, 1.0)
        new_min = min(current_min, new_value - padding)
        new_max = max(current_max, new_value + padding)
        
        # Update slider range first
        sliders[col].min = new_min
        sliders[col].max = new_max
        
        # Recalculate step for smoother operation
        sliders[col].step = (new_max - new_min) / 200.0 if new_max > new_min else 0.1

for i, col in enumerate(features):
    # Get correct scaler statistics for this feature
    scaler_idx = scaler.feature_names_in_.tolist().index(col)
    mu = float(means[scaler_idx])
    sigma = float(std[scaler_idx])
    lo = mu - sigma * 3
    hi = mu + sigma * 3
    step = (hi - lo) / 200.0 if hi > lo else 0.1
    
    # Create slider
    sliders[col] = W.FloatSlider(
        description=col[:14],  # Shorter description to prevent overlap
        value=mu,
        min=lo,
        max=hi,
        step=step,
        readout=False,
        continuous_update=False,
        layout=W.Layout(width='200px'),  # Reduced width to prevent overlap
        style={'description_width': '100px'}  # Fixed width for description
)
    
    # Create text input (initially hidden)
    text_inputs[col] = W.FloatText(
        value=mu,
        description='',
        layout=W.Layout(width='80px', display='none')
)
    
    # Create clickable value label
    value_labels[col] = W.Button(
        description=f"{mu:.2f}",
        button_style='',
        tooltip=f'Click to edit {col} value directly',
        layout=W.Layout(width='80px', height='28px'),
        disabled=False
)
    
    # Track input mode
    input_mode[col] = False

# Function to create event handlers for each feature
def create_handlers(col):
    def on_slider_change(change):
        if not input_mode[col]:  # Only update if not in text input mode
            new_val = change['new']
            value_labels[col].description = f"{new_val:.2f}"
            text_inputs[col].value = new_val
    
    def on_text_change(change):
        if input_mode[col]:  # Only update if in text input mode
            try:
                new_val = float(change['new'])
            except (TypeError, ValueError):
                return
            value_labels[col].description = f"{new_val:.2f}"
            
            # First adjust the slider range to accommodate the new value
            adjust_slider_range(col, new_val)
            
            # Then set the slider value - the range is already expanded if needed
            sliders[col].value = new_val
            
            # Automatically return to slider mode after applying the value
            switch_to_slider_mode(col)
    
    def on_value_click(button):
        # Switch to text input mode when value label is clicked
        switch_to_text_mode(col)
    
    return on_slider_change, on_text_change, on_value_click

def switch_to_text_mode(col):
    """Switch to text input mode for a feature"""
    input_mode[col] = True
    sliders[col].layout.display = 'none'
    text_inputs[col].layout.display = 'block'
    text_inputs[col].value = sliders[col].value
    # Note: FloatText has no on_submit; we auto-apply on value change.

def switch_to_slider_mode(col):
    """Switch back to slider mode for a feature"""
    input_mode[col] = False
    text_inputs[col].layout.display = 'none'
    sliders[col].layout.display = 'block'

# Attach event handlers to each widget
for col in features:
    on_slider_change, on_text_change, on_value_click = create_handlers(col)
    
    sliders[col].observe(on_slider_change, names='value')
    text_inputs[col].observe(on_text_change, names='value')
    value_labels[col].on_click(on_value_click)

# Arrange widgets into columns
N_COLS = 3 if len(features) >= 15 else 2
per_col = math.ceil(len(features) / N_COLS)
cols = []
for i in range(N_COLS):
    chunk = features[i*per_col:(i+1)*per_col]
    vb_items = []
    for c in chunk:
        # Create container for slider and text input (overlapping)
        input_container = W.VBox([
            sliders[c],
            text_inputs[c]
        ], layout=W.Layout(width='200px'))  # Match slider width
        
        # Row with input container and clickable value label with spacing
        row = W.HBox([
            input_container,
            value_labels[c]
        ], layout=W.Layout(align_items='center', justify_content='space-between', width='300px'))
        vb_items.append(row)
    
    cols.append(W.VBox(vb_items, layout=W.Layout(margin='0 10px 0 0')))

slider_panel = W.HBox(cols)

# ---------- Demographic Input Widgets ----------
# Only create demographic widgets if not in top10 mode
if show_demographics:
    # Age input with improved styling
    age_input = W.IntText(
        value=25,  # Default age
        description='Age:',
        min=0,
        max=100,
        layout=W.Layout(width='140px'),
        style={'description_width': '45px'}
    )

    # Gender selection with improved compact design
    gender_toggle = W.Dropdown(
        options=[('Male', 'Male'), ('Female', 'Female')],
        value='Male',  # Default to Male
        description='Gender:',
        layout=W.Layout(width='140px'),
        style={'description_width': '55px'}
    )

    # Create improved demographic panel with better spacing and styling
    demographic_panel = W.VBox([
        W.HTML(value="<b>📊 Demographic Information</b>", 
               layout=W.Layout(margin='2px 0px 8px 0px')),
        W.HBox([age_input, gender_toggle], 
               layout=W.Layout(justify_content='space-between', align_items='center'))
    ], layout=W.Layout(
        border='1px solid #e0e0e0', 
        padding='12px', 
        margin='8px 0px',
        border_radius='6px',
        background_color='#fafafa'
    ))
else:
    # Create empty demographic panel for top10 mode
    demographic_panel = W.HTML(value="")
    age_input = None
    gender_toggle = None

# Accordion collapsible by default (collapsed => selected_index=None)
accordion = W.Accordion(children=[slider_panel])
accordion.set_title(0, 'Feature Inputs (expand to edit)')
accordion.selected_index = None

# Buttons + options
predict_button = W.Button(description='Predict', button_style='success', tooltip='Run model prediction', icon='play')
reset_button = W.Button(description='Reset Means', button_style='warning', tooltip='Reset all inputs to mean values', icon='refresh')
show_full_toggle = W.ToggleButton(value=False, description='Show Full Input', tooltip='Toggle full feature table display', icon='table')
show_anomaly_toggle = W.ToggleButton(value=True, description='Show Anomaly Analysis', tooltip='Toggle anomaly analysis display', icon='search')
update_scaler_button = W.Button(description='Update Scaler', button_style='info', tooltip='Add current sample to scaler statistics', icon='plus')

# Add instruction label
instruction_label = W.HTML(
    value="<small><b>Instructions:</b> Use sliders for typical values or click the value display (right of slider) to enter arbitrary values. Values outside slider range will auto-expand the range. Toggle 'Show Anomaly Analysis' to see feature anomaly scores and heatmap.</small>",
    layout=W.Layout(margin='5px 0px')
)

status_dict = {0: 'Withdrawal', 1: 'Reentry'}
out = W.Output(layout={'border': '1px solid #ccc', 'padding': '6px'})

# ---------- Helper Functions ----------

def _collect_sample():
    sample_data = {}
    # Collect slider data
    for c in features:
        # Always get value from slider (which is kept in sync with text input)
        sample_data[c] = sliders[c].value
    
    # Add demographic data first (if widgets exist)
    if show_demographics and age_input is not None and gender_toggle is not None:
        # Gender needs to be first to match classifier feature order
        sample_data['Gender'] = 1 if gender_toggle.value == 'Male' else 0  # Male=1, Female=0
        sample_data['Age'] = age_input.value
    
    return pd.DataFrame([sample_data])


def _format_sample(sample: pd.DataFrame, show_full: bool):
    # Transpose & create tidy 2-column table
    df_t = sample.T.reset_index().rename(columns={'index': 'Feature', 0: 'Value'})
    
    # Format Gender display
    if 'Gender' in df_t['Feature'].values:
        gender_idx = df_t[df_t['Feature'] == 'Gender'].index[0]
        gender_val = df_t.loc[gender_idx, 'Value']
        df_t.loc[gender_idx, 'Value'] = 'Male' if gender_val == 1 else 'Female'
    
    if not show_full:
        # Show only first 15 (or all if fewer); could adapt to show changed from mean
        df_t = df_t.iloc[:15].copy()
        more = len(sample.columns) - 15
        if more > 0:
            df_t.loc[len(df_t)] = ['... (+{} more)'.format(more), '']
    
    # Try to use styling, fallback to plain HTML if jinja2 not available
    try:
        style = (df_t.style.set_table_styles([
            {'selector': 'th', 'props': [('font-size', '11px'), ('text-align', 'left')]},
            {'selector': 'td', 'props': [('font-size', '11px'), ('padding', '2px 6px')]},
        ]).hide(axis='index'))
        return style
    except AttributeError:
        # Fallback to HTML table if styling not available
        html_table = df_t.to_html(index=False, table_id='feature_table')
        return HTML(f"""
        <style>
        #feature_table th, #feature_table td {{
            font-size: 11px;
            padding: 2px 6px;
            text-align: left;
        }}
        </style>
        {html_table}
        """)


def on_predict(_):
    with out:
        clear_output(wait=True)
        sample = _collect_sample()
        
        # Suppress the sklearn feature names warning
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Handle scaling: Gender is not scaled, only the scaler features are scaled
            if show_demographics and 'Gender' in sample.columns:
                # Extract Gender (not scaled) 
                gender_data = sample[['Gender']]
                
                # Get scaler features in the exact order the scaler expects
                scaler_sample = sample[saved_scale_cols]  # This ensures correct order

                scaled_data = scaler.transform(scaler_sample)
                
                # Combine Gender + scaled features to match classifier expectations
                sample_scaled = np.hstack([gender_data.values, scaled_data])
            else:
                # Top10 mode - no demographics
                sample_scaled = scaler.transform(sample[all_features])
                
            proba = clf.predict_proba(sample_scaled)[0]
        
        pred_idx = int(np.argmax(proba))
        pred_class = clf.classes_[pred_idx]
        
        # Display basic prediction results
        display(HTML('<b>Input Sample (partial view)</b>' if not show_full_toggle.value else '<b>Input Sample (full)</b>'))
        display(_format_sample(sample, show_full_toggle.value))
        print('Prediction probabilities:')
        for cls, p in zip(clf.classes_, proba):
            print(f'  P({status_dict.get(cls, cls)}) = {p:.4f}')
        print(f'Predicted class: {status_dict.get(pred_class, pred_class)} (index {pred_idx})')
        
        # Compute and display anomaly analysis (only for non-demographic features)
        if show_anomaly_toggle.value:
            print('\n' + '='*60)
            display(HTML('<b>Anomaly Analysis</b>'))
            
            # Compute anomaly scores for each feature
            anomaly_scores = {}
            uncertainty_levels = {}
            
            for feat in features:
                if feat in kde_q_high.columns and feat in kde_q_low.columns:
                    value = sample[feat].iloc[0]
                    q_high = kde_q_high[feat]
                    q_low = kde_q_low[feat]
                    
                    # Compute anomaly score and uncertainty level
                    anomaly_score = compute_signed_anomaly_score(value, q_high, q_low)
                    uncertainty_level = find_uncertain_intervals(value, q_high, q_low)
                    
                    anomaly_scores[feat] = anomaly_score
                    uncertainty_levels[feat] = uncertainty_level
                else:
                    anomaly_scores[feat] = 0.0
                    uncertainty_levels[feat] = 'low'
            
            # Create summary statistics
            high_anomaly_features = [feat for feat, score in anomaly_scores.items() if abs(score) > 1.0]
            medium_anomaly_features = [feat for feat, score in anomaly_scores.items() if 0.5 < abs(score) <= 1.0]
            
            print(f"High anomaly features ({len(high_anomaly_features)}): {', '.join(high_anomaly_features[:5])}")
            if len(high_anomaly_features) > 5:
                print(f"  ... and {len(high_anomaly_features)-5} more")
            
            print(f"Medium anomaly features ({len(medium_anomaly_features)}): {', '.join(medium_anomaly_features[:5])}")
            if len(medium_anomaly_features) > 5:
                print(f"  ... and {len(medium_anomaly_features)-5} more")
            
            # Show top 5 most anomalous features with scores
            sorted_anomalies = sorted(anomaly_scores.items(), key=lambda x: abs(x[1]), reverse=True)
            print(f"\nTop 5 most anomalous features:")
            for i, (feat, score) in enumerate(sorted_anomalies[:5]):
                direction = "positive" if score > 0 else "negative"
                print(f"  {i+1}. {feat}: {score:.3f} ({direction} anomaly)")
            
            # Create and display anomaly heatmap
            print(f"\nAnomaly Heatmap:")
            try:
                heatmap_html = create_anomaly_heatmap(sample, features, kde_q_high, kde_q_low)
                display(HTML(heatmap_html))
            except Exception as e:
                print(f"Could not generate heatmap: {e}")
                # Fallback: show text-based representation
                print("Anomaly scores by feature:")
                for feat, score in sorted_anomalies:
                    bar_length = int(abs(score) * 10)
                    bar_char = "+" if score > 0 else "-"
                    bar = bar_char * bar_length
                    print(f"  {feat:15s}: {score:6.2f} |{bar}")


def on_update_scaler(_):
    global scaler, means, std, mins, maxs, feature_means
    
    with out:
        clear_output(wait=True)
        sample = _collect_sample()
        
        # Show current sample
        display(HTML('<b>Adding Sample to Scaler Statistics:</b>'))
        display(_format_sample(sample, show_full_toggle.value))
        
        # Store old statistics for comparison
        old_mean = scaler.mean_.copy()
        old_n = scaler.n_samples_seen_
        
        # Update scaler
        scaler = update_scaler_statistics(scaler, sample, features)
        
        # Update global variables with correct indexing
        means = scaler.mean_
        std = scaler.scale_
        
        # Update feature_means mapping with correct scaler indices
        feature_means = {}
        for feature in features:
            scaler_idx = scaler.feature_names_in_.tolist().index(feature)
            feature_means[feature] = means[scaler_idx]
        
        # Save updated scaler with timestamp
        new_scaler_path = scaler_path.replace(".pkl", f"_updated_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.pkl")
        save_updated_scaler(scaler, scale_cols, new_scaler_path)

        # Show statistics update
        print(f"Statistics updated! Sample count: {old_n} → {scaler.n_samples_seen_}")
        print("\nMean changes (first 5 features):")
        for i, feat in enumerate(features[:5]):
            old_val = old_mean[i]
            new_val = means[i]
            change = new_val - old_val
            print(f"  {feat}: {old_val:.4f} → {new_val:.4f} (Δ{change:+.4f})")
        
        if len(features) > 5:
            print(f"  ... and {len(features)-5} more features")


def on_reset(_):
    for c in features:
        # Get correct mean value from scaler
        scaler_idx = scaler.feature_names_in_.tolist().index(c)
        mean_val = float(means[scaler_idx])
        
        # Reset to slider mode
        input_mode[c] = False
        text_inputs[c].layout.display = 'none'
        sliders[c].layout.display = 'block'
        
        # Reset values
        sliders[c].value = mean_val
        text_inputs[c].value = mean_val
        value_labels[c].description = f"{mean_val:.2f}"
        
        # Reset slider range to original
        sigma = float(std[scaler_idx])
        lo = mean_val - sigma * 3
        hi = mean_val + sigma * 3
        sliders[c].min = lo
        sliders[c].max = hi
        sliders[c].step = (hi - lo) / 200.0 if hi > lo else 0.1
    
    # Reset demographic inputs to defaults (if they exist)
    if show_demographics and age_input is not None and gender_toggle is not None:
        age_input.value = 25
        gender_toggle.value = 'Male'
    
    # Trigger UI refresh if already displayed
    if out.outputs:
        on_predict(None)


def on_toggle(_):
    # Always re-run prediction when any toggle changes, regardless of current output state
    on_predict(None)

predict_button.on_click(on_predict)
reset_button.on_click(on_reset)
show_full_toggle.observe(on_toggle, names='value')
show_anomaly_toggle.observe(on_toggle, names='value')
update_scaler_button.on_click(on_update_scaler)

# ---------- Assemble UI ----------
buttons = W.HBox([predict_button, reset_button, update_scaler_button, show_full_toggle, show_anomaly_toggle])

# Build UI components list
ui_components = [
    instruction_label,
    buttons
]

# Add demographic panel only if not in top10 mode
if show_demographics:
    ui_components.append(demographic_panel)

ui_components.extend([
    accordion,
    out
])

ui = W.VBox(ui_components)

# Inject CSS for better button contrast and clickable value labels
custom_css = HTML("""
<style>
    .widget-inline-hbox .widget-label { min-width: 0 !important; }
    
    /* Prevent slider overflow */
    .widget-hslider {
        overflow: hidden !important;
        max-width: 200px !important;
    }
    .widget-hslider .ui-slider {
        max-width: 95px !important;  /* Constrain slider track to prevent overlap */
        margin-left: 5px !important;  /* Add margin to prevent leftward overlap */
    }
    .widget-hslider .ui-slider-handle {
        max-width: 20px !important;
        min-width: 15px !important;
    }
    .widget-hslider .widget-label {
        max-width: 100px !important;  /* Fixed width for feature names */
        min-width: 100px !important;
        overflow: hidden !important;
        text-overflow: ellipsis !important;
        white-space: nowrap !important;
    }
    
    /* Enhanced demographic panel styling */
    .widget-vbox:has(.widget-dropdown) {
        background: linear-gradient(145deg, #f8f9fa, #e9ecef) !important;
        border-radius: 8px !important;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
    }
    
    /* Improved dropdown styling */
    .widget-dropdown select {
        border: 1px solid #ced4da !important;
        border-radius: 4px !important;
        padding: 4px 8px !important;
        background-color: white !important;
        font-size: 13px !important;
        transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out !important;
    }
    
    .widget-dropdown select:focus {
        border-color: #80bdff !important;
        outline: 0 !important;
        box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25) !important;
    }
    
    /* Improved IntText styling */
    .widget-text input {
        border: 1px solid #ced4da !important;
        border-radius: 4px !important;
        padding: 4px 8px !important;
        font-size: 13px !important;
        transition: border-color 0.15s ease-in-out, box-shadow 0.15s ease-in-out !important;
    }
    
    .widget-text input:focus {
        border-color: #80bdff !important;
        outline: 0 !important;
        box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25) !important;
    }
    
    /* Styling for Predict button (success style) */
    .widget-button.mod-success { 
        background-color: #28a745 !important;
        color: white !important;
        border-color: #1e7e34 !important;
        font-weight: bold !important;
    }
    .widget-button.mod-success:hover { 
        background-color: #218838 !important;
        color: white !important;
    }
    
    /* Styling for Reset button (warning style) */
    .widget-button.mod-warning { 
        background-color: #ffc107 !important;
        color: #212529 !important;
        border-color: #d39e00 !important;
        font-weight: bold !important;
    }
    .widget-button.mod-warning:hover { 
        background-color: #e0a800 !important;
        color: #212529 !important;
    }
    
    /* Styling for Update Scaler button (info style) */
    .widget-button.mod-info { 
        background-color: #17a2b8 !important;
        color: white !important;
        border-color: #117a8b !important;
        font-weight: bold !important;
    }
    .widget-button.mod-info:hover { 
        background-color: #138496 !important;
        color: white !important;
    }
    
    /* Default styling for value label buttons */
    .widget-button { 
        font-size: 11px !important; 
        border: 1px solid #ccc !important;
        background-color: #f8f9fa !important;
        color: #495057 !important;
    }
    .widget-button:hover { 
        background-color: #e9ecef !important;
        cursor: pointer !important;
        color: #495057 !important;
    }
    
    /* Toggle button styling */
    .widget-toggle-button {
        background-color: #6c757d !important;
        color: white !important;
        border-color: #545b62 !important;
        font-weight: bold !important;
    }
    .widget-toggle-button:hover {
        background-color: #5a6268 !important;
        color: white !important;
    }
    .widget-toggle-button.mod-active {
        background-color: #007bff !important;
        border-color: #0056b3 !important;
        color: white !important;
    }
</style>
""")

display(custom_css, ui)

VBox(children=(HTML(value="<small><b>Instructions:</b> Use sliders for typical values or click the value displ…