In [None]:
#----------------------------------------
file = 'Twisted_2D_vdW_Bilayers_DFT.json'
#----------------------------------------
Dir = ''
Dir_figures = 'figures_Correlations'
#-----------------------------------



#==================
!pip install pandas
!pip install numpy
!pip install plotly
!pip install scipy
!pip install dash
!pip install kaleido
#==================
import pandas as pd
import numpy as np
import uuid
import os
#================================
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
from plotly.offline import plot
#==============================
import dash
from dash import dcc, html, Input, Output, State
#===============================================
from scipy.stats import gaussian_kde
#===================================================
path_folder = os.path.join(os.getcwd(), Dir_figures)
if not os.path.exists(path_folder): os.makedirs(path_folder)
figures_plot = Dir_figures
#=========================

In [None]:
df = pd.read_json(Dir + file)
df_mono = df[df['number_layers'] == 1]
df = df[df['number_layers'] == 2]

df_mono["work_function_SO"] = df_mono["e_vacuum_SO"] - df_mono["e_fermi_SO"]
df["layer1"] = df["id_layers"].apply(lambda x: x[0])
df["layer2"] = df["id_layers"].apply(lambda x: x[1])

max_angle = 181

def merge_mono_info(df, df_mono, layer_col, suffix):
    df_mono_ren = df_mono[ [ "id", "gap_SO", "e_fermi_SO", "e_per_area_SO", "total_energy_SO", "work_function_SO", "lattice_type"] ].copy()
    df_mono_ren = df_mono_ren.rename( columns={ "id": f"id_{suffix}", "gap_SO": f"gap_SO_{suffix}", "e_fermi_SO": f"e_fermi_SO_{suffix}", "e_per_area_SO": f"e_per_area_SO_{suffix}", "total_energy_SO": f"total_energy_SO_{suffix}", "work_function_SO": f"work_function_SO_{suffix}", "lattice_type": f"lattice_type_{suffix}" } )
    return df.merge( df_mono_ren, left_on=layer_col, right_on=f"id_{suffix}", how="left" ).drop(columns=[f"id_{suffix}"])

df = merge_mono_info(df, df_mono, "layer1", "layer1")
df = merge_mono_info(df, df_mono, "layer2", "layer2")
df["work_function_SO"] = df["e_vacuum_SO"] - df["e_fermi_SO"]
df["electron_affinity_SO"] = df["e_vacuum_SO"] - df["e_cbm_SO"]
df["ionization_potential_SO"] = df["e_vacuum_SO"] - df["e_vbm_SO"]

df["mean_abs_area_perc_mismatch"] = df["area_perc_mismatch"].apply(lambda x: (abs(x[0][0]) + abs(x[0][1])) / 2)
df["mean_abs_angle_perc_mismatch"] = df["angle_perc_mismatch"].apply(lambda x: (abs(x[0][0]) + abs(x[0][1])) / 2)
df["mean_abs_perc_area_change"] = df["perc_area_change"].apply(lambda x: (abs(x[0]) + abs(x[1])) / 2)
df["mean_abs_perc_angle_change"] = df["perc_angle_change"].apply(lambda x: (abs(x[0]) + abs(x[1])) / 2)

df["mean_abs_area_perc_mismatch"] = round(df["mean_abs_area_perc_mismatch"],5)
df["mean_abs_angle_perc_mismatch"] = round(df["mean_abs_angle_perc_mismatch"],5)
df["mean_abs_perc_area_change"] = round(df["mean_abs_perc_area_change"],5)
df["mean_abs_perc_angle_change"] = round(df["mean_abs_perc_angle_change"],5)

def parse_value(value):
    if isinstance(value, list) and len(value) == 1: value = value[0]
    try: return float(value)
    except: return value

columns_to_fix = ["z_separation", "rotation_angle"]
for col in columns_to_fix: df[col] = df[col].apply(parse_value)
df["rotation_angle"] = df["rotation_angle"] % max_angle

######### ID por cálculo ###########
id_map = {id_val: i for i, id_val in enumerate(df["id"].unique(), start=1)}
df["id_numeric"] = df["id"].map(id_map)

def classify_bilayer(formula):
    parts = formula.split('+')
    return 'Homobilayer' if len(parts) == 2 and parts[0] == parts[1] else 'Heterobilayer'

df['bilayer_type'] = df['formula'].apply(classify_bilayer)

df['e_binding'] = pd.to_numeric(df['e_binding'], errors='coerce')
df = df[df['e_binding'] <= 700]

####### ID by formula ###########
def normalize_formula(formula):
    parts = [part.strip() for part in formula.split("+")]
    parts.sort()
    return " + ".join(parts)

df["formula_normalized"] = df["formula"].apply(normalize_formula)

id_map = {val: i for i, val in enumerate(df["formula_normalized"].unique(), start=1)}
df["id_formula"] = df["formula_normalized"].map(id_map)

df['rotation_angle_str'] = df['rotation_angle'].map(lambda x: f"{x:.1f}°")
df['e_binding_str'] = df['e_binding'].map(lambda x: f"{x:.2f} eV")
df['z_separation_str'] = df['z_separation'].map(lambda x: f"{x:.2f} Å")
df['e_per_area_SO_str'] = df['e_per_area_SO'].map(lambda x: f"{x:.2f} eV/Å²")
df['total_energy_SO_str'] = df['total_energy_SO'].map(lambda x: f"{x:.2f} eV")
df['gap_SO_str'] = df['gap_SO'].map(lambda x: f"{x:.2f} eV")
df['e_slide_str'] = df['e_slide'].map(lambda x: f"{x:.2f} eV")
df['gap_SO_str'] = df['gap_SO'].map(lambda x: f"{x:.5f} eV")

selected_columns = [
    "formula", "rotation_angle_str", "gap_SO_str", "e_binding_str", "lattice_type_layer1",
    "lattice_type_layer2", "number_layers", "type_ions_layers", "number_ions_layers",
    "number_type_ions_layers", "range_ions_layers", "number_ions", "area_perc_mismatch",
    "perc_area_change", "perc_mod_vectors_change", "angle_perc_mismatch", "perc_angle_change",
    "rotation_angle", "shift_plane", "z_separation", "thickness", "total_thickness",
    "lattice_type", "inversion_symmetry", "pseudo_type", "exchange_correlation_functional",
    "vdW", "a1", "a2", "a3", "cell_area", "b1", "b2", "b3", "zb_area", "lorbit", "nk",
    "nb", "ne", "gap", "e_vbm", "e_cbm", "vbm", "cbm", "type_gap", "k_vbm", "k_cbm",
    "e_fermi", "e_vacuum", "total_energy", "e_per_ion", "e_per_area", "e_binding", "e_slide",
    "lorbit_SO", "ispin_SO", "nk_SO", "nb_SO", "ne_SO", "e_vbm_SO", "e_cbm_SO", "vbm_SO",
    "cbm_SO", "type_gap_SO", "k_vbm_SO", "k_cbm_SO", "e_fermi_SO", "e_vacuum_SO",
    "total_energy_SO", "e_per_ion_SO", "e_per_area_SO", "gap_SO_layer1", "e_fermi_SO_layer1",
    "e_per_area_SO_layer1", "total_energy_SO_layer1", "work_function_SO_layer1",
    "gap_SO_layer2", "e_fermi_SO_layer2", "e_per_area_SO_layer2", "total_energy_SO_layer2",
    "work_function_SO_layer2", "work_function_SO"
]

custom_colorscale = [
    [0/6,  'rgb(0, 255, 255)'],     # Ciano
    [1/6,  'rgb(0, 255, 0)'],       # Verde
    [2/6,  'rgb(255, 0, 0)'],       # Vermelho
    [3/6,  'rgb(255, 215, 0)'],     # Dourado
    [4/6,  'rgb(255, 0, 255)'],     # Magenta
    [5/6,  'rgb(0, 0, 255)'],       # Azul
    [6/6,  'rgb(0, 0, 0)']          # Preto
]

In [None]:
cminimo = 0

app = dash.Dash(__name__)
server = app.server

fig = go.Figure()

app.layout = html.Div(
    [
        # Filter Controls
        html.Div([ html.Label("Bilayer Type", style={"fontWeight": "bold"}), dcc.Checklist(id="bilayer-checklist", options=[{"label": "Homobilayer", "value": "Homobilayer"},{"label": "Heterobilayer", "value": "Heterobilayer"}], value=["Homobilayer", "Heterobilayer"], labelStyle={"display": "inline-block", "margin-right": "10px"}, inputStyle={"margin-right": "5px"}, style={"textAlign": "center"}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div([ html.Label("Filter by Monolayer Formula", style={"fontWeight": "bold"}), dcc.Dropdown(id="formula-dropdown", options=[{"label": f, "value": f} for f in sorted(df_mono["formula"].unique())], value=None, placeholder="Select a monolayer...", clearable=True, style={"width": "100%"}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div([ html.Label("Filter by Rotation Angle (°)", style={"fontWeight": "bold"}), dcc.RangeSlider(id="angle-slider", min=0, max=max_angle, step=1, marks={i: f"{i}°" for i in range(0, 181, 30)}, value=[0, max_angle], tooltip={"placement": "bottom", "always_visible": False}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div([ html.Label("Select property to plot (y-axis)", style={"fontWeight": "bold"}), dcc.Dropdown(id="yaxis-dropdown", options=[{"label": "Angle Change (%)", "value": "mean_abs_perc_angle_change"}, {"label": "Angle perc Mismatch (%)", "value": "mean_abs_angle_perc_mismatch"}, {"label": "Area Change (%)", "value": "mean_abs_perc_area_change"}, {"label": "Area perc Mismatch (%)", "value": "mean_abs_area_perc_mismatch"}, {"label": "Atom count", "value": "number_ions"}, {"label": "Band Gap (SOC, eV)", "value": "gap_SO"}, {"label": "Bilayer combination (ID formula)", "value": "id_formula"}, {"label": "Bilayer type", "value": "bilayer_type"}, {"label": "Binding Energy (eV)", "value": "e_binding"}, {"label": "Electron Affinity (SOC, eV)", "value": "electron_affinity_SO"}, {"label": "Energy per Area (SOC, eV/Å²)", "value": "e_per_area_SO"}, {"label": "Energy per Ion (SOC, eV/n° ions)", "value": "e_per_ion_SO"}, {"label": "Ionization Potential (SOC, eV)", "value": "ionization_potential_SO"}, {"label": "Rotation Angle (°)", "value": "rotation_angle"}, {"label": "Single ID (ID)", "value": "id_numeric"}, {"label": "Sliding Energy (eV)", "value": "e_slide"}, {"label": "Total Energy (SOC, eV)", "value": "total_energy_SO"}, {"label": "Work Function (SOC)", "value": "work_function_SO"}, {"label": "Z Separation (Å)", "value": "z_separation"}], value="e_binding", clearable=False, style={"fontSize": "16px", "color": "black", "backgroundColor": "white", "margin": "0 auto", "width": "300px", "textAlign": "left"}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div([ html.Label("Select property to plot (x-axis)", style={"fontWeight": "bold"}), dcc.Dropdown(id="xaxis-dropdown", options=[{"label": "Angle Change (%)", "value": "mean_abs_perc_angle_change"}, {"label": "Angle perc Mismatch (%)", "value": "mean_abs_angle_perc_mismatch"}, {"label": "Area Change (%)", "value": "mean_abs_perc_area_change"}, {"label": "Area perc Mismatch (%)", "value": "mean_abs_area_perc_mismatch"}, {"label": "Atom count", "value": "number_ions"}, {"label": "Band Gap (SOC, eV)", "value": "gap_SO"}, {"label": "Bilayer combination (ID formula)", "value": "id_formula"}, {"label": "Bilayer type", "value": "bilayer_type"}, {"label": "Binding Energy (eV)", "value": "e_binding"}, {"label": "Electron Affinity (SOC, eV)", "value": "electron_affinity_SO"}, {"label": "Energy per Area (SOC, eV/Å²)", "value": "e_per_area_SO"}, {"label": "Energy per Ion (SOC, eV/n° ions)", "value": "e_per_ion_SO"}, {"label": "Ionization Potential (SOC, eV)", "value": "ionization_potential_SO"}, {"label": "Rotation Angle (°)", "value": "rotation_angle"}, {"label": "Single ID (ID)", "value": "id_numeric"}, {"label": "Sliding Energy (eV)", "value": "e_slide"}, {"label": "Total Energy (SOC, eV)", "value": "total_energy_SO"}, {"label": "Work Function (SOC)", "value": "work_function_SO"}, {"label": "Z Separation (Å)", "value": "z_separation"}], value="id_formula", clearable=False, style={"fontSize": "16px", "color": "black", "backgroundColor": "white", "margin": "0 auto", "width": "300px", "textAlign": "left"}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div([ html.Label("Select property to plot (colorbar)", style={"fontWeight": "bold"}), dcc.Dropdown(id="colorbar-dropdown", options=[{"label": "Atom count", "value": "number_ions"}, {"label": "Rotation Anlge (°)", "value": "rotation_angle"}, {"label": "Z Separation (Å)", "value": "z_separation"}], value="rotation_angle", clearable=False, style={"fontSize": "16px", "color": "black", "backgroundColor": "white", "margin": "0 auto", "width": "300px", "textAlign": "left"}), ], style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"} ),
        html.Div(id='kde-input-container', children=[ html.Hr(), html.Label("Define Density Ranges (based on Colorbar Property)", style={"fontWeight": "bold"}), html.Div([ html.Label("Blue Curve Range:", style={'margin-right': '10px'}), dcc.Input(id='range-min-1', type='number', placeholder='min', style={'width': '80px'}), dcc.Input(id='range-max-1', type='number', placeholder='max', style={'width': '80px', 'marginLeft': '5px'}), ], style={'padding': '5px'}), html.Div([ html.Label("Red Curve Range: ", style={'margin-right': '10px'}), dcc.Input(id='range-min-2', type='number', placeholder='min', style={'width': '80px'}), dcc.Input(id='range-max-2', type='number', placeholder='max', style={'width': '80px', 'marginLeft': '5px'}), ], style={'padding': '5px'}), html.Div([ html.Label("Green Curve Range:", style={'margin-right': '10px'}), dcc.Input(id='range-min-3', type='number', placeholder='min', style={'width': '80px'}), dcc.Input(id='range-max-3', type='number', placeholder='max', style={'width': '80px', 'marginLeft': '5px'}), ], style={'padding': '5px'}), html.Hr(), ], style={'width': '80%', 'margin': '0 auto', "textAlign": "center"}),

        dcc.Graph(id="scatter-plot", figure=fig, style={"width": "100%", "backgroundColor": "white"}),

        # Store to store the figure and save buttons
        dcc.Store(id='figure-store'),
        html.Div([
            html.Button("Save as HTML", id="btn-save-html", n_clicks=0, style={'margin-right': '10px'}),
            html.Button("Save as PDF", id="btn-save-pdf", n_clicks=0),
        ], style={'textAlign': 'center', 'padding': '10px'}),
        html.Div(id='save-notification', style={'textAlign': 'center'}),
        html.Hr(),
        dcc.Markdown(id="click-info", style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "16px", "color": "black", "backgroundColor": "white", "padding": "10px", "margin": "0 auto", "textAlign": "center"}),
    ] )

@app.callback(Output("click-info", "children"), Input("scatter-plot", "clickData"))
def display_click_data(clickData):
    if clickData is None: return "Click in a point to see details."
    try:
        point_id = clickData["points"][0]["customdata"][0]
        material = df[df["id_numeric"] == point_id].iloc[0]
        info = f"==== Material {point_id} ====\n\n"
        info += "\n".join((f"**{col}**: `{material[col]}`" if isinstance(material[col], (list, str)) else f"**{col}**: {material[col]}") for col in selected_columns if col != "ID")
        return info
    except Exception as e:
        return dash.no_update

# The main callback UPDATES the figure and STORES its data

@app.callback(
    [Output("scatter-plot", "figure"), Output("figure-store", "data")],
    [
        Input("angle-slider", "value"), Input("formula-dropdown", "value"),
        Input("bilayer-checklist", "value"), Input("yaxis-dropdown", "value"),
        Input("xaxis-dropdown", "value"), Input("colorbar-dropdown", "value"),
        Input("scatter-plot", "relayoutData"), Input("range-min-1", "value"),
        Input("range-max-1", "value"), Input("range-min-2", "value"),
        Input("range-max-2", "value"), Input("range-min-3", "value"),
        Input("range-max-3", "value"),
    ]
)

def update_figure(
    angle_range, selected_formula, selected_types, selected_y, selected_x,
    selected_colorbar, relayout_data, min1, max1, min2, max2, min3, max3):
    
    # Filtering logic and figure creation
    angle_min, angle_max = angle_range
    filtered_df = df[(df["rotation_angle"] >= angle_min) & (df["rotation_angle"] <= angle_max)]
    if selected_formula: filtered_df = filtered_df[filtered_df["formula"].apply(lambda f: selected_formula in f.split("+"))]
    if selected_types: filtered_df = filtered_df[filtered_df["bilayer_type"].isin(selected_types)]
    x_ok = (relayout_data and "xaxis.range[0]" in relayout_data and "xaxis.range[1]" in relayout_data)
    y_ok = (relayout_data and "yaxis.range[0]" in relayout_data and "yaxis.range[1]" in relayout_data)
    if x_ok or y_ok:
        if x_ok:
            x_min, x_max = relayout_data["xaxis.range[0]"], relayout_data["xaxis.range[1]"]
            filtered_df = filtered_df[(filtered_df[selected_x] >= x_min) & (filtered_df[selected_x] <= x_max)]
        if y_ok:
            y_min, y_max = relayout_data["yaxis.range[0]"], relayout_data["yaxis.range[1]"]
            filtered_df = filtered_df[(filtered_df[selected_y] >= y_min) & (filtered_df[selected_y] <= y_max)]
    label_dict = { "e_binding": "Binding Energy (eV)", "z_separation": "Z Separation (Å)", "e_per_area_SO": "Energy per Area (SOC, eV/Å²)", "total_energy_SO": "Total Energy (SOC, eV)", "gap_SO": "Band Gap (SOC, eV)", "e_slide": "Sliding Energy (eV)", "id_formula": "Bilayer combination (ID formula)", "rotation_angle": "Rotation Angle (°)", "number_ions": "Atom count", "bilayer_type": "Bilayer type", "id_numeric": "Single ID (ID)", "e_per_ion_SO": "Energy per Ion (SOC, eV/n° ions)", "work_function_SO": "Work Function (SOC)", "ionization_potential_SO": "Ionization Potential (SOC, eV)", "electron_affinity_SO": "Electron Affinity (SOC, eV)", "mean_abs_area_perc_mismatch": "Area perc Mismatch (%)", "mean_abs_angle_perc_mismatch": "Angle perc Mismatch (%)", "mean_abs_perc_area_change": "Area Change (%)", "mean_abs_perc_angle_change": "Angle Change (%)", }
    fig = make_subplots(rows=1, cols=2, column_widths=[0.8, 0.1], shared_yaxes=True, specs=[[{"type": "xy"}, {"type": "xy"}]], horizontal_spacing=0.02)

    def create_kde_trace(data_series, total_points, color, name):
        if data_series.empty or len(data_series.unique()) < 2: return None
        y_data = data_series.dropna()
        kde = gaussian_kde(y_data)
        y_range_total = filtered_df[selected_y].dropna()
        if y_range_total.empty: return None
        y_range = np.linspace(y_range_total.min(), y_range_total.max(), 200)
        density = kde(y_range)
        weight = len(y_data) / total_points
        weighted_density = density if name == 'Total' else density * weight
        trace = go.Scatter(x=weighted_density, y=y_range, mode='lines', orientation='h', fill='tozerox' if color=='lightgray' else 'none', line=dict(color=color, shape='spline'), name=name, legendgroup='kde_group', hovertemplate=f"Density ({name})<extra></extra>")
        return trace
    n_total = len(filtered_df)

    if n_total > 0:
        total_kde_trace = create_kde_trace(filtered_df[selected_y], n_total, 'lightgray', 'Total')
        if total_kde_trace: fig.add_trace(total_kde_trace, row=1, col=2)
        ranges = [(min1, max1), (min2, max2), (min3, max3)]; colors = ['blue', 'red', 'green']
        for i, (min_val, max_val) in enumerate(ranges):
            if min_val is not None and max_val is not None and min_val < max_val:
                seg_df = filtered_df[(filtered_df[selected_colorbar] >= min_val) & (filtered_df[selected_colorbar] < max_val)]
                seg_name = f'{min_val:.2f} - {max_val:.2f}'
                seg_trace = create_kde_trace(seg_df[selected_y], n_total, colors[i], seg_name)
                if seg_trace: fig.add_trace(seg_trace, row=1, col=2)
    customdata = np.stack([filtered_df["id_numeric"], filtered_df["formula"], filtered_df["rotation_angle_str"], filtered_df["gap_SO_str"], round(filtered_df[selected_y], 5), filtered_df["lattice_type_layer1"], filtered_df["lattice_type_layer2"]], axis=-1,)
    ticksuffix_dict = {"rotation_angle": "°", "z_separation": " Å", "number_ions": ""}
    tickvals_dict = {"rotation_angle": np.arange(0, max_angle + 1, 30), "z_separation": np.arange(0, df["z_separation"].max() + 0.5, 0.5) if not df.empty else [], "number_ions": np.arange(0, df["number_ions"].max() + 0.001, 3) if not df.empty else [],}
    fig.add_trace(go.Scatter(x=filtered_df[selected_x], y=filtered_df[selected_y], mode="markers", marker=dict(color=filtered_df[selected_colorbar] if not filtered_df.empty else [], colorscale=custom_colorscale, cmin=cminimo, cmax=df[selected_colorbar].max() if not df.empty else None, size=6, colorbar=dict(title=dict(text=label_dict.get(selected_colorbar, selected_colorbar), side="top"), orientation='h', y=1.04, yanchor='bottom', x=0.45, xanchor='center', len=0.9, tickvals=tickvals_dict.get(selected_colorbar, []), ticksuffix=ticksuffix_dict.get(selected_colorbar, ""))), customdata=customdata, hovertemplate=("Material ID: %{customdata[0]}<br>" "Formula: %{customdata[1]}<br>" "Rotation Angle: %{customdata[2]}<br>" "Band Gap (SOC): %{customdata[3]}<extra></extra><br>" f"{label_dict.get(selected_y, selected_y)}: " + "%{customdata[4]}<br>" "Lattice Type Layer 1: %{customdata[5]}<br>" "Lattice Type Layer 2: %{customdata[6]}"), showlegend=False,), row=1, col=1,)
    fig.update_layout(uirevision="keep-zoom", height=600, width=600, plot_bgcolor="white", paper_bgcolor="white", font=dict(family="Times New Roman", size=14), margin=dict(t=100, b=60), title=f"{label_dict.get(selected_y, selected_y)} vs {label_dict.get(selected_x, selected_x)} with Sliced Density Plot: {len(filtered_df)} Instances", title_x=0.5, legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.91, bgcolor='rgba(255,255,255,0.6)'))

    # <<< ALTERAÇÃO: Ajustes de posição dos labels (standoff) >>>
    fig.update_yaxes(title=dict(text=label_dict.get(selected_y, selected_y), standoff=0), gridcolor="lightgray", linecolor="black", mirror=True, row=1, col=1)
    fig.update_xaxes(title=dict(text=label_dict.get(selected_x, selected_x)), gridcolor="lightgray", linecolor="black", mirror=True, row=1, col=1)
    fig.update_xaxes(title=dict(text="Density", standoff=30), row=1, col=2, showticklabels=False, zeroline=False)

    return [fig, fig.to_dict()]

# CALLBACK: save the figure when buttons are clicked
@app.callback(
    Output("save-notification", "children"),
    [Input("btn-save-html", "n_clicks"), Input("btn-save-pdf", "n_clicks")],
    [State("figure-store", "data")],
    prevent_initial_call=True # Prevents the callback from running on startup
)

def save_figure(n_html, n_pdf, figure_data):
    if not figure_data:
        return "No figure to save."

    # Identifies which button was clicked
    ctx = dash.callback_context
    if not ctx.triggered:
        return dash.no_update

    button_id = ctx.triggered[0]['prop_id'].split('.')[0]

    # Reconstructs the figure from the stored data
    fig = go.Figure(figure_data)
    unique_id = uuid.uuid4().hex

    try:
        if button_id == "btn-save-html":
            file_path = Dir_figures + "/" + f"plot_{unique_id}.html"
            fig.write_html(file_path)
            return f"Figure saved to {file_path}"

        elif button_id == "btn-save-pdf":
            file_path = Dir_figures + "/" + f"plot_{unique_id}.pdf"
            fig.write_image(file_path, format="pdf")
            return f"Figure saved to {file_path}"

    except Exception as e:
        return f"Error saving figure: {e}"

    return dash.no_update

if __name__ == "__main__":
    app.run(debug=False, port=8050)

___