In [None]:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import kaleido
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

import os

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import os

def plot_experiments(directory: str,
                     base_path: str = "/home/tvanhout/oxides_ML/models/Experiments/",
                     output_path: str = "./experiment_plots/",
                     plot_title: str = None) -> None:
    """
    Generate and save plot showing prediction and error analysis for a given experiment directory.
    """
    os.makedirs(output_path, exist_ok=True)

    # Load data
    df_test_set = pd.read_csv(os.path.join(base_path, directory, "test_set.csv"))
    df_uq = pd.read_csv(os.path.join(base_path, directory, "uq.csv"))

    # Color maps
    color_map_safe = px.colors.qualitative.Safe
    color_map_set2 = px.colors.qualitative.Set2

    # Sort data
    df_test_sorted = df_test_set.sort_values("Material")
    df_uq_sorted1 = df_uq[df_uq["split"] == "test"].sort_values("molecule_group")
    df_uq_sorted2 = df_uq[df_uq["split"] == "test"].sort_values("material")

    # Create subplot
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=[
            "True vs Predicted Energy by Material",
            "Error Distribution by Molecular Group",
            "Error Distribution by Material"
        ],
        horizontal_spacing=0.05
    )

    # === Plot 1 ===
    unique_materials = df_test_sorted["Material"].unique()
    for i, material in enumerate(unique_materials):
        subset = df_test_sorted[df_test_sorted["Material"] == material]
        if not subset.empty:
            fig.add_trace(go.Scatter(
                x=subset['True_eV'],
                y=subset['Prediction_eV'],
                mode='markers',
                name=material,
                marker=dict(size=10, color=color_map_safe[i % len(color_map_safe)]),
                hovertemplate=(
                    "<b>Material:</b> %{customdata[0]}<br>" +
                    "<b>Molecule Group:</b> %{customdata[1]}<br>" +
                    "<b>Molecule:</b> %{customdata[2]}<br>" +
                    "<b>State:</b> %{customdata[3]}<br>" +
                    "<b>Dissociation:</b> %{customdata[4]}<br>" +
                    "<b>True Energy:</b> %{x:.3f} eV<br>" +
                    "<b>Predicted Energy:</b> %{y:.3f} eV<br>" +
                    "<b>Relative Error:</b> %{customdata[5]} eV<br>" +
                    "<b>Absolute Error:</b> %{customdata[6]} eV<br><extra></extra>"
                ),
                customdata=subset[['Material', 'Molecule Group', 'Molecule', 'State', 'Dissociation', "Error_eV", "Abs_error_eV"]],
                showlegend=False
            ), row=1, col=1)

    # Add 1:1 line
    min_val = df_test_sorted['True_eV'].min()
    max_val = df_test_sorted['True_eV'].max()
    fig.add_trace(go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode='lines',
        name='1:1 Line',
        line=dict(dash='dash', color='red'),
        showlegend=False
    ), row=1, col=1)

    # === Plot 2 ===
    for i, group in enumerate(df_uq_sorted1['molecule_group'].unique()):
        group_data = df_uq_sorted1[df_uq_sorted1['molecule_group'] == group]
        fig.add_trace(go.Box(
            y=group_data['error'],
            name=group,
            marker_color=color_map_set2[i % len(color_map_set2)],
            boxmean=True,
            boxpoints="all",
            jitter=0.5,
            line=dict(width=3),
            width=0.2,
            legendgroup=group,
            hovertemplate=(
                "Material: %{customdata[7]}<br>" +
                "Molecule Group: %{customdata[0]}<br>" +
                "Molecule: %{customdata[1]}<br>" +
                "State: %{customdata[2]}<br>" +
                "Dissociation: %{customdata[3]}<br>" +
                "True Value: %{customdata[4]} eV<br>" +
                "Std Dev: %{customdata[5]} eV<br>" +
                "Error: %{customdata[6]} eV<br><extra></extra>"
            ),
            customdata=group_data[['molecule_group', 'molecule', 'state', 'dissociation', 'y_true', 'y_std', 'error', 'material']].values,
            showlegend=False
        ), row=1, col=2)

    # === Plot 3 ===
    for i, group in enumerate(df_uq_sorted2['material'].unique()):
        group_data = df_uq_sorted2[df_uq_sorted2['material'] == group]
        fig.add_trace(go.Box(
            y=group_data['error'],
            name=group,
            marker_color=color_map_set2[i % len(color_map_set2)],
            boxmean=True,
            boxpoints="all",
            jitter=0.5,
            line=dict(width=3),
            width=0.2,
            legendgroup=group,
            hovertemplate=(
                "Material: %{customdata[7]}<br>" +
                "Molecule Group: %{customdata[0]}<br>" +
                "Molecule: %{customdata[1]}<br>" +
                "State: %{customdata[2]}<br>" +
                "Dissociation: %{customdata[3]}<br>" +
                "True Value: %{customdata[4]} eV<br>" +
                "Std Dev: %{customdata[5]} eV<br>" +
                "Error: %{customdata[6]} eV<br><extra></extra>"
            ),
            customdata=group_data[['molecule_group', 'molecule', 'state', 'dissociation', 'y_true', 'y_std', 'error', 'material']].values,
            showlegend=False
        ), row=1, col=3)

    # Layout and axes
    fig.update_layout(
        height=700,
        width=2600,
        template="plotly_white",
        title_text=plot_title or "Prediction Performance Overview",
        showlegend=True
    )
    fig.update_xaxes(showticklabels=False, row=1, col=2)
    fig.update_xaxes(showticklabels=False, row=1, col=3)
    fig.update_yaxes(title_text="Predicted Energy (eV)", row=1, col=1)
    fig.update_yaxes(title_text="Absolute Error [eV]", row=1, col=2)
    fig.update_yaxes(title_text="Absolute Error [eV]", row=1, col=3)

    # Save plots
    safe_dir_name = directory.replace("/", "_")
    html_path = os.path.join(output_path, f"{safe_dir_name}_performance.html")
    png_path = os.path.join(output_path, f"{safe_dir_name}_performance.png")
    
    fig.write_image(png_path)
    fig.show()

    print(f"Saved:{png_path}")


def plot_all_experiments(directories: list[str],
                         base_path: str = "/home/tvanhout/oxides_ML/models/Experiments/",
                         output_path: str = "./experiment_plots/") -> None:
    """
    Loop through multiple experiment directories and generate/save plots.
    """
    for directory in directories:
        print(f"Processing: {directory}")
        title = f"Prediction Performance Overview – {directory}"
        plot_experiments(directory=directory,
                         base_path=base_path,
                         output_path=output_path,
                         plot_title=title)


In [None]:
dirs = ["Ex1", "Ex2", "Ex3", "Ex4", "Ex5", "Ex6", "Ex7", "Ex8", "Ex9", "test"]
plot_all_experiments(dirs)