# initial

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import ast

# ============================================================================
# CONFIGURATION - Replace these paths with your data locations
# Or set environment variables: HYPERPARAMETER_OPT_DIR
# ============================================================================

root_dir = os.environ.get(
    "HYPERPARAMETER_OPT_DIR",
    "./models/hyperparameter_optimization/initial_test/initial"
)

print(f"Hyperparameter optimization directory: {root_dir}")

datasets = ["Set1", "Set2", "Set3"]  # Add more dataset names as needed
QoI = "architecture_dim" 
independent_variable = QoI

all_data = []

for dataset in datasets:
    parent_dir = os.path.join(root_dir, dataset, QoI)
    summary_data = []
    
    for model_folder in os.listdir(parent_dir):
        model_path = os.path.join(parent_dir, model_folder)
        
        if os.path.isdir(model_path):
            summary_entry = {'Dataset': dataset, 'Model': model_folder}

            # Load test set
            test_set_path = os.path.join(model_path, "test_set.csv")
            if os.path.exists(test_set_path):
                df_test = pd.read_csv(test_set_path)
                summary_entry['MAE'] = df_test['Abs_error_eV'].mean()

            # Load hyperparameters from .txt file
            hyperparam_path = os.path.join(model_path, "input.txt")
            if os.path.exists(hyperparam_path):
                with open(hyperparam_path, "r") as file:
                    hyperparams = ast.literal_eval(file.read())

                def flatten_dict(d, parent_key='', sep='_'):

In [None]:
df_combined.loc[df_combined['MAE'].idxmin(), ["Dataset", "Model", independent_variable, "MAE"]]

# augment

In [None]:
root_dir = os.environ.get(
    "HYPERPARAMETER_OPT_DIR_AUGMENT",
    "./models/hyperparameter_optimization/initial_test/augment"
)

print(f"Hyperparameter optimization (augmented) directory: {root_dir}")

datasets = ["Set1","Set2","Set3", "Set15"] 

all_data = []

for dataset in datasets:
    parent_dir = os.path.join(root_dir, dataset, QoI)
    summary_data = []
    
    for model_folder in os.listdir(parent_dir):
        model_path = os.path.join(parent_dir, model_folder)
        
        if os.path.isdir(model_path):
            summary_entry = {'Dataset': dataset, 'Model': model_folder}

            # Load test set
            test_set_path = os.path.join(model_path, "test_set.csv")
            if os.path.exists(test_set_path):
                df_test = pd.read_csv(test_set_path)
                summary_entry['MAE'] = df_test['Abs_error_eV'].mean()

            # Load hyperparameters from .txt file
            hyperparam_path = os.path.join(model_path, "input.txt")
            if os.path.exists(hyperparam_path):
                with open(hyperparam_path, "r") as file:
                    hyperparams = ast.literal_eval(file.read())

                def flatten_dict(d, parent_key='', sep='_'):
                    items = []
                    for k, v in d.items():
                        new_key = f"{parent_key}{sep}{k}" if parent_key else k
                        if isinstance(v, dict):
                            items.extend(flatten_dict(v, new_key, sep=sep).items())
                        else:
                            items.append((new_key, v))
                    return dict(items)

                flat_hyperparams = flatten_dict(hyperparams)
                summary_entry.update(flat_hyperparams)

            summary_data.append(summary_entry)
    
    df_summary_augment = pd.DataFrame(summary_data)
    all_data.append(df_summary_augment)

# Combine all datasets into a single DataFrame
df_combined_augment = pd.concat(all_data, ignore_index=True)

# Sorting
df_combined_augment = df_combined_augment.sort_values(["Dataset", independent_variable], ascending=True).reset_index()

# Compute the average MAE across the datasets
df_avg_augment = df_combined_augment.groupby([independent_variable], as_index=False)['MAE'].mean()
df_avg_augment['Model'] = 'Average'

# Plot MAE across datasets

# Use a consistent palette
palette = sns.color_palette('hsv', n_colors=len(datasets))

plt.figure(figsize=(20, 6))

# First subplot
plt.subplot(1, 2, 1)
sns.barplot(data=df_combined_augment, x='Model', y='MAE', hue='Dataset', palette=palette, edgecolor='black', legend=False, zorder=10)
plt.title('Mean Absolute Error (MAE) Across Models and Datasets')
plt.ylabel('Mean Absolute Error (MAE) [eV]')
plt.xticks(rotation=45)
plt.grid(True)

# Second subplot
plt.subplot(1, 2, 2)
for i, dataset in enumerate(datasets):
    df_subset_augment = df_combined_augment[df_combined_augment['Dataset'] == dataset].sort_values(by=independent_variable)
    
    # Line connecting points within the dataset
    sns.lineplot(data=df_subset_augment, x=independent_variable, y='MAE', label=f"{dataset} (line)", color=palette[i], linewidth=1.5, zorder=5)
    
    # Scatter plot without custom label (avoid conflict with style)
    sns.scatterplot(data=df_subset_augment, x=independent_variable, y='MAE', style='Model', color=palette[i], s=100, edgecolor='black', zorder=10)

# Plot average line and scatter
df_avg_sorted_augment = df_avg_augment.sort_values(by=independent_variable)
sns.lineplot(data=df_avg_sorted_augment, x=independent_variable, y='MAE', color='black', linewidth=2.5, label='Average', zorder=12)

sns.scatterplot(data=df_avg_sorted_augment, x=independent_variable, y='MAE', color='black', marker='X', s=200, label='_nolegend_', zorder=15
)

plt.xlabel(independent_variable)
plt.ylabel('Mean Absolute Error (MAE) [eV]')
plt.title(f'MAE vs. {independent_variable} Across Datasets')
plt.grid(True)
plt.legend(title='Legend', bbox_to_anchor=(1.01, 1), loc='upper left')
plt.subplots_adjust(wspace=0.15)

if independent_variable == "train_minlr":
    plt.xscale('log')

plt.subplots_adjust(wspace=0.15)
plt.show()

In [None]:
df_combined_augment.loc[df_combined_augment['MAE'].idxmin(), ["Dataset", "Model", independent_variable, "MAE"]]

# Compare augment vs initial

In [None]:
# Use a consistent color palette
palette = sns.color_palette('hsv', n_colors=len(datasets))

# Create subplots
fig, ax = plt.subplots(1, 2, figsize=(20, 6), sharey=True)

# First subplot
for i, dataset in enumerate(datasets):
    df_subset = df_combined[df_combined['Dataset'] == dataset].sort_values(by=independent_variable)
    
    sns.lineplot(
        data=df_subset, x=independent_variable, y='MAE',
        label=f"{dataset} (line)", color=palette[i], linewidth=1.5,
        zorder=5, ax=ax[0]
    )
    sns.scatterplot(
        data=df_subset, x=independent_variable, y='MAE',
        style='Model', color=palette[i], s=100, edgecolor='black',
        zorder=10, ax=ax[0]
    )

# Average line and scatter
df_avg_sorted = df_avg.sort_values(by=independent_variable)
sns.lineplot(data=df_avg_sorted, x=independent_variable, y='MAE',
             color='black', linewidth=2.5, label='Average', zorder=12, ax=ax[0])
sns.scatterplot(data=df_avg_sorted, x=independent_variable, y='MAE',
                color='black', marker='X', s=200, label='_nolegend_', zorder=15, ax=ax[0])

# Aesthetics for first plot
ax[0].set_xlabel(independent_variable)
ax[0].set_ylabel('Mean Absolute Error (MAE) [eV]')
ax[0].set_title(f'MAE vs. {independent_variable} Across Datasets (initial dataset)')
ax[0].grid(True)
# ax[0].set_xscale('log' if independent_variable == "train_minlr" else 'linear')

# Hide legend on first plot
ax[0].get_legend().remove()


# Second subplot
for i, dataset in enumerate(datasets):
    df_subset_augment = df_combined_augment[df_combined_augment['Dataset'] == dataset].sort_values(by=independent_variable)
    
    sns.lineplot(
        data=df_subset_augment, x=independent_variable, y='MAE',
        label=f"{dataset} (line)", color=palette[i], linewidth=1.5,
        zorder=5, ax=ax[1]
    )
    sns.scatterplot(
        data=df_subset_augment, x=independent_variable, y='MAE',
        style='Model', color=palette[i], s=100, edgecolor='black',
        zorder=10, ax=ax[1]
    )

# Average line and scatter
df_avg_sorted_augment = df_avg_augment.sort_values(by=independent_variable)
sns.lineplot(data=df_avg_sorted_augment, x=independent_variable, y='MAE',
             color='black', linewidth=2.5, label='Average', zorder=12, ax=ax[1])
sns.scatterplot(data=df_avg_sorted_augment, x=independent_variable, y='MAE',
                color='black', marker='X', s=200, label='_nolegend_', zorder=15, ax=ax[1])

# Aesthetics for second plot
ax[1].set_xlabel(independent_variable)
ax[1].set_ylabel('Mean Absolute Error (MAE) [eV]')
ax[1].set_title(f'MAE vs. {independent_variable} Across Datasets (augmented dataset)')
# ax[1].set_xscale('log' if independent_variable == "train_minlr" else 'linear')
ax[1].grid(True)

# Add legend only to second plot
ax[1].legend(title='Molecule Group', bbox_to_anchor=(1.01, 1), loc='upper left')

# Layout adjustments
plt.subplots_adjust(wspace=0.25)
plt.tight_layout()
plt.show()


In [None]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ast

# ============================================================================
# CONFIGURATION - Replace this path with your data location
# Or set environment variable: HYPERPARAMETER_OPT_BASE_DIR
# ============================================================================

hyperparameter_opt_base_dir = os.environ.get(
    "HYPERPARAMETER_OPT_BASE_DIR",
    "./models/hyperparameter_optimization"
)

print(f"Base hyperparameter optimization directory: {hyperparameter_opt_base_dir}")

def plot_mae_barplot_with_whiskers(QoI: str, database: str):
    """
    Plot barplots of MAE with min/max whiskers for a given QoI or baseline across datasets.

    Parameters:
        QoI (str): Name of the hyperparameter or 'none' for baseline.
        database (str): Name of the subdirectory under the root path.
    """
    root_dir = os.path.join(hyperparameter_opt_base_dir, database)
    independent_variable = QoI.replace("-", "_") if QoI != "none" else "Model"

    # --- Determine datasets ---
    datasets = sorted({
        d
        for run_folder in os.listdir(root_dir)
        if os.path.isdir(os.path.join(root_dir, run_folder))
        for d in os.listdir(os.path.join(root_dir, run_folder))
        if os.path.isdir(os.path.join(root_dir, run_folder, d))
    })

    all_data = []

    def flatten_dict(d, parent_key='', sep='_'):
        items = []
        for k, v in d.items():
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
            if isinstance(v, dict):
                items.extend(flatten_dict(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)

    # --- Traverse directory and collect data ---
    for run_folder in os.listdir(root_dir):
        run_path = os.path.join(root_dir, run_folder)
        if not os.path.isdir(run_path):
            continue

        for dataset in datasets:
            dataset_path = os.path.join(run_path, dataset)
            if not os.path.isdir(dataset_path):
                continue

            model_folders = ["base"] if QoI == "none" else [QoI]
            for model_folder in model_folders:
                model_path = os.path.join(dataset_path, model_folder)
                if not os.path.isdir(model_path):
                    continue

                param_folders = [""] if QoI == "none" else os.listdir(model_path)

                for param_val in param_folders:
                    full_path = os.path.join(model_path, param_val) if QoI != "none" else model_path
                    if not os.path.isdir(full_path):
                        continue

                    input_path = os.path.join(full_path, "input_config")
                    entry = {
                        "Dataset": dataset,
                        "Run": run_folder,
                        "Model": "base" if QoI == "none" else f"{QoI}_{param_val}"
                    }

                    test_csv = os.path.join(input_path, "test_set.csv")
                    if os.path.exists(test_csv):
                        df = pd.read_csv(test_csv)
                        entry["MAE"] = df["Abs_error_eV"].mean()

                    param_txt = os.path.join(input_path, "input.txt")
                    if os.path.exists(param_txt):
                        with open(param_txt, "r") as f:
                            hyperparams = ast.literal_eval(f.read())
                        flat_params = flatten_dict(hyperparams)
                        entry.update(flat_params)

                    all_data.append(entry)

    # --- Data aggregation ---
    df_combined = pd.DataFrame(all_data)
    df_combined = df_combined.sort_values(["Dataset", independent_variable]).reset_index(drop=True)

    df_avg_runs = df_combined.groupby(["Dataset", independent_variable]).agg(MAE=("MAE", "mean")).reset_index()
    df_whiskers = df_combined.groupby(["Dataset", independent_variable]).agg(
        MAE_min=("MAE", "min"),
        MAE_max=("MAE", "max")
    ).reset_index()

    df_plot = pd.merge(df_avg_runs, df_whiskers, on=["Dataset", independent_variable])

    # --- Plotting ---
    sns.set(style="whitegrid", font_scale=1.2)
    palette = sns.color_palette("Set2", n_colors=len(datasets))

    x_data = "Dataset" if QoI == "none" else independent_variable

    plt.figure(figsize=(12, 6))

    sns.barplot(data=df_plot, x=x_data, y='MAE', hue='Dataset', palette=palette, errorbar=None)

    for i, dataset in enumerate(datasets):
        subset = df_plot[df_plot["Dataset"] == dataset]
        for _, row in subset.iterrows():
            plt.errorbar(
                x=row[x_data], 
                y=row['MAE'], 
                yerr=[[row['MAE'] - row['MAE_min']], [row['MAE_max'] - row['MAE']]],  # asymmetric error
                fmt='none',
                ecolor='black',
                elinewidth=2.5,
                capsize=8,
                capthick=2
            )

    plt.xlabel(QoI if QoI != "none" else "Model")
    plt.ylabel("Mean Absolute Error (MAE) [eV]")
    plt.title(f"MAE vs. {QoI if QoI != 'none' else 'Baseline Model'} Across Datasets")
    plt.grid(True)

    if QoI != "none":
        plt.legend(title="Legend", bbox_to_anchor=(1.01, 1), loc="upper left")

    if "lr" in QoI:
        plt.xscale("log")

    plt.tight_layout()
    plt.show()

In [None]:
plot_mae_barplot_with_whiskers("none", "CN_database_1")
plot_mae_barplot_with_whiskers("none", "BS_database_1")