# **1.Time Benchmark**

In [None]:
import pandas as pd

df = pd.read_csv("../../Data/AAM/results_benchmark/mapping_times.csv", index_col=0)
df.rename(columns={"Unnamed: 0": "dataset"}, inplace=True)
df = df / 60
df["Number of Reactions"] = [273, 382, 3000, 1758, 491]
df

In [None]:
import copy

df.rename(
    columns={
        "rxn_mapper": r"$\texttt{RXNMapper}$",
        "graphormer": r"$\texttt{GraphormerMapper}$",
        "local_mapper": r"$\texttt{LocalMapper}$",
        "rdt": r"$\texttt{RDT}$",
    },
    inplace=True,
)

df_time = copy.deepcopy(df)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from typing import List, Optional


def _calculate_stats(sub_df: pd.DataFrame, column: str, logarit=True) -> tuple:
    """
    Helper function to calculate the average and standard deviation of reaction times.

    Parameters:
    - sub_df (pd.DataFrame): Subset of the main DataFrame.
    - column (str): Name of the column for which to calculate statistics.

    Returns:
    - tuple: A tuple containing the average and standard deviation.
    """
    if logarit:
        avg = -np.log((sub_df[column].sum()) / sub_df["Number of Reactions"].sum())
        std = np.sqrt(
            np.sum((np.log(sub_df[column]) / sub_df["Number of Reactions"] - avg) ** 2)
            / sub_df["Number of Reactions"].sum()
        )
    else:
        avg = sub_df[column].sum() / sub_df["Number of Reactions"].sum()
        std = np.sqrt(
            np.sum((sub_df[column] / sub_df["Number of Reactions"] - avg) ** 2)
            / sub_df["Number of Reactions"].sum()
        )
    return avg, std


def plot_reaction_times(df: pd.DataFrame, ax: Axes) -> None:
    """
    Plot average reaction times for different datasets with error bars.

    This function visualizes the average reaction times for 'All', 'Biochemical',
    and 'Chemical' datasets, using error bars to indicate variability. It expects
    a DataFrame with reaction times across these categories and an Axes object
    to draw the plot.

    Parameters:
    - df (pd.DataFrame): DataFrame containing the reaction time data.
      It should have columns for each mapper (e.g., 'Mapper1', 'Mapper2', etc.)
      and a 'Number of Reactions' column for normalization.
    - ax (matplotlib.axes.Axes): Matplotlib Axes object where the plot will be drawn.

    Returns:
    - None: The function directly modifies the provided Axes object with the plot.

    Notes:
    - The function uses LaTeX rendering for text, which requires a LaTeX system installed.
    - It sets the y-axis limit based on assumed data range, which may need adjustment.
    """
    # Enable LaTeX rendering in matplotlib
    plt.rc("text", usetex=True)
    plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
    sns.set(style="darkgrid")  # Correct style for background grid

    # Define colors for better differentiation in the plot
    colors = ["#4c92c3", "#ffdd57", "#ff6f61"]

    # Define bar width and positions
    bar_width = 0.25
    x = np.arange(len(df.columns[:-1]))  # Position indexes for mappers

    # Plot each category for mappers with error bars
    for i, mapper in enumerate(df.columns[:-1]):
        all_avg, all_std = _calculate_stats(df, mapper)
        bio_avg, bio_std = _calculate_stats(df.iloc[:2], mapper)
        chem_avg, chem_std = _calculate_stats(df.iloc[2:], mapper)

        # Plotting bars for each category on the provided ax object
        ax.bar(
            x[i] - bar_width,
            all_avg,
            width=bar_width,
            color=colors[0],
            label=r"$\textit{All dataset}$" if i == 0 else "",
            yerr=all_std,
            capsize=5,
        )
        ax.bar(
            x[i],
            bio_avg,
            width=bar_width,
            color=colors[1],
            label=r"$\textit{Biochemical dataset}$" if i == 0 else "",
            yerr=bio_std,
            capsize=5,
        )
        ax.bar(
            x[i] + bar_width,
            chem_avg,
            width=bar_width,
            color=colors[2],
            label=r"$\textit{Chemical dataset}$" if i == 0 else "",
            yerr=chem_std,
            capsize=5,
        )

        # Add text labels for average values above each bar
        ax.text(
            x[i] - bar_width,
            all_avg,
            f"{all_avg:.2f}",
            ha="center",
            va="bottom",
            color="black",
            fontsize=18,
        )
        ax.text(
            x[i],
            bio_avg,
            f"{bio_avg:.2f}",
            ha="center",
            va="bottom",
            color="black",
            fontsize=18,
        )
        ax.text(
            x[i] + bar_width,
            chem_avg,
            f"{chem_avg:.2f}",
            ha="center",
            va="bottom",
            color="black",
            fontsize=18,
        )

    # Set labels, title, and ticks for the plot
    ax.set_ylabel(r"$-\log(minutes)$", fontsize=24, weight="bold", color="black")
    ax.set_xticks(x)
    ax.set_xticklabels(df.columns[:-1], fontsize=20, weight="bold", color="black")
    ax.set_title(r"A. Processing time benchmarking", fontsize=28, weight="bold")
    ax.tick_params(axis="y", labelsize=18, labelcolor="black")
    ax.legend(fontsize=20, loc="upper right", frameon=True, edgecolor="black")
    # ax.set_ylim(0, 7)  # Adjust this limit based on your data range

In [None]:
for i, mapper in enumerate(df_time.columns[:-1]):
    all_avg, all_std = _calculate_stats(df_time, mapper, logarit=False)
    bio_avg, bio_std = _calculate_stats(df_time.iloc[:2], mapper)
    chem_avg, chem_std = _calculate_stats(df_time.iloc[2:], mapper)

In [None]:
all_std

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# from _analysis._aam_analysis import plot_reaction_times

plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
sns.set_theme(style="darkgrid")
# Example of how to use the modified function in a subplot


# Example of how to use the modified function in a subplot
fig, ax = plt.subplots(figsize=(12, 6))

plot_reaction_times(df_time, ax)

# plt.tight_layout()
# fig.savefig("./fig/aam_time_benchmark.pdf", dpi=600, bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
np.log(6)

# **2. Success rate Benchmark**

In [None]:
import re


def calculate_mapping_failures(data):
    """
    Calculates the number of failed mappings based on the absence of atom maps in the reaction data.
    Args:
    data (list of dicts): A list of dictionaries where each dictionary represents reaction data.

    Returns:
    dict: A dictionary containing the number of fails and success rates for each mapper type.
    """

    # Keys to evaluate
    keys_to_check = ["rxn_mapper", "graphormer", "local_mapper", "rdt"]

    # Initialize results dictionary to store fails and successes
    results = {key: {"fails": 0, "successes": 0} for key in keys_to_check}

    # Regex pattern to find atom maps
    atom_map_pattern = re.compile(r":\d+")

    # Process each dictionary in the list
    for entry in data:
        for key in keys_to_check:
            # Get the reaction data
            reaction_data = entry.get(key, "")

            # Check if the reaction contains any atom maps
            if re.search(atom_map_pattern, reaction_data):
                results[key]["successes"] += 1
            else:
                results[key]["fails"] += 1

    # Prepare aggregate results to provide total fails and success rates
    aggregate_results = {}
    for key in keys_to_check:
        total_fails = results[key]["fails"]
        total_successes = results[key]["successes"]
        total_attempts = total_fails + total_successes
        success_rate = (total_successes / total_attempts) if total_attempts > 0 else 0

        aggregate_results[f"{key}_number_fails"] = int(total_fails)
        aggregate_results[f"{key}_success_rate"] = round(
            success_rate * 100, 2
        )  # Express as percentage

    return aggregate_results

In [None]:
import sys

sys.path.append("../../")
from syntemp.SynUtils.utils import load_database

data = load_database(
    "../../Data/AAM/results_benchmark/golden/golden_aam_reactions.json.gz"
)

In [None]:
from syntemp.SynAAM.aam_validator import AAMValidator

results, _ = AAMValidator.validate_smiles(
    data=data,
    ground_truth_col="ground_truth",
    mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    check_method="RC",
    ignore_aromaticity=False,
    n_jobs=4,
    verbose=2,
    ensemble=True,
    strategies=[
        ["rxn_mapper", "graphormer", "local_mapper"],
        ["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    ],
    ignore_tautomers=False,
)

In [None]:
import sys

sys.path.append("../..")
from syntemp.SynUtils.utils import load_database

data_paths = {
    "ecoli": "../../Data/AAM/results_benchmark/ecoli/ecoli_aam_reactions.json.gz",
    "recon3d": "../../Data/AAM/results_benchmark/recon3d/recon3d_aam_reactions.json.gz",
    "uspto_3k": "../../Data/AAM/results_benchmark/uspto_3k/uspto_3k_aam_reactions.json.gz",
    "golden": "../../Data/AAM/results_benchmark/golden/golden_aam_reactions.json.gz",
    "natcomm": "../../Data/AAM/results_benchmark/natcomm/natcomm_aam_reactions.json.gz",
}

# Dictionary to hold the results for each dataset
results_dict = {}

# Process each dataset
for dataset_name, filepath in data_paths.items():
    # Load the data
    data = load_database(filepath)

    # Calculate fails and success rates
    results = calculate_mapping_failures(data)

    # Store the results
    results_dict[dataset_name] = results

# Convert the results dictionary to a DataFrame
df_results = pd.DataFrame.from_dict(results_dict, orient="index")

# Display or save the DataFrame
df_results.T

# **3. Accuracy Benchmark**

In [None]:
import sys
import pandas as pd

sys.path.append("../..")
from syntemp.SynAAM.aam_validator import AAMValidator

In [None]:
data_paths = {
    "ecoli": "../../Data/AAM/results_benchmark/ecoli/ecoli_aam_reactions.json.gz",
    "recon3d": "../../Data/AAM/results_benchmark/recon3d/recon3d_aam_reactions.json.gz",
    "uspto_3k": "../../Data/AAM/results_benchmark/uspto_3k/uspto_3k_aam_reactions.json.gz",
    "golden": "../../Data/AAM/results_benchmark/golden/golden_aam_reactions.json.gz",
    "natcomm": "../../Data/AAM/results_benchmark/natcomm/natcomm_aam_reactions.json.gz",
}
results_dict = {}

# Process each dataset
for dataset_name, filepath in data_paths.items():
    # Load the data
    data = load_database(filepath)

    # Calculate fails and success rates
    results, _ = AAMValidator.validate_smiles(
        data=data,
        ground_truth_col="ground_truth",
        mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
        check_method="RC",
        ignore_aromaticity=False,
        n_jobs=4,
        verbose=2,
        ensemble=True,
        strategies=[
            ["rxn_mapper", "graphormer", "local_mapper"],
            ["rxn_mapper", "graphormer", "local_mapper", "rdt"],
        ],
        ignore_tautomers=False,
    )

    # Store the results
    results_dict[dataset_name] = results

bio = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)

    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ["ecoli", "recon3d"]:
        bio.extend(single)
results, _ = AAMValidator.validate_smiles(
    data=bio,
    ground_truth_col="ground_truth",
    mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    check_method="RC",
    ignore_aromaticity=False,
    n_jobs=4,
    verbose=2,
    ensemble=True,
    strategies=[
        ["rxn_mapper", "graphormer", "local_mapper"],
        ["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    ],
    ignore_tautomers=False,
)
results_dict["Biochemical"] = results


chem = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)

    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ["golden", "natcomm", "uspto_3k"]:
        chem.extend(single)

results, _ = AAMValidator.validate_smiles(
    data=chem,
    ground_truth_col="ground_truth",
    mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    check_method="RC",
    ignore_aromaticity=False,
    n_jobs=4,
    verbose=2,
    ensemble=True,
    strategies=[
        ["rxn_mapper", "graphormer", "local_mapper"],
        ["rxn_mapper", "graphormer", "local_mapper", "rdt"],
    ],
    ignore_tautomers=False,
)
results_dict["Chemical"] = results

In [None]:
# Initialize an empty DataFrame to hold all data
import pandas as pd

final_df = pd.DataFrame()

# Process each data type
for data_type, records in results_dict.items():
    # Create a DataFrame
    df = pd.DataFrame(records)

    # Add columns for accuracy and success_rate specific to the data type
    df["accuracy_col"] = data_type + "_accuracy"
    df["success_col"] = data_type + "_success_rate"
    df["Accuracy"] = round(df["accuracy"] * 100, 2)
    df["Success Rate"] = round(df["success_rate"], 2)

    # Pivot the DataFrame
    df_pivot = df.pivot(index="mapper", columns="accuracy_col", values="Accuracy").join(
        df.pivot(index="mapper", columns="success_col", values="Success Rate")
    )

    # Merge with the final DataFrame
    if final_df.empty:
        final_df = df_pivot
    else:
        final_df = final_df.join(df_pivot)

# Reset index to make 'mapper' a column
final_df.reset_index(inplace=True)
final_df = final_df.reindex([5, 2, 3, 4, 0, 1])
final_df = final_df.reset_index(drop=True)

In [None]:
aam_json = final_df.to_dict(orient="records")

In [None]:
from syntemp.SynUtils.utils import save_database

save_database(aam_json, "../../Data/AAM/results_benchmark/aam_benchmark.json.gz")

## 3.0 Load results

In [None]:
import sys
import pandas as pd

sys.path.append("../../")
from syntemp.SynUtils.utils import load_database

final_df = pd.DataFrame(
    load_database("../../Data/AAM/results_benchmark/aam_benchmark.json.gz")
)

In [None]:
final_df[
    [
        "mapper",
        "Biochemical_accuracy",
        "Biochemical_success_rate",
        "Chemical_accuracy",
        "Chemical_success_rate",
    ]
]

## 3.1 Heatmap

In [None]:
data_visual = final_df[
    [
        "mapper",
        "ecoli_accuracy",
        "recon3d_accuracy",
        "uspto_3k_accuracy",
        "golden_accuracy",
        "natcomm_accuracy",
        "Biochemical_accuracy",
        "Chemical_accuracy",
    ]
]
data_visual.rename(
    {
        "ecoli_accuracy": "ecoli",
        "recon3d_accuracy": "recon3d",
        "uspto_3k_accuracy": "uspto_3k",
        "golden_accuracy": "golden",
        "natcomm_accuracy": "natcomm",
        "Biochemical_accuracy": "Biochemical",
        "Chemical_accuracy": "Chemical",
    },
    axis=1,
    inplace=True,
)

In [None]:
data_visual["mapper"] = [
    r"$\texttt{RXNMapper}$",
    r"$\texttt{GraphormerMapper}$",
    r"$\texttt{LocalMapper}$",
    r"$\texttt{RDT}$",
    r"$\textit{Ensemble_1}$",
    r"$\textit{Ensemble_2}$",
]

In [None]:
data_visual

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
sns.set_theme(style="darkgrid")
from _analysis._aam_analysis import plot_heatmap

fig = plot_heatmap(data_visual, "./fig/FigS1_aam_accuracy_heatmap.pdf")
plt.show()

## 3.2 Barplot

In [None]:
df = final_df.copy()

In [None]:
from _analysis._aam_analysis import plot_accuracy_success_rate_subplot
import seaborn as sns
import matplotlib.pyplot as plt

plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
sns.set_theme(style="darkgrid")

fig, axes = plt.subplots(1, 2, figsize=(18, 8))
plot_accuracy_success_rate_subplot(
    df,
    ["Chemical_accuracy", "Biochemical_accuracy"],
    ["Chemical_success_rate", "Biochemical_success_rate"],
    [r"$\textit{Chemical dataset}$", r"$\textit{Biochemical dataset}$"],
    axes,
)

In [None]:
from _analysis._aam_analysis import (
    plot_accuracy_success_rate_subplot,
    # plot_reaction_times,
)
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


accuracy_cols = ["Chemical_accuracy", "Biochemical_accuracy"]
success_cols = ["Chemical_success_rate", "Biochemical_success_rate"]
titles = [r"$\textit{B. Chemical dataset}$", r"$\textit{C. Biochemical dataset}$"]

df = final_df.copy()
df["mapper"] = [
    r"$\texttt{RXNMapper}$",
    r"$\texttt{GraphormerMapper}$",
    r"$\texttt{LocalMapper}$",
    r"$\texttt{RDT}$",
    r"$\textit{Ensemble_1}$",
    r"$\textit{Ensemble_2}$",
]

# Create a 2x2 subplot layout
fig = plt.figure(figsize=(18, 16))
gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])

# First row: reaction times plot (spans the entire width)
ax0 = plt.subplot(gs[0, :])
plot_reaction_times(df_time, ax0)

# Second row: accuracy and success rate subplots
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[1, 1])
plot_accuracy_success_rate_subplot(df, accuracy_cols, success_cols, titles, [ax1, ax2])

# Adjust layout to prevent overlap
plt.tight_layout()

# Display the combined plot
fig.savefig(
    "./fig/Fig6_aam_time_data_benchmark.pdf", dpi=600, bbox_inches="tight", pad_inches=0
)
plt.show()