Script to produce grid plot of errors for reaction coordinate + solvation sampling:

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import PowerNorm, TwoSlopeNorm

Plot Grid Plots:

In [None]:
# SELECT MODELS:
model_names = ["GO-MACE-23", "MACE-OFF23_medium", "MACE-MP0-128-L1"]

# OTHER SETTINGS
set_relative_to_baseline = True
sig_figs = 3

all_model_energy_errors, all_model_force_errors = [], []
for model_name in model_names:

    # SELECT INPUT CSV FILE:
    # 1. combined vacuum + react in water test set
    input_csv_file = f"test_set_comparison_endo_DA_final_test_combined_{model_name}_endo_DA_n=8_r_dist_fine_tuned.csv"
    save_name = 'endo_DA_combined'
    col_name = 'Endo DA Combined Test'

    # 2. just react in water test set
    # input_csv_file = f"{CSV_DIR}/test_set_comparison_endo_DA_final_test_combined_react_in_water_{model_name}_endo_DA_n=8_r_dist_fine_tuned.csv"
    # save_name = 'endo_DA_combined_react_in_water'
    # col_name = 'Endo DA Combined Test React in Water'
    # ===========================

    df = pd.read_csv(input_csv_file)

    # get base model 
    base_model_energy_error, base_model_force_error = df.iloc[0][f"{col_name} Energies (meV / atom)"], df.iloc[0][f"{col_name} Forces (meV / A)"]
    df = df.iloc[1:]  # remove base model row
    print("Base Model: ", model_name)
    print("Base Model Energy Error: ", base_model_energy_error)
    print("Base Model Force Error: ", base_model_force_error)

    # group by solv frac
    df.sort_values(['Model'])
    df['run'] = df['Model'].str.extract(r'^(.*?)(?:_rep_\d+|_fine_tuned)')[0]
    df['solv'] = df['Model'].str.extract(r'solv_(\d+)').astype('int32')
    grouped = df.groupby('run')

    # for run, group in grouped:
    #     print('Run: ', run)
    #     print('Group: ', np.array(group[f"{col_name} Energies (meV / atom)"]))

    mean_df = grouped.mean(numeric_only=True)
    std_df = grouped.std(numeric_only=True)

    # print('Mean Df: \n', mean_df)#df[f"{col_name} Forces (meV / A)"])
    # print('Std Df: \n', std_df)#df[f"{col_name} Forces (meV / A)"])

    # Rename columns for clarity
    # Combine into one DataFrame
    # mean_df.columns = [f"{col}_mean" for col in mean_df.columns]
    # std_df.columns = [f"{col}_std" for col in std_df.columns]
    # df = pd.concat([mean_df, std_df], axis=1)
    # print('Concatenated DF: \n', df)

    # grouped = df.groupby('solv')
    mean_grouped = mean_df.groupby('solv')
    energy_grid_rows, force_grid_rows = [], []
    for solv, group in mean_grouped:
        # print('Group: ', np.array(group[f"{col_name} Energies (meV / atom)"]))
        energy_grid_rows.append(np.array(group[f"{col_name} Energies (meV / atom)"]))
        force_grid_rows.append(np.array(group[f"{col_name} Forces (meV / A)"]))
    abs_energy_grid_rows, abs_force_grid_rows = np.array(energy_grid_rows), np.array(force_grid_rows)

    print('Energy Grid Rows: \n', energy_grid_rows)
    print('Force Grid Rows: \n', force_grid_rows)

    # get stdevs
    std_grouped = std_df.groupby('solv')
    energy_stds, force_stds = [], []
    for solv, group in std_grouped:
        energy_stds.append(np.array(group[f"{col_name} Energies (meV / atom)"]))
        force_stds.append(np.array(group[f"{col_name} Forces (meV / A)"]))

    energy_stds, force_stds = np.array(energy_stds), np.array(force_stds)
    print('Energy Stds: \n', energy_stds)
    print('Force Stds: \n', force_stds)

    if set_relative_to_baseline:
        energy_grid_rows = abs_energy_grid_rows - base_model_energy_error
        force_grid_rows = abs_force_grid_rows - base_model_force_error
        print('\u0394 Energy Grid Rows: \n', energy_grid_rows)
        print('\u0394 Force Grid Rows: \n', force_grid_rows)#
    else:
        energy_grid_rows = abs_energy_grid_rows
        force_grid_rows = abs_force_grid_rows

    use_log_scale = False
    if set_relative_to_baseline:
        label_set = ['\u0394 Energy (FT - Baseline) RMSE (meV / atom)', '\u0394 Force (FT - Baseline) RMSE (meV / A)']
    else:
        label_set = ['Energy RMSE (meV / atom)', 'Force Comp. RMSE (meV / A)']
    type_set = ['Energy', 'Force']
    # y_ranges = [(-15, 15), (-730.0, 730.0)]
    y_ranges = [(None, None), (None, None)]
    abs_grid_rows_set = [abs_energy_grid_rows, abs_force_grid_rows]
    grid_rows_set = [energy_grid_rows, force_grid_rows]
    baseline_errors = [base_model_energy_error, base_model_force_error]
    for grid_rows, abs_grid_rows, baseline_error, label, type, y_range in zip(grid_rows_set, abs_grid_rows_set, baseline_errors, label_set, type_set, y_ranges):

        # normalise the data between -1.0 and 1.0
        # min_val = np.min(grid_rows)
        # max_val = np.max(grid_rows)
        # grid_rows = 2 * (grid_rows - min_val) / (max_val - min_val) - 1.0
        # print('Grid Rows Norm: \n', grid_rows)

        if use_log_scale:
            grid_rows = np.log(grid_rows)

        # get max point
        if set_relative_to_baseline:
            tick_range = np.max([abs(np.min(grid_rows)), abs(np.max(grid_rows))])
            min_tick = -tick_range
        else:
            min_tick = np.min(grid_rows)
            max_tick = 2 * baseline_error - np.min(grid_rows)

        # Create the plot
        fig, ax = plt.subplots()
        norm = None
        # norm = PowerNorm(gamma=0.6, vmin=y_range[0], vmax=y_range[1])
        # norm = TwoSlopeNorm(vmin=np.min(grid_rows), vmax=max(np.max(grid_rows), 0.1), vcenter=0)
        norm = TwoSlopeNorm(vmin=min_tick, vmax=tick_range, vcenter=0) if set_relative_to_baseline else TwoSlopeNorm(vmin=None, vmax=None, vcenter=baseline_error)
        # norm = TwoSlopeNorm(vmin=None, vmax=None, vcenter=baseline_error)
        # im = ax.imshow(grid_rows, cmap='viridis_r', aspect='equal')
        im = ax.imshow(grid_rows, cmap='RdYlBu_r', aspect='equal', norm=norm)
        # im = ax.imshow(grid_rows, cmap='RdYlGn_r', aspect='equal', norm=norm)

        # Annotate each cell with its value
        for i in range(abs_grid_rows.shape[0]):
            for j in range(abs_grid_rows.shape[1]):
                val = abs_grid_rows[i, j]
                ax.text(j, i, f'{val:.1f}', ha='center', va='center', color='black')

        # Add a colorbar
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label(label, rotation=270, labelpad=15)  # Customize label and padding

        # set custom ticks
        # min_tick = np.min(grid_rows)
        # max_tick = 2 * baseline_error - np.min(grid_rows)
        # step = max_tick / 10.0
        # cbar.set_ticks(list(np.arange(min_tick, max_tick, step)) + [baseline_error])
        # cbar.set_ticklabels([baseline_error])

        # Optional: add grid lines
        ax.set_xticks(np.arange(grid_rows.shape[1]+1) - 0.5, minor=True)
        ax.set_yticks(np.arange(grid_rows.shape[0]+1) - 0.5, minor=True)
        ax.grid(which="minor", color="white", linestyle='-', linewidth=2)
        ax.tick_params(which="minor", bottom=False, left=False)

        # Optional: remove axis ticks
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'{model_name} {type}')
        save_fp = f"{model_name}_{save_name}_{type}_react_solv_grid.png"
        plt.savefig(save_fp, format='png', dpi=600)

        plt.show()
        print(f'Saved plot to {save_fp}')