# Analysis of the results of the ablation study

This notebook has been written to run on the TU Ilmenau cluster with my specific setup.
However, it can also be used simply to display the results of the ablation study. 
A summary of all results is also stored in the directory `/research/ablation_data`.

In [None]:
from __future__ import annotations

import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as mticker
from copy import deepcopy

import optimetal.factory as factory
from optimetal.utils import get_model_parameters, load_plot_style
load_plot_style()

# manually designed OptiMetal2B architecture (similar to OptiMate)
BASE_CONFIG_2B = {
    "architecture": {
        "type": "optimetal_2b",
        "node_embedding_dict": {
            "type": "group_period"
        },
        "edge_embedding_dict": {
            "type": "gaussian",
            "num_basis": 64,
            "basis_width": 2.0,
            "apply_envelope": False,
        },
        "message_passing_dict": {
            "num_layers": 2,
            "type": "gatv2",
            "heads": 4,
            "hidden_multiplier": 4,
        },
        "pooling_dict": {
            "type": "vector_attention"
        },
        "hidden_dim": 256,
        "spectra_dim": 1024,
        "depth": 2,
        "activation": "relu",
        "twobody_cutoff": 5.5,
    },
}

def make_perm_name(perm: dict) -> str:
    """
    Create a unique directory name from a permutation dictionary.
    """
    fragments = [f"{k.replace("_", "")}={str(v)}" for k, v in sorted(perm.items())]
    name = "_".join(fragments).replace("/", "-").replace(" ", "").replace(".", "")
    return name

def load_variation(study_path: str, model_type: str, variation: str) -> pd.DataFrame:
    """
    Load all permutations of a specific model layer variation, average the results across 
    all random seeds, and then return the results as a DataFrame sorted by validation loss.
    """
    records = []
    model_dirs = [f for f in os.listdir(study_path) if f"optimetal{model_type:s}" in f and os.path.isdir(os.path.join(study_path, f))]
    for model_dir in model_dirs:
        model_path = os.path.join(study_path, model_dir)
        perm_dirs = [f for f in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, f))]
        for perm_dir in perm_dirs:
            if variation not in perm_dir:
                continue
            perm_path = os.path.join(model_path, perm_dir)
            perms = [f for f in os.listdir(perm_path) if os.path.isdir(os.path.join(perm_path, f))]
            for perm in perms:
                perm_file = os.path.join(perm_path, perm, "_done")
                if not os.path.exists(perm_file):
                    continue
                with open(perm_file, "r") as f:
                    data = json.load(f)
                rec = {
                    "perm": data["perm"],
                    "seed": data["seed"],
                    "val_loss": data["val_loss"],
                    "lr": data["lr_wd"]["lr"],
                    "wd": data["lr_wd"]["wd"],
                }
                records.append(rec)

    # create DataFrame from records
    df = pd.DataFrame(records).sort_values(by="val_loss")
    df["perm_str"] = df["perm"].apply(lambda d: json.dumps(d, sort_keys=True)) # make the dictionary a string (hashable for sorting and comparison)
    df = df.drop(columns=["perm"])

    # average over seeds
    df_avg = (
        df
        .groupby(["perm_str", "lr", "wd"], as_index=False)
        .agg(
            mean_val_loss=("val_loss", "mean"),
            std_val_loss=("val_loss", "std"),
            n_seeds=("val_loss", "count"),
        )
        .sort_values("mean_val_loss")
    )
    return df_avg

def generate_factorial_configs(
    base_config: dict,
    node_perms: list[dict],
    edge_perms: list[dict],
    mp_perms: list[dict],
    pool_perms: list[dict],
) -> list[dict]:
    """
    Generate all combinations of architecture configurations based on the provided
    permutations for node, edge, message passing, and pooling layers.
    """
    configs = []
    for node_cfg in node_perms:
        for edge_cfg in edge_perms:
            for mp_cfg in mp_perms:
                for pool_cfg in pool_perms:
                    cfg = deepcopy(base_config)
                    arch = cfg["architecture"]
                    arch["node_embedding_dict"] = node_cfg
                    arch["edge_embedding_dict"] = edge_cfg
                    arch["message_passing_dict"] = mp_cfg
                    arch["pooling_dict"] = pool_cfg
                    configs.append(cfg)
    return configs

def load_variation_interaction(study_path: str) -> pd.DataFrame:
    """
    Load all permutation from the layer interaction part of the ablation study, 
    average the results over all random seeds, and return as a DataFrame sorted by the validation loss.
    """
    records = []
    model_dirs = [f for f in os.listdir(study_path) if f"optimetal2b" in f and os.path.isdir(os.path.join(study_path, f))]
    for model_dir in model_dirs:
        model_path = os.path.join(study_path, model_dir)
        arch_dirs = [f for f in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, f))]
        for arch_dir in arch_dirs:
            arch_path = os.path.join(model_path, arch_dir)
            done_check = [f for f in os.listdir(arch_path) if f == "_done"]
            if done_check is None:
                continue
            perm_file = os.path.join(arch_path, "_done")
            if not os.path.exists(perm_file):
                continue
            with open(perm_file, "r") as f:
                data = json.load(f)
            rec = {
                "index": int(re.findall(r"\d+", arch_dir)[0]),
                "perm": data["perm"],
                "seed": data["seed"],
                "val_loss": data["val_loss"],
                "lr": data["lr_wd"]["lr"],
                "wd": data["lr_wd"]["wd"],
            }
            records.append(rec)

    # create DataFrame from records
    df = pd.DataFrame(records).sort_values(by="val_loss")
    df["perm_str"] = df["perm"].apply(lambda d: json.dumps(d, sort_keys=True)) # make the dictionary a string (hashable for sorting and comparison)
    df = df.drop(columns=["perm"])

    # average over seeds
    df_avg = (
        df
        .groupby(["index", "perm_str", "lr", "wd"], as_index=False)
        .agg(
            mean_val_loss=("val_loss", "mean"),
            std_val_loss=("val_loss", "std"),
            n_seeds=("val_loss", "count"),
        )
        .sort_values("mean_val_loss")
    )
    return df_avg

# Part 1 - OptiMetal2B Layer Variations

In [None]:
# we performed the full ablation study on the OptiMetal2B architecture
study_path = "/scratch/magr4985/Ablation"
model_type = "2b"

# nicer display of dataframes
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000

# directory to save the results
output_dir = "./ablation_data"
os.makedirs(output_dir, exist_ok=True)

# check if the study path exists, else just load in the results already stored in JSON files
if os.path.exists(study_path):
    print(f"Study path {study_path:s} exists, loading results from there")
    analysis_flag = True
else:
    print(f"Loading results from JSON files")
    analysis_flag = False

In [None]:
# load the node embedding variations and print the average results
json_path = os.path.join(output_dir, "2b_node_embedding_results.json")
if analysis_flag:
    df_node_avg = load_variation(study_path=study_path, model_type=model_type, variation="node")
    # save the node embedding results to a JSON file
    df_node_avg.to_json(json_path, orient="records", indent=4)
else:
    df_node_avg = pd.read_json(json_path, orient="records")
display(df_node_avg)

In [None]:
# load the edge embedding variations and print the average results
json_path = os.path.join(output_dir, "2b_edge_embedding_results.json")
if analysis_flag:
    df_edge_avg = load_variation(study_path=study_path, model_type=model_type, variation="edge")
    # save the node embedding results to a JSON file
    df_edge_avg.to_json(json_path, orient="records", indent=4)
else:
    df_edge_avg = pd.read_json(json_path, orient="records")
display(df_edge_avg)

In [None]:
# load the message passing variations and print the average results
json_path = os.path.join(output_dir, "2b_mp_results.json")
if analysis_flag:
    df_mp_avg = load_variation(study_path=study_path, model_type=model_type, variation="mp")
    # save the node embedding results to a JSON file
    df_mp_avg.to_json(json_path, orient="records", indent=4)
else:
    df_mp_avg = pd.read_json(json_path, orient="records")
display(df_mp_avg)

In [None]:
# load the pooling variations and print the average results
json_path = os.path.join(output_dir, "2b_pool_results.json")
if analysis_flag:
    df_pool_avg = load_variation(study_path=study_path, model_type=model_type, variation="pool")
    # save the node embedding results to a JSON file
    df_pool_avg.to_json(json_path, orient="records", indent=4)
else:
    df_pool_avg = pd.read_json(json_path, orient="records")
display(df_pool_avg)

# Part 2 - OptiMetal2B Layer Interaction

In [None]:
# gather the top two performing configurations for each layer type
# (for the edge embedding we only take the top one, as there is only one variation)
top_node_perms = [
    json.loads(s) for s in df_node_avg["perm_str"].head(2)
]
top_edge_perms = [
    json.loads(s) for s in df_edge_avg["perm_str"].head(1)
]
top_mp_perms = [
    json.loads(s) for s in df_mp_avg["perm_str"].head(2)
]
top_pool_perms = [
    json.loads(s) for s in df_pool_avg["perm_str"].head(2)
]

# build interaction configurations
configs = generate_factorial_configs(
    BASE_CONFIG_2B,
    top_node_perms,
    top_edge_perms,
    top_mp_perms,
    top_pool_perms,
)

# save the configurations to a JSON file
config_path = "2b_interaction_configs.json"
output_path = os.path.join(output_dir, config_path)
with open(output_path, "w") as f:
    json.dump(configs, f, indent=4)

In [None]:
# check the results of the layer interaction part of the ablation study
study_path = "/scratch/magr4985/Ablation_Interaction"

# nicer display of dataframes
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000

# directory to save the results
output_dir = "./ablation_data"
os.makedirs(output_dir, exist_ok=True)

# check if the study path exists, else just load in the results already stored in JSON files
if os.path.exists(study_path):
    print(f"Study path {study_path:s} exists, loading results from there")
    analysis_flag = True
else:
    print(f"Loading results from JSON files")
    analysis_flag = False

# load and print the average results
json_path = os.path.join(output_dir, "2b_interaction_results.json")
if analysis_flag:
    df_avg_interaction = load_variation_interaction(study_path=study_path)
    # save the node embedding results to a JSON file
    df_avg_interaction.to_json(json_path, orient="records", indent=4)
else:
    df_avg_interaction = pd.read_json(json_path, orient="records")
display(df_avg_interaction)

# Part 3 - OptiMetal3B Message Passing Variations

In [None]:
# OptiMetal3B architecture was designed based on the results of the ablation study for OptiMetal2B
# (taken from the results of interaction study of OptiMetal2B, see 'research/ablation_results.ipynb')
BASE_CONFIG_3B = {
    "architecture": {
        "type": "optimetal_3b",
        "node_embedding_dict": {
            "type": "group_period"
        },
        "edge_embedding_dict": {
            "type": "gaussian",
            "num_basis": 32,
            "basis_width": 4.0,
            "apply_envelope": True,
        },
        "angle_embedding_dict": {
            "l_max": 3,
        },
        "triplet_block_dict": {
            "num_layers": 2,
            "edge_graph_mp_dict": {
                "type": "cgconv",
                "hidden_multiplier": 6,
            },
            "node_graph_mp_dict": {
                "type": "cgconv",
                "hidden_multiplier": 6,
            }
        },
        "pooling_dict": {
            "type": "vector_attention",
        },
        "hidden_dim": 256,
        "spectra_dim": 1024,
        "depth": 2,
        "activation": "relu",
        "twobody_cutoff": 5.5,
    },
} 

In [None]:
# we conducted an ablation study on the message passing layer of the OptiMetal3B architecture
# (the other layers were fixed based on the results from Part 2 of the ablation study, i.e., the layer interaction part)
study_path = "/scratch/magr4985/Ablation"
model_type = "3b"

# nicer display of dataframes
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 1000

# directory to save the results
output_dir = "./ablation_data"
os.makedirs(output_dir, exist_ok=True)

# load the message passing variations and print the average results
# check if the study path exists, else just load in the results already stored in JSON file
json_path = os.path.join(output_dir, f"3b_mp_results.json")
if os.path.exists(study_path):
    print(f"Study path {study_path:s} exists, loading results from there")
    df_mp_avg_3b = load_variation(study_path=study_path, model_type=model_type, variation="mp")
    # save the node embedding results to a JSON file
    df_mp_avg_3b.to_json(json_path, orient="records", indent=4)
else:
    print(f"Loading results from JSON files")
    df_mp_avg_3b = pd.read_json(json_path, orient="records")
display(df_mp_avg_3b)

# Part 4 - Analyze the influence of increased model capacity on the results

In [None]:
"""
Helper functions to extract the number of parameters using the permutation strings.
"""

def get_model_parameters_from_2b_perm(perm_type: str, perm_str: str) -> int:
    perm_name = {
        "node": "node_embedding_dict",
        "edge": "edge_embedding_dict",
        "mp": "message_passing_dict",
        "pool": "pooling_dict",
    }.get(perm_type)
    perm = json.loads(perm_str)
    arch_dict = deepcopy(BASE_CONFIG_2B["architecture"])
    arch_dict[perm_name] = perm
    return get_model_parameters(factory.create_model(arch_dict))

def annotate_2b(df: pd.DataFrame, perm_type: str) -> pd.DataFrame:
    df = df.copy()
    df["num_params"] = df["perm_str"].apply(lambda s: get_model_parameters_from_2b_perm(perm_type, s))
    df["type"] = df["perm_str"].apply(lambda s:json.loads(s).get("type"))
    if perm_type == "mp":
        df["hidden_multiplier"] = df["perm_str"].apply(lambda s: json.loads(s).get("hidden_multiplier"))
    return df

def get_model_parameters_from_2b_interaction(arch_str: str) -> int:
    arch_dict = json.loads(arch_str)["architecture"]
    return get_model_parameters(factory.create_model(arch_dict))

def annotate_2b_interaction(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["num_params"] = df["perm_str"].apply(lambda s: get_model_parameters_from_2b_interaction(s))
    ranking = []
    for perm in df["perm_str"]:
        rank = []
        perm = json.loads(perm)
        if perm["architecture"]["node_embedding_dict"] == {"type": "group_period"}:
            rank.append(1)
        else:
            rank.append(2)
        if perm["architecture"]["message_passing_dict"]["type"] == "transformer":
            rank.append(1)
        else:
            rank.append(2)
        if perm["architecture"]["pooling_dict"] == {"type": "vector_attention"}:
            rank.append(1)
        else:
            rank.append(2)
        ranking.append(sum(rank) / 3)
    df["ranking"] = ranking
    return df

def get_model_parameters_from_3b_mp_perm(perm_str: str) -> int:
    perm = json.loads(perm_str)
    arch_dict = deepcopy(BASE_CONFIG_3B["architecture"])
    arch_dict["triplet_block_dict"]["edge_graph_mp_dict"] = perm
    arch_dict["triplet_block_dict"]["node_graph_mp_dict"] = perm
    return get_model_parameters(factory.create_model(arch_dict))

def annotate_3b(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["num_params"] = df["perm_str"].apply(lambda s: get_model_parameters_from_3b_mp_perm(s))
    df["type"] = df["perm_str"].apply(lambda s:json.loads(s).get("type"))
    df["hidden_multiplier"] = df["perm_str"].apply(lambda s: json.loads(s).get("hidden_multiplier"))
    return df

def ablation_plot(
    df: pd.DataFrame, 
    ref_loss: list[float, float],
    color_key: str, 
    title: str  | None = None,
    leg_loc: str | None = None,
    save_path: str | None = None,
) -> None:
    # helper variables
    ms = 4 # marker size
    ytick_format = "%.2f"
    palette = ["k", "tab:orange", "tab:blue", "tab:green"]
    type_dict = {
        "group_period": r"Group $\vert\vert$ Period",
        "atom": "Atomic number",
        "gaussian": "Gaussian",
        "bessel": "Bessel",
        "cgconv": "CGConv",
        "gatv2": "GATv2Conv",
        "transformer": "TransformerConv",
        "mean": "Mean",
        "set2set": "Set2Set",
        "scalar_attention": "Scalar Attention",
        "vector_attention": "Vector Attention",
        2: r"$m_\mathrm{MLP}=2$",
        4: r"$m_\mathrm{MLP}=4$",
        6: r"$m_\mathrm{MLP}=6$",
    }
    # figure setup
    fig, ax = plt.subplots(figsize=(3, 3))
    # highlight the loss of the starting architecture
    ref_mean, ref_std = ref_loss
    ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
    ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
    # differentiate between layer types through different color
    if color_key is not None:
        if color_key == "ranking":
            vals = pd.to_numeric(df[color_key], errors="coerce") 
            vmin, vmax = float(vals.min()), float(vals.max())
            denom = (vmax - vmin) if (vmax - vmin) > 0 else 1.0
            black_rgb = np.array(mcolors.to_rgb("k"))
            orange_rgb = np.array(mcolors.to_rgb("tab:orange"))
            ax.axhline(y=df["mean_val_loss"].min(), color="tab:blue", linestyle="--", zorder=-1)
            for idx, row in df.iterrows():
                r = float(vals.loc[idx])
                t = (vmax - r) / denom 
                c = black_rgb * (1.0 - t) + orange_rgb * t 
                if row[color_key] < 1.001: 
                    label = r"1st---1st"
                elif row[color_key] > 1.999:
                    label = r"2nd---2nd"
                else:
                    label = None
                ax.errorbar(
                    x=row["num_params"], 
                    y=row["mean_val_loss"],
                    yerr=row["std_val_loss"],
                    fmt="o",
                    markersize=4,
                    markeredgecolor=c,
                    markerfacecolor=c,
                    ecolor=c,
                    capsize=4,
                    linestyle="none",
                    label=label,
                )
        else:
            labels = df[color_key]
            uniq = sorted(labels.unique())
            color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
            for lab in uniq:
                sub = df[df[color_key] == lab]
                ax.errorbar(
                    x=sub["num_params"],
                    y=sub["mean_val_loss"],
                    yerr=sub["std_val_loss"],
                    fmt="o",
                    markersize=4,
                    markeredgecolor=color_of[lab],
                    markerfacecolor=color_of[lab],
                    ecolor=color_of[lab],
                    capsize=4,
                    linestyle="none",
                    label=type_dict[lab],
                ) 
    else:
        raise ValueError("A 'color_key' must be given")
    # legend
    if leg_loc is None:
        ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
    else:
        ax.legend(title=title, handlelength=1.25, frameon=True, loc=leg_loc)
    # axis label and ticks
    ax.set_xlabel(r"$N$")
    ax.set_ylabel(r"$L_\mathrm{val}$")
    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
    # finish the figure
    fig.tight_layout()
    fig.align_labels()
    if save_path is not None:
        fig.savefig(save_path)
        
# directory where to store the figures
fig_dir = "./ablation_data/figures"
os.makedirs(fig_dir, exist_ok=True)

In [None]:
# OptiMetal2B ablation
df_2b_node = annotate_2b(df_node_avg, "node")
df_2b_edge = annotate_2b(df_edge_avg, "edge")
df_2b_mp = annotate_2b(df_mp_avg, "mp")
df_2b_pool = annotate_2b(df_pool_avg, "pool")

# OptiMetal2B ablation interaction 
df_2b_interaction = annotate_2b_interaction(df_avg_interaction)

# OptiMetal3B message passing ablation
df_3b_mp = annotate_3b(df_mp_avg_3b)

In [None]:
# OptiMetal2B ablation
ref_loss_2b_node = df_2b_node.loc[
    df_2b_node["perm_str"] == '{"type": "group_period"}'
][["mean_val_loss", "std_val_loss"]].values[0]
ablation_plot(df_2b_node, ref_loss_2b_node, color_key="type", title="Node Embedding", save_path=os.path.join(fig_dir, "2b_node.pdf"))

ref_loss_2b_edge = df_2b_edge.loc[
    df_2b_edge["perm_str"] == '{"apply_envelope": false, "basis_width": 2.0, "num_basis": 64, "type": "gaussian"}'
][["mean_val_loss", "std_val_loss"]].values[0]
ablation_plot(df_2b_edge, ref_loss_2b_edge, color_key="type", title="Edge Embedding", save_path=os.path.join(fig_dir, "2b_edge.pdf"))

ref_loss_2b_mp = df_2b_mp.loc[
    df_2b_mp["perm_str"] == '{"heads": 4, "hidden_multiplier": 4, "num_layers": 2, "type": "gatv2"}'
][["mean_val_loss", "std_val_loss"]].values[0]
ablation_plot(df_2b_mp, ref_loss_2b_mp, color_key="type", title="Message Passing", save_path=os.path.join(fig_dir, "2b_mp_1.pdf"))
ablation_plot(df_2b_mp, ref_loss_2b_mp, color_key="hidden_multiplier", title="Message Passing", save_path=os.path.join(fig_dir, "2b_mp_2.pdf"))

ref_loss_2b_pool = df_2b_pool.loc[
    df_2b_pool["perm_str"] == '{"type": "vector_attention"}'
][["mean_val_loss", "std_val_loss"]].values[0]
ablation_plot(df_2b_pool, ref_loss_2b_pool, color_key="type", title="Pooling", leg_loc="upper center", save_path=os.path.join(fig_dir, "2b_pool.pdf"))

# OptiMetal2B ablation interaction 
ref_loss_2b_interaction = np.mean(np.array([ref_loss_2b_node, ref_loss_2b_pool, ref_loss_2b_mp, ref_loss_2b_pool]), axis=0)
ablation_plot(df_2b_interaction, ref_loss_2b_interaction, color_key="ranking", title="Top 2 Interaction", save_path=os.path.join(fig_dir, "2b_interaction.pdf"))

# OptiMetal3B message passing ablation
ref_loss_3b = df_3b_mp.loc[
    df_3b_mp["perm_str"] == '{"hidden_multiplier": 6, "num_layers": 1, "type": "cgconv"}'
][["mean_val_loss", "std_val_loss"]].values[0]
ablation_plot(df_3b_mp, ref_loss_3b, color_key="type", title="3B Message Passing", save_path=os.path.join(fig_dir, "3b_mp_1.pdf"))
ablation_plot(df_3b_mp, ref_loss_3b, color_key="hidden_multiplier", title="3B Message Passing", save_path=os.path.join(fig_dir, "3b_mp_2.pdf"))

In [None]:
"""
Ablation plots for the OptiMetal2B.
"""

# figure setup
fig, axes = plt.subplots(2, 2, figsize=(6.5, 4.5))

# helper variables
ms = 4 # marker size
ytick_format = "%.2f"
palette = ["k", "tab:orange", "tab:blue", "tab:green"]
type_dict = {
    "group_period": r"Group $\vert\vert$ Period",
    "atom": "Atomic Number",
    "gaussian": "Gaussian",
    "bessel": "Bessel",
    "cgconv": "CGConv",
    "gatv2": "GATv2Conv",
    "transformer": "TransformerConv",
    "mean": "Mean",
    "set2set": "Set2Set",
    "scalar_attention": "Scalar Attention",
    "vector_attention": "Vector Attention",
    2: r"$m_\mathrm{MLP}=2$",
    4: r"$m_\mathrm{MLP}=4$",
    6: r"$m_\mathrm{MLP}=6$",
}

"""
OptiMetal2B - Node Embedding
"""

ax = axes[0, 0]
df = df_2b_node
color_key = "type"
title = "Node Embedding"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_node
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Edge Embedding
"""

ax = axes[0, 1]
df = df_2b_edge
color_key = "type"
title = "Edge Embedding"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_edge
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Message Passing (Type)
"""

ax = axes[1, 0]
df = df_2b_mp
color_key = "type"
title = "Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_mp
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Pooling
"""

ax = axes[1, 1]
df = df_2b_pool
color_key = "type"
title = "Pooling"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_pool
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper center")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")

# finish the figure
fig.tight_layout()
fig.align_labels()
fig.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(os.path.join(fig_dir, "optimate2b.pdf"))

In [None]:
""" 
Show the effect of the hidden multiplier on message passing in the OptiMate2B ablation study.
"""

# figure setup
fig, axes = plt.subplots(1, 2, figsize=(6.5, 3))

# helper variables
ms = 4 # marker size
ytick_format = "%.2f"
palette = ["k", "tab:orange", "tab:blue", "tab:green"]
type_dict = {
    "group_period": r"Group $\vert\vert$ Period",
    "atom": "Atomic Number",
    "gaussian": "Gaussian",
    "bessel": "Bessel",
    "cgconv": "CGConv",
    "gatv2": "GATv2Conv",
    "transformer": "TransformerConv",
    "mean": "Mean",
    "set2set": "Set2Set",
    "scalar_attention": "Scalar Attention",
    "vector_attention": "Vector Attention",
    2: r"$m_\mathrm{MLP}=2$",
    4: r"$m_\mathrm{MLP}=4$",
    6: r"$m_\mathrm{MLP}=6$",
}

"""
OptiMetal2B - Message Passing (Type)
"""

ax = axes[0]
df = df_2b_mp
color_key = "type"
title = "Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_mp
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Message Passing (Hidden Multiplier)
"""

ax = axes[1]
df = df_2b_mp
color_key = "hidden_multiplier"
title = "Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_mp
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")

# finish the figure
fig.tight_layout()
fig.align_labels()
fig.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(os.path.join(fig_dir, "optimate2b_mp_hidden_multiplier.pdf"))

In [None]:
"""
Show the effect of the hidden multiplier on message passing in the OptiMate3B ablation study.
"""

# figure setup
fig, axes = plt.subplots(1, 2, figsize=(6.5, 3))

# helper variables
ms = 4 # marker size
ytick_format = "%.2f"
palette = ["k", "tab:orange", "tab:blue", "tab:green"]
type_dict = {
    "group_period": r"Group $\vert\vert$ Period",
    "atom": "Atomic Number",
    "gaussian": "Gaussian",
    "bessel": "Bessel",
    "cgconv": "CGConv",
    "gatv2": "GATv2Conv",
    "transformer": "TransformerConv",
    "mean": "Mean",
    "set2set": "Set2Set",
    "scalar_attention": "Scalar Attention",
    "vector_attention": "Vector Attention",
    2: r"$m_\mathrm{MLP}=2$",
    4: r"$m_\mathrm{MLP}=4$",
    6: r"$m_\mathrm{MLP}=6$",
}

"""
OptiMetal3B - Message Passing (Type)
"""

ax = axes[0]
df = df_3b_mp
color_key = "type"
title = "3B Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_3b
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal3B - Message Passing (Hidden Multiplier)
"""

ax = axes[1]
df = df_3b_mp
color_key = "hidden_multiplier"
title = "3B Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_3b
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")

# finish the figure
fig.tight_layout()
fig.align_labels()
fig.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(os.path.join(fig_dir, "optimate3b_mp_hidden_multiplier.pdf"))

In [None]:
"""
Plot everything as into one figure
"""

# figure setup
fig, axes = plt.subplots(4, 2, figsize=(6.5, 8.5))

# helper variables
ms = 4 # marker size
ytick_format = "%.2f"
palette = ["k", "tab:orange", "tab:blue", "tab:green"]
type_dict = {
    "group_period": r"Group $\vert\vert$ Period",
    "atom": "Atomic Number",
    "gaussian": "Gaussian",
    "bessel": "Bessel",
    "cgconv": "CGConv",
    "gatv2": "GATv2Conv",
    "transformer": "TransformerConv",
    "mean": "Mean",
    "set2set": "Set2Set",
    "scalar_attention": "Scalar Attention",
    "vector_attention": "Vector Attention",
    2: r"$m_\mathrm{MLP}=2$",
    4: r"$m_\mathrm{MLP}=4$",
    6: r"$m_\mathrm{MLP}=6$",
}

"""
OptiMetal2B - Node Embedding
"""

ax = axes[0, 0]
df = df_2b_node
color_key = "type"
title = "Node Embedding"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_node
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Edge Embedding
"""

ax = axes[0, 1]
df = df_2b_edge
color_key = "type"
title = "Edge Embedding"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_edge
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Message Passing (Type)
"""

ax = axes[1, 0]
df = df_2b_mp
color_key = "type"
title = "Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_mp
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Message Passing (Hidden Multiplier)
"""

ax = axes[1, 1]
df = df_2b_mp
color_key = "hidden_multiplier"
title = "Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_mp
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Pooling
"""

ax = axes[2, 0]
df = df_2b_pool
color_key = "type"
title = "Pooling"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_2b_pool
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper center")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal2B - Interaction
"""

ax = axes[2, 1]
df = df_2b_interaction
color_key = "ranking"
title = "Interaction"
x = df["num_params"].values
y = df["mean_val_loss"].values
dy = df["std_val_loss"].values
ref_mean, ref_std = ref_loss_2b_interaction
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
vals = pd.to_numeric(df[color_key], errors="coerce") 
vmin, vmax = float(vals.min()), float(vals.max())
denom = (vmax - vmin) if (vmax - vmin) > 0 else 1.0
black_rgb = np.array(mcolors.to_rgb("k"))
orange_rgb = np.array(mcolors.to_rgb("tab:orange"))
for j, (idx, row) in enumerate(df.iterrows()):
    r = float(vals.loc[idx])
    t = (vmax - r) / denom 
    c = black_rgb * (1.0 - t) + orange_rgb * t 
    if j == 0:
        label = r"1st---1st"
    elif j == len(df) - 1:
        label = r"2nd---2nd"
    else:
        label = None
    ax.errorbar(
        x=row["num_params"], 
        y=row["mean_val_loss"],
        yerr=row["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=c,
        markerfacecolor=c,
        ecolor=c,
        capsize=4,
        linestyle="none",
        label=label,
    )
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
        
"""
OptiMetal3B - Message Passing (Type)
"""

ax = axes[3, 0]
df = df_3b_mp
color_key = "type"
title = "3B Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_3b
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
    
"""
OptiMetal3B - Message Passing (Hidden Multiplier)
"""

ax = axes[3, 1]
df = df_3b_mp
color_key = "hidden_multiplier"
title = "3B Message Passing"
labels = df[color_key]
uniq = sorted(labels.unique())
color_of = {lab: palette[i] for i, lab in enumerate(uniq)}
ref_mean, ref_std = ref_loss_3b
ax.axhspan(ref_mean - ref_std, ref_mean + ref_std, color="lightgray", alpha=0.5, zorder=0)
ax.axhline(ref_mean, color="dimgray", linestyle="-", linewidth=1.5, zorder=1)
for lab in uniq:
    sub = df[df[color_key] == lab]
    ax.errorbar(
        sub["num_params"],
        sub["mean_val_loss"],
        yerr=sub["std_val_loss"],
        fmt="o",
        markersize=4,
        markeredgecolor=color_of[lab],
        markerfacecolor=color_of[lab],
        ecolor=color_of[lab],
        capsize=4,
        linestyle="none",
        label=type_dict[lab],
    ) 
ax.legend(title=title, handlelength=1.25, frameon=True, loc="upper right")
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(ytick_format))
ax.set_xlabel(r"$N$")
ax.set_ylabel(r"$L_\mathrm{val}$")
            
# finish the figure
fig.tight_layout()
fig.align_labels()
fig.subplots_adjust(wspace=0.2, hspace=0.25)
fig.savefig(os.path.join(fig_dir, "all.pdf"))

# Part 5 - Evaluate model improvements

In [None]:
# best performing models
best_2b = float(df_2b_interaction["mean_val_loss"].min())
best_3b = float(df_3b_mp["mean_val_loss"].min())

# print the performance improvements
print(f"OptiMate2B improvement (architecture optimization)        = {100 * (1 - (best_2b / ref_loss_2b_interaction[0])):.2f}%")
print(f"OptiMate3B improvement (message passing optimization)     = {100 * (1 - (best_3b / ref_loss_3b[0])):.2f}%")
print(f"OptiMate3B improvement compared to unoptimized OptiMate2B = {100 * (1 - (best_3b / ref_loss_2b_interaction[0])):.2f}%")
print(f"OptiMate3B improvement compared to optimized OptiMate2B   = {100 * (1 - (best_3b / best_2b)):.2f}%")

# save the results in a file
with open(os.path.join(output_dir, "performance_improvement.txt"), "w") as f:
    f.write(f"OptiMate2B improvement (architecture optimization)        = {100 * (1 - (best_2b / ref_loss_2b_interaction[0])):.2f}%\n")
    f.write(f"OptiMate3B improvement (message passing optimization)     = {100 * (1 - (best_3b / ref_loss_3b[0])):.2f}%\n")
    f.write(f"OptiMate3B improvement compared to unoptimized OptiMate2B = {100 * (1 - (best_3b / ref_loss_2b_interaction[0])):.2f}%\n")
    f.write(f"OptiMate3B improvement compared to optimized OptiMate2B   = {100 * (1 - (best_3b / best_2b)):.2f}%\n")