# **1.Time Benchmark**

In [None]:
import pandas as pd
df = pd.read_csv('../../Data/AAM/unbalance/mapping_times.csv', index_col=0)
df.rename(columns={'Unnamed: 0': 'dataset'}, inplace=True)
df['Number of Reactions'] = [273, 382, 3000, 1758, 491]
df

In [None]:
df.rename(columns={'rxn_mapper': r'$\texttt{RXNMapper}$', 
                   'graphormer': r'$\texttt{GraphormerMapper}$', 
                   'local_mapper': r'$\texttt{LocalMapper}$', 
                   'rdt': r'$\texttt{RDT}$'}, inplace=True)


In [None]:
import copy
df_time = copy.deepcopy(df)

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Enable LaTeX rendering in matplotlib
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')  # Ensure amsmath is loaded


colors = ["#4c92c3", "#ffdd57", "#ff6f61"]  # More vibrant colors for differentiation

# Plotting
plt.figure(figsize=(12, 6))
sns.set(style="darkgrid")  # Correct style for background grid

# Define bar width and positions
bar_width = 0.25
x = np.arange(len(df.columns[:-1]))  # Position indexes for mappers

# Plot each category for mappers with error bars
for i, mapper in enumerate(df.columns[:-1]):
    # 'All' category
    all_avg = df[mapper].sum() / df['Number of Reactions'].sum()
    all_std = np.sqrt(np.sum((df[mapper]/df['Number of Reactions']- all_avg) ** 2) / df['Number of Reactions'].sum())

    # 'Biochemical' category
    bio_avg = df[mapper][:2].sum() / df['Number of Reactions'][:2].sum()
    bio_std = np.sqrt(np.sum((df[mapper][:2]/ df['Number of Reactions'][:2] - bio_avg) ** 2) / df['Number of Reactions'][:2].sum())
 
    # 'Chemical' category
    chem_avg = df[mapper][2:].sum() / df['Number of Reactions'][2:].sum()
    chem_std = np.sqrt(np.sum((df[mapper][2:]/ df['Number of Reactions'][2:] - chem_avg) ** 2) / df['Number of Reactions'][2:].sum())

    plt.bar(x[i] - bar_width, all_avg, width=bar_width, color=colors[0], label=r'$\textit{All dataset}$' if i == 0 else "", yerr=all_std, capsize=5)
    plt.bar(x[i], bio_avg, width=bar_width, color=colors[1], label=r'$\textit{Biochemical dataset}$' if i == 0 else "", yerr=bio_std, capsize=5)
    plt.bar(x[i] + bar_width, chem_avg, width=bar_width, color=colors[2], label=r'$\textit{Chemical dataset}$' if i == 0 else "", yerr=chem_std, capsize=5)

    # Text labels for average times
    label_offset = 0.05  # Constant offset for label placement
    plt.text(x[i] - bar_width, all_avg + label_offset, f'{all_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)
    plt.text(x[i], bio_avg + label_offset, f'{bio_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)
    plt.text(x[i] + bar_width, chem_avg + label_offset, f'{chem_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)

#plt.xlabel('Mapper Type', fontsize=14, weight='bold', color='black')
plt.ylabel('Average Time per Reaction (seconds)', fontsize=20, weight='bold', color='black')
#plt.title('Average processing time per reaction for different AAM tools', fontsize=16, weight='bold', color='black')

plt.xticks(x, df.columns[:-1], rotation=45, fontsize=18, weight='bold', color='black')
plt.yticks(fontsize=18, weight='bold', color='black')
plt.legend(fontsize=18, loc='upper left', frameon=True, edgecolor='black')
plt.ylim(0, 7)  # Set y-axis limits to start from 0 to 20 seconds

plt.tight_layout()
plt.savefig('./fig/aam_time_benchmark.pdf', dpi = 600)
plt.show()


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def plot_reaction_times(df, ax):
    # Enable LaTeX rendering in matplotlib
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

    # Define colors for better differentiation in the plot
    colors = ["#4c92c3", "#ffdd57", "#ff6f61"]

    # Define bar width and positions
    bar_width = 0.25
    x = np.arange(len(df.columns[:-1]))  # Position indexes for mappers

    # Plot each category for mappers with error bars
    for i, mapper in enumerate(df.columns[:-1]):
        # 'All' category calculations
        all_avg = df[mapper].sum() / df['Number of Reactions'].sum()
        all_std = np.sqrt(np.sum((df[mapper] / df['Number of Reactions'] - all_avg) ** 2) / df['Number of Reactions'].sum())

        # 'Biochemical' category calculations
        bio_avg = df[mapper][:2].sum() / df['Number of Reactions'][:2].sum()
        bio_std = np.sqrt(np.sum((df[mapper][:2] / df['Number of Reactions'][:2] - bio_avg) ** 2) / df['Number of Reactions'][:2].sum())

        # 'Chemical' category calculations
        chem_avg = df[mapper][2:].sum() / df['Number of Reactions'][2:].sum()
        chem_std = np.sqrt(np.sum((df[mapper][2:] / df['Number of Reactions'][2:] - chem_avg) ** 2) / df['Number of Reactions'][2:].sum())

        # Plot bars with error bars on the provided ax object
        ax.bar(x[i] - bar_width, all_avg, width=bar_width, color=colors[0], label=r'$\textit{All dataset}$' if i == 0 else "", yerr=all_std, capsize=5)
        ax.bar(x[i], bio_avg, width=bar_width, color=colors[1], label=r'$\textit{Biochemical dataset}$' if i == 0 else "", yerr=bio_std, capsize=5)
        ax.bar(x[i] + bar_width, chem_avg, width=bar_width, color=colors[2], label=r'$\textit{Chemical dataset}$' if i == 0 else "", yerr=chem_std, capsize=5)

        # Text labels for average values
        label_offset = 0.05  # Constant offset for label placement
        ax.text(x[i] - bar_width, all_avg + label_offset, f'{all_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)
        ax.text(x[i], bio_avg + label_offset, f'{bio_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)
        ax.text(x[i] + bar_width, chem_avg + label_offset, f'{chem_avg:.2f}', ha='center', va='bottom', color='black', fontsize=18)

    # Labeling and aesthetics
    ax.set_ylabel('Average Time per Reaction (seconds)', fontsize=24, weight='bold', color='black')
    # Set ticks and tick labels
    ax.set_xticks(x)
    ax.set_xticklabels(df.columns[:-1], rotation=45, fontsize=18, weight='bold', color='black')
    ax.set_title(r'A. Processing time benchmarking', fontsize=28, weight='bold')
    # Customize y-tick labels using tick_params
    ax.tick_params(axis='y', labelsize=18, labelcolor='black')

    # Adding legend with custom settings
    ax.legend(fontsize=20, loc='upper left', frameon=True, edgecolor='black')

    # Set y-axis limits
    ax.set_ylim(0, 7)  # Adjust this limit based on your data range

# Example of how to use the modified function in a subplot
fig, ax = plt.subplots(figsize=(12, 6))

plot_reaction_times(df_time, ax)



# **2. Success rate Benchmark**

In [None]:
import re

def calculate_mapping_failures(data):
    """
    Calculates the number of failed mappings based on the absence of atom maps in the reaction data.
    Args:
    data (list of dicts): A list of dictionaries where each dictionary represents reaction data.

    Returns:
    dict: A dictionary containing the number of fails and success rates for each mapper type.
    """
    
    # Keys to evaluate
    keys_to_check = ['rxn_mapper', 'graphormer', 'local_mapper', 'rdt']
    
    # Initialize results dictionary to store fails and successes
    results = {key: {'fails': 0, 'successes': 0} for key in keys_to_check}
    
    # Regex pattern to find atom maps
    atom_map_pattern = re.compile(r':\d+')
    
    # Process each dictionary in the list
    for entry in data:
        for key in keys_to_check:
            # Get the reaction data
            reaction_data = entry.get(key, "")
            
            # Check if the reaction contains any atom maps
            if re.search(atom_map_pattern, reaction_data):
                results[key]['successes'] += 1
            else:
                results[key]['fails'] += 1
    
    # Prepare aggregate results to provide total fails and success rates
    aggregate_results = {}
    for key in keys_to_check:
        total_fails = results[key]['fails']
        total_successes = results[key]['successes']
        total_attempts = total_fails + total_successes
        success_rate = (total_successes / total_attempts) if total_attempts > 0 else 0

        aggregate_results[f"{key}_number_fails"] = int(total_fails)
        aggregate_results[f"{key}_success_rate"] = round(success_rate * 100, 2)  # Express as percentage
    
    return aggregate_results


In [None]:
import sys
sys.path.append('../../')
from SynTemp.SynUtils.utils import load_database
data = load_database('../../Data/AAM/unbalance/golden/golden_aam_reactions.json.gz')

In [None]:
test  =[value['rdt'] for value in data]

In [None]:
import re
def mapping_success_rate(list_mapping_data):
    """
    Calculate the success rate of entries containing atom mappings in a list of data strings.
    
    Parameters:
        list_mapping_in_data (list of str): List containing strings to be searched for atom mappings.
    
    Returns:
        float: The success rate of finding atom mappings in the list as a percentage.
    
    Raises:
        ValueError: If the input list is empty.
    """
    atom_map_pattern = re.compile(r':\d+')
    if not list_mapping_data:
        raise ValueError("The input list is empty, cannot calculate success rate.")
    
    success = sum(1 for entry in list_mapping_data if re.search(atom_map_pattern, entry))
    rate = 100 * (success / len(list_mapping_data))
    
    return round(rate, 2)


In [None]:
import networkx as nx
import pandas as pd
from typing import Dict, List, Tuple, Union, Optional
from rdkit import Chem
from operator import eq
from joblib import Parallel, delayed
from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match
from SynTemp.SynITS.its_construction import ITSConstruction
from SynTemp.SynITS.its_extraction import ITSExtraction
from SynTemp.SynChemistry.mol_to_graph import MolToGraph
from SynTemp.SynRule.rules_extraction import RuleExtraction
from SynTemp.SynUtils.chemutils import enumerate_tautomers, mapping_success_rate
from itertools import combinations


class AMMValidator:
    def __init__(self):
        """Initializes the AMMValidator class."""
        pass

    @staticmethod
    def graph_from_smiles(smiles: str) -> nx.Graph:
        """
        Constructs a graph representation from a SMILES string.

        Parameters:
            smiles (str): A SMILES string representing a molecule or a set of molecules.

        Returns:
            nx.Graph: A graph representation of the molecule(s).
        """
        mol = Chem.MolFromSmiles(smiles)
        graph = MolToGraph().mol_to_graph(mol, drop_non_aam=True)
        return graph

    @staticmethod
    def check_equivariant_graph(
        its_graphs: List[nx.Graph],
    ) -> Tuple[List[Tuple[int, int]], int]:
        """
        Checks for isomorphism among a list of ITS graphs and
        identifies all pairs of isomorphic graphs.

        Parameters:
        - its_graphs (List[nx.Graph]): A list of ITS graphs.

        Returns:
        - List[Tuple[int, int]]: A list of tuples representing
                pairs of indices of isomorphic graphs.
        - int: The count of unique isomorphic graph pairs found.
        """
        nodeLabelNames = ["typesGH"]
        nodeLabelDefault = ["*", False, 0, 0, ()]
        nodeLabelOperator = [eq, eq, eq, eq, eq]
        nodeMatch = generic_node_match(
            nodeLabelNames, nodeLabelDefault, nodeLabelOperator
        )
        edgeMatch = generic_edge_match("order", 1, eq)

        classified = []

        # Use combinations to check each unique pair of graphs without repetition
        for i, j in combinations(range(len(its_graphs)), 2):
            if nx.is_isomorphic(
                its_graphs[i], its_graphs[j], node_match=nodeMatch, edge_match=edgeMatch
            ):
                classified.append((i, j))

        return classified, len(classified)

    @staticmethod
    def smiles_check(
        mapped_smile: str,
        ground_truth: str,
        check_method: str = "RC",  # or 'ITS'
        ignore_aromaticity: bool = False,
    ) -> bool:
        """
        Checks the equivalence of mapped SMILES against ground truth
        using reaction center (RC) or ITS graph method.

        Parameters:
            mapped_smile (str): The mapped SMILES string.
            ground_truth (str): The ground truth SMILES string.
            check_method (str): The method used for validation
            ('RC' or 'ITS').
            ignore_aromaticity (bool): Flag to ignore aromaticity
                                        in ITS graph construction.

        Returns:
            bool: True if the mapped SMILES is equivalent to the ground truth,
                    False otherwise.
        """
        its_graphs = []
        rules_graphs = []
        try:
            for rsmi in [mapped_smile, ground_truth]:
                reactants_side, products_side = rsmi.split(">>")
                G = AMMValidator.graph_from_smiles(reactants_side)  # Reactants graph
                H = AMMValidator.graph_from_smiles(products_side)  # Products graph

                ITS = ITSConstruction.ITSGraph(G, H, ignore_aromaticity)
                its_graphs.append(ITS)

                rules = RuleExtraction.extract_reaction_rules(G, H, ITS, extend=False)
                rules_graphs.append(rules[2])

            _, equivariant = AMMValidator.check_equivariant_graph(
                rules_graphs if check_method == "RC" else its_graphs
            )

            return equivariant == 1

        except Exception as e:  # Catch more general exceptions
            print("An error occurred:", str(e))
            return False
        
    def smiles_check_tautomer(
        mapped_smile: str,
        ground_truth: str,
        check_method: str = "RC",  # or 'ITS'
        ignore_aromaticity: bool = False,
    ) -> Optional[bool]:
        """
        Determines if a given mapped SMILE string is equivalent to any tautomer of a ground truth SMILES string
        using a specified comparison method.

        The function first enumerates all possible tautomers of the ground truth SMILES and then checks
        if the mapped SMILE string matches any of these tautomers based on the specified method.

        Args:
            mapped_smile (str): The SMILES string to check against the tautomers of the ground truth.
            ground_truth (str): The reference SMILES string for generating possible tautomers.
            check_method (str): The method to use for checking equivalence. Possible values are "RC" for
                                relaxed chemical transformation or "ITS" for isomorphic tautomer search.
                                Default is "RC".
            ignore_aromaticity (bool): If True, the comparison ignores differences in aromaticity between
                                       the mapped SMILE and the tautomers. Default is False.

        Returns:
            Optional[bool]: True if the mapped SMILE matches any of the enumerated tautomers of the ground truth
                            according to the specified check method. Returns False if no match is found.
                            Returns None if an error occurs during processing.

        Raises:
            Exception: If an error occurs during the tautomer enumeration or during the comparison process.
        """
        try:
            ground_truth_tautomers = enumerate_tautomers(ground_truth)
            return any(AMMValidator.smiles_check(mapped_smile, t, check_method, ignore_aromaticity) for t in ground_truth_tautomers)
        except Exception as e:
            print(f"An error occurred: {e}")
            return None

    @staticmethod
    def check_pair(
        mapping: Dict[str, str],
        mapped_col: str,
        ground_truth_col: str,
        check_method: str = "RC",
        ignore_aromaticity: bool = False,
        ignore_tautomers: bool = True
    ) -> bool:
        """
        Checks the equivalence between the mapped and ground truth
        values within a given mapping dictionary, using a specified check method.
        The check can optionally ignore aromaticity.

        Parameters:
        - mapping (Dict[str, str]): A dictionary containing the data entries to check.
        - mapped_col (str): The key in the mapping dictionary corresponding
                                    to the mapped value.
        - ground_truth_col (str): The key in the mapping dictionary corresponding
                                    to the ground truth value.
        - check_method (str, optional): The method used for checking the equivalence.
                                    Defaults to 'RC'.
        - ignore_aromaticity (bool, optional): Flag to indicate whether aromaticity
                                    should be ignored during the check.
                                    Defaults to False.
        - ignore_tautomers (bool, optional): Flag to indicate whether tautomers
                                    should be ignored during the check.
                                    Defaults to False.

        Returns:
        - bool: The result of the check, indicating whether the mapped value is
                equivalent to the ground truth according to the specified method
                and considerations regarding aromaticity.
        """
        if ignore_tautomers:
            return AMMValidator.smiles_check(
                mapping[mapped_col],
                mapping[ground_truth_col],
                check_method,
                ignore_aromaticity,
            )
        else:
            return AMMValidator.smiles_check_tautomer(
                mapping[mapped_col],
                mapping[ground_truth_col],
                check_method,
                ignore_aromaticity,
            )

    @staticmethod
    def validate_smiles(
        data: Union[pd.DataFrame, List[Dict[str, str]]],
        id_col: str = "R-id",
        ground_truth_col: str = "ground_truth",
        mapped_cols: List[str] = ["rxn_mapper", "graphormer", "local_mapper"],
        check_method: str = "RC",
        ignore_aromaticity: bool = False,
        n_jobs: int = 1,
        verbose: int = 0,
        ensemble=False,
        strategies=[
            ["rxn_mapper", "graphormer", "local_mapper"],
            ["rxn_mapper", "graphormer", "local_mapper", "rdt"],
        ],
        ignore_tautomers=True
    ) -> List[Dict[str, Union[str, float, List[bool]]]]:
        """
        Validates collections of mapped SMILES against their ground truths
        for multiple mappers and calculates the accuracy.

        Parameters:
            data (Union[pd.DataFrame, List[Dict[str, str]]]): The input data
                                    containing mapped and ground truth SMILES.
            id_col (str): The name of the column or key containing
                                    the reaction ID.
            ground_truth_col (str): The name of the column or key containing
                                    the ground truth SMILES.
            mapped_cols (List[str]): The list of columns or keys containing
                                    the mapped SMILES for different mappers.
            check_method (str): The method used for validation ('RC' or 'ITS').
            ignore_aromaticity (bool): Whether to ignore aromaticity
                                    in ITS graph construction.
            n_jobs (int): The number of parallel jobs to run.
            verbose (int): The verbosity level for joblib's parallel execution.

        Returns:
            List[Dict[str, Union[str, float, List[bool]]]]: A list of dictionaries,
            each containing the mapper name,
            accuracy, and individual results for each SMILES pair.
        """
        validation_results = []

        for mapped_col in mapped_cols:

            if isinstance(data, pd.DataFrame):
                mappings = data.to_dict("records")
            elif isinstance(data, list):
                mappings = data
            else:
                raise ValueError(
                    "Data must be either a pandas DataFrame or a list of dictionaries."
                )

            # Use joblib to parallelize the validation checks
            results = Parallel(n_jobs=n_jobs, verbose=verbose)(
                delayed(AMMValidator.check_pair)(
                    mapping,
                    mapped_col,
                    ground_truth_col,
                    check_method,
                    ignore_aromaticity,
                    ignore_tautomers
                )
                for mapping in mappings
            )
            accuracy = sum(results) / len(mappings) if mappings else 0
            mapped_data = [value[mapped_col] for value in mappings]
            # Store the results for each mapper in the list
            validation_results.append(
                {
                    "mapper": mapped_col,
                    "accuracy": round(100*accuracy,2),
                    "results": results,
                    "success_rate": mapping_success_rate(mapped_data),
                }
            )
        if ensemble:
            for key, strategy in enumerate(strategies):
                mapped_cols = strategy
                threshold = len(mapped_cols) - 1

                its_graph, _ = ITSExtraction.parallel_process_smiles(
                    mappings,
                    mapped_cols,
                    threshold=threshold,
                    n_jobs=n_jobs,
                    verbose=verbose,
                    export_full=False,
                    check_method=check_method,
                )
                id = [value["R-id"] for value in its_graph]
                data_ensemble = [value for value in mappings if value["R-id"] in id]
                data_ensemble = [
                    {
                        id_col: value[id_col],
                        f"ensemble_{key+1}": value[mapped_cols[-1]],
                        ground_truth_col: value[ground_truth_col],
                    }
                    for value in data_ensemble
                ]
                results = Parallel(n_jobs=n_jobs, verbose=verbose)(
                    delayed(AMMValidator.check_pair)(
                        mapping,
                        f"ensemble_{key+1}",
                        ground_truth_col,
                        check_method,
                        ignore_aromaticity,
                        ignore_tautomers
                    )
                    for mapping in data_ensemble
                )
                accuracy = sum(results) / len(data_ensemble)
                validation_results.append(
                    {
                        "mapper": f"ensemble_{key+1}",
                        "accuracy": round(100*accuracy,2),
                        "results": results,
                        "success_rate": round(100 * len(data_ensemble) / len(mappings),2),
                    }
                )

        return validation_results, data_ensemble if ensemble else None


In [None]:
results, _ = AMMValidator.validate_smiles(
            data=data,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = [["rxn_mapper", "graphormer", "local_mapper"], ["rxn_mapper", "graphormer", "local_mapper", "rdt"]],
            ignore_tautomers=False
    )

In [None]:
pd.DataFrame(results)

In [None]:
import sys
sys.path.append('../..')
from SynTemp.SynUtils.utils import load_database

data_paths = {
    'ecoli': '../../Data/AAM/unbalance/ecoli/ecoli_aam_reactions.json.gz',
    'recon3d': '../../Data/AAM/unbalance/recon3d/recon3d_aam_reactions.json.gz',
    'uspto_3k': '../../Data/AAM/unbalance/uspto_3k/uspto_3k_aam_reactions.json.gz',
    'golden': '../../Data/AAM/unbalance/golden/golden_aam_reactions.json.gz',
    'natcomm': '../../Data/AAM/unbalance/natcomm/natcomm_aam_reactions.json.gz'
}

# Dictionary to hold the results for each dataset
results_dict = {}

# Process each dataset
for dataset_name, filepath in data_paths.items():
    # Load the data
    data = load_database(filepath)
    
    # Calculate fails and success rates
    results = calculate_mapping_failures(data)
    
    # Store the results
    results_dict[dataset_name] = results

# Convert the results dictionary to a DataFrame
df_results = pd.DataFrame.from_dict(results_dict, orient='index')

# Display or save the DataFrame
df_results.T

# **3. Accuracy Benchmark**

In [None]:
import sys
import pandas as pd
sys.path.append('../..')
from SynTemp.SynAAM.aam_validator import AMMValidator

In [None]:
data_paths = {
    'ecoli': '../../Data/AAM/unbalance/ecoli/ecoli_aam_reactions.json.gz',
    'recon3d': '../../Data/AAM/unbalance/recon3d/recon3d_aam_reactions.json.gz',
    'uspto_3k': '../../Data/AAM/unbalance/uspto_3k/uspto_3k_aam_reactions.json.gz',
    'golden': '../../Data/AAM/unbalance/golden/golden_aam_reactions.json.gz',
    'natcomm': '../../Data/AAM/unbalance/natcomm/natcomm_aam_reactions.json.gz'
}
results_dict = {}

# Process each dataset
for dataset_name, filepath in data_paths.items():
    # Load the data
    data = load_database(filepath)
    
    # Calculate fails and success rates
    results, _ = AMMValidator.validate_smiles(
            data=data,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = [["rxn_mapper", "graphormer", "local_mapper"], ["rxn_mapper", "graphormer", "local_mapper", "rdt"]],
            ignore_tautomers=False
    )
    
    # Store the results
    results_dict[dataset_name] = results

bio = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)
    
    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ['ecoli', 'recon3d']:
        bio.extend(single)
results, _ = AMMValidator.validate_smiles(
            data=bio,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = [["rxn_mapper", "graphormer", "local_mapper"], ["rxn_mapper", "graphormer", "local_mapper", "rdt"]],
            ignore_tautomers=False
    )
results_dict['Biochemical'] = results   


chem = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)
    
    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ['golden', 'natcomm', 'uspto_3k']:
        chem.extend(single)

results, _ = AMMValidator.validate_smiles(
            data=chem,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = [["rxn_mapper", "graphormer", "local_mapper"], ["rxn_mapper", "graphormer", "local_mapper", "rdt"]],
            ignore_tautomers=False
    )
results_dict['Chemical'] = results   


In [None]:
strategies = [["rxn_mapper", "graphormer"], ["rxn_mapper", "rdt"], ["rxn_mapper", "local_mapper"],
              ["graphormer", "rdt"], ["graphormer", "local_mapper"], ["rdt", "local_mapper"],
              ["rxn_mapper", "graphormer", "rdt"], ["rxn_mapper", "graphormer", "local_mapper"],
              ["rxn_mapper", "rdt", "local_mapper"], ["graphormer", "rdt", "local_mapper"],
              ["rxn_mapper", "graphormer", "rdt", "local_mapper"]]

In [None]:
data_paths = {
    'ecoli': '../../Data/AAM/ecoli/ecoli_aam_reactions.json.gz',
    'recon3d': '../../Data/AAM/recon3d/recon3d_aam_reactions.json.gz',
    'uspto_3k': '../../Data/AAM/uspto_3k/uspto_3k_aam_reactions.json.gz',
    'golden': '../../Data/AAM/golden/golden_aam_reactions.json.gz',
    'natcomm': '../../Data/AAM/natcomm/natcomm_aam_reactions.json.gz'
}
results_dict = {}

# Process each dataset
for dataset_name, filepath in data_paths.items():
    # Load the data
    data = load_database(filepath)
    
    # Calculate fails and success rates
    results, _ = AMMValidator.validate_smiles(
            data=data,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="ITS",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = strategies
    )
    
    # Store the results
    results_dict[dataset_name] = results

bio = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)
    
    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ['ecoli', 'recon3d']:
        bio.extend(single)
results, _ = AMMValidator.validate_smiles(
            data=bio,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = strategies
    )
results_dict['Biochemical'] = results   


chem = []
for dataset_name, filepath in data_paths.items():
    # Load the dataset
    single = load_database(filepath)
    
    # Extend the bio list if the dataset is either 'ecoli' or 'recon3d'
    if dataset_name in ['golden', 'natcomm', 'uspto_3k']:
        chem.extend(single)

results, _ = AMMValidator.validate_smiles(
            data=chem,
            ground_truth_col="ground_truth",
            mapped_cols=["rxn_mapper", "graphormer", "local_mapper", "rdt"],
            check_method="RC",
            ignore_aromaticity=False,
            n_jobs=4,
            verbose=2,
            ensemble=True,
            strategies = strategies
    )
results_dict['Chemical'] = results   


    

In [None]:
# Initialize an empty DataFrame to hold all data
import pandas as pd
final_df = pd.DataFrame()

# Process each data type
for data_type, records in results_dict.items():
    # Create a DataFrame
    df = pd.DataFrame(records)
    
    # Add columns for accuracy and success_rate specific to the data type
    df['accuracy_col'] = data_type + '_accuracy'
    df['success_col'] = data_type + '_success_rate'
    df['Accuracy'] = round(df['accuracy']*100, 2) 
    df['Success Rate'] = round(df['success_rate'], 2) 
    
    # Pivot the DataFrame
    df_pivot = df.pivot(index='mapper', columns='accuracy_col', values='Accuracy').join(
        df.pivot(index='mapper', columns='success_col', values='Success Rate')
    )
    
    # Merge with the final DataFrame
    if final_df.empty:
        final_df = df_pivot
    else:
        final_df = final_df.join(df_pivot)

# Reset index to make 'mapper' a column
final_df.reset_index(inplace=True)
final_df = final_df.reindex([5, 2, 3, 4, 0, 1])
final_df = final_df.reset_index(drop=True)

In [None]:
aam_json = final_df.to_dict(orient='records')

In [None]:
from SynTemp.SynUtils.utils import save_database, load_database
save_database(aam_json, '../../Data/AAM/unbalance/aam_benchmark.json.gz')

In [None]:
import sys
sys.path.append('../../')
from SynTemp.SynUtils.utils import save_database, load_database
final_df = pd.DataFrame(load_database('../../Data/AAM/unbalance/aam_benchmark.json.gz'))

In [None]:
final_df[['mapper', 'Biochemical_accuracy', 'Biochemical_success_rate', 'Chemical_accuracy', 'Chemical_success_rate']]

## 3.1 Heatmap

In [None]:
data_visual = final_df[['mapper', 'ecoli_accuracy', 'recon3d_accuracy', 'uspto_3k_accuracy', 'golden_accuracy', 'natcomm_accuracy','Biochemical_accuracy', 'Chemical_accuracy']]
data_visual.rename({'ecoli_accuracy':'ecoli', 'recon3d_accuracy':'recon3d', 'uspto_3k_accuracy':'uspto_3k',
             'golden_accuracy':'golden', 'natcomm_accuracy':'natcomm',
             'Biochemical_accuracy':'Biochemical', 'Chemical_accuracy':'Chemical'}, axis=1, inplace=True)

In [None]:
data_visual['mapper'] = [
    r'$\texttt{RXNMapper}$', 
    r'$\texttt{GraphormerMapper}$', 
    r'$\texttt{LocalMapper}$', 
    r'$\texttt{RDT}$', 
    r'$\textit{Ensemble_1}$', 
    r'$\textit{Ensemble_2}$'
]


In [None]:
data_visual

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming 'test' is your DataFrame and already includes the necessary data
heatmap_data = data_visual.set_index('mapper')  # Adjust this line if the DataFrame preparation differs

# Plotting
plt.figure(figsize=(16, 8))
# Heatmap with annotations and adjusted color map
sns.heatmap(heatmap_data, annot=True, cmap='coolwarm', fmt=".1f", linewidths=0.3, linecolor='white', 
            cbar=True, cbar_kws={'label': r'Accuracy (\%)', 'orientation': 'vertical'}, annot_kws={"size": 18})

# To further customize the color bar label font size
cbar = plt.gca().collections[0].colorbar
cbar.ax.tick_params(labelsize=18)  # Change the font size of the color bar ticks
cbar.set_label(r'Accuracy (\%)', size=18)  # Change the font size of the color bar label


# Enhancements for title and labels
#plt.title('Heatmap of Mapper Accuracies Across Datasets', fontsize=18, color='navy', fontweight='bold', pad=20)
# plt.xlabel('Dataset', fontsize=16, color='black', labelpad=10)
# plt.ylabel('Mapper', fontsize=16, color='black', labelpad=10)
plt.ylabel(None)
# Adjustments for tick marks
plt.xticks(rotation=45, ha='right', fontsize=18, fontweight='bold', color='darkred')
plt.yticks(rotation=0, fontsize=18, fontweight='bold', color='darkgreen')
plt.tight_layout()  # Ensure the layout fits without overlap
plt.savefig('./fig/aam_accuracy_heatmap.pdf', dpi = 600)
plt.show()


## 3.2 Barplot

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_accuracy_success_rate_seaborn(df, accuracy_col, success_col):
    """
    Plots grouped bar charts for accuracy and success rates using Seaborn,
    from specified columns in a DataFrame.
    
    Parameters:
    - df: DataFrame containing the data
    - accuracy_col: string, name of the column with accuracy data
    - success_col: string, name of the column with success rate data
    """
    # Ensure the mapper is treated as categorical
    df['mapper'] = df['mapper'].astype(str)

    # Create a temporary DataFrame to facilitate Seaborn plotting
    temp_df = df[['mapper', accuracy_col, success_col]].melt(id_vars=['mapper'], 
                                                             var_name='Metric', 
                                                             value_name='Percentage')

    # Mapping the original column names to more user-friendly names
    temp_df['Metric'] = temp_df['Metric'].map({accuracy_col: 'Accuracy', success_col: 'Success Rate'})
    
    # Initialize the matplotlib figure
    plt.figure(figsize=(14, 8))
    
    # Plot using Seaborn
    sns.barplot(x='mapper', y='Percentage', hue='Metric', data=temp_df, palette='viridis')

    # Adding labels above bars
    ax = plt.gca()
    for p in ax.patches:
        height = p.get_height()
        if height > 0:  # Only add annotations to bars with a non-zero height
            ax.annotate(format(height, '.2f'), 
                        (p.get_x() + p.get_width() / 2., height), 
                        ha='center', va='center', 
                        xytext=(0, 9), 
                        textcoords='offset points')

    # Set title and labels with enhanced font settings
    plt.title('Chemical reaction database', fontsize=18, fontweight='bold')
    plt.xlabel(None)
    plt.ylabel('Percentage (%)', fontsize=16, fontweight='semibold')
    
    # Improve the appearance of ticks
    plt.xticks(rotation=45, fontsize=12, fontweight='bold')
    plt.yticks(fontsize=12, fontweight='bold')

    # Adding a grid for better visual alignment
    plt.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

    # Adjust legend to prevent overlap with the bars
    plt.legend(title=None, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
    
    plt.tight_layout()
    plt.show()

df = final_df.copy()
plot_accuracy_success_rate_seaborn(df, 'Chemical_accuracy', 'Chemical_success_rate')


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_accuracy_success_rate_subplot(df, accuracy_cols, success_cols, titles, axes):
    """
    Plots 1x2 subplots for accuracy and success rates using Seaborn,
    from specified columns in a DataFrame.

    Parameters:
    - df: DataFrame containing the data
    - accuracy_cols: list of strings, names of the columns with accuracy data
    - success_cols: list of strings, names of the columns with success rate data
    - titles: list of strings, titles for each subplot
    """
    #fig, axes = plt.subplots(1, 2, figsize=(18, 8))  # 1x2 subplot layout

    for idx, ax in enumerate(axes):
        # Ensure the mapper is treated as categorical
        df['mapper'] = df['mapper'].astype(str)

        # Create a temporary DataFrame to facilitate Seaborn plotting
        temp_df = df[['mapper', accuracy_cols[idx], success_cols[idx]]].melt(id_vars=['mapper'], 
                                                                            var_name='Metric', 
                                                                            value_name='Percentage')

        # Mapping the original column names to more user-friendly names
        temp_df['Metric'] = temp_df['Metric'].map({accuracy_cols[idx]: 'Accuracy', success_cols[idx]: 'Success Rate'})
        
        # Plot using Seaborn on the specified axis
        sns.barplot(x='mapper', y='Percentage', hue='Metric', data=temp_df, palette='coolwarm', ax=ax)

        # Adding labels above bars
        for p in ax.patches:
            height = p.get_height()
            if height > 0:  # Only add annotations to bars with a non-zero height
                ax.annotate(format(height, '.1f'), 
                            (p.get_x() + p.get_width() / 2., height), 
                            ha='center', va='center', 
                            xytext=(0, 9), 
                            textcoords='offset points',
                            fontsize=18)

        # Set title and labels with enhanced font settings for each subplot
        ax.set_title(titles[idx], fontsize=28, fontweight='bold')
        ax.set_xlabel(None)
        ax.set_ylabel(r'Percentage (\%)', fontsize=24, fontweight='semibold')
        
        # Improve the appearance of ticks
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=18, fontweight='bold')
        ax.set_yticklabels([f'{int(x)}%' for x in ax.get_yticks()], fontsize=18, fontweight='bold')

        # Adding a grid for better visual alignment
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

        # Remove individual legends
        ax.legend([],[], frameon=False)

    # Add a single legend outside the subplots
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, ['Accuracy', 'Success Rate'], loc='center left', bbox_to_anchor=(0.45, 0.05), fontsize=18)

    
    plt.tight_layout()
    #plt.savefig('./fig/aam_accuracy_barplot.pdf', dpi = 600)
    #plt.show()

# df = final_df.copy()

# df['mapper'] = [
#     r'$\texttt{RXNMapper}$', 
#     r'$\texttt{GraphormerMapper}$', 
#     r'$\texttt{LocalMapper}$', 
#     r'$\texttt{RDT}$', 
#     r'$\textit{Ensemble_1}$', 
#     r'$\textit{Ensemble_2}$'
# ]

# fig, axes = plt.subplots(1, 2, figsize=(18, 8))
# plot_accuracy_success_rate_subplot(df, ['Chemical_accuracy', 'Biochemical_accuracy'], ['Chemical_success_rate', 'Biochemical_success_rate'], 
#                                    [r'$\textit{Chemical dataset}$', r'$\textit{Biochemical dataset}$'], axes)


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 8))
plot_accuracy_success_rate_subplot(df, ['Chemical_accuracy', 'Biochemical_accuracy'], ['Chemical_success_rate', 'Biochemical_success_rate'], 
                                   [r'$\textit{Chemical dataset}$', r'$\textit{Biochemical dataset}$'], axes)

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def plot_accuracy_success_rate_subplot(df, accuracy_cols, success_cols, titles, axes):
    """
    Plots 1x2 subplots for accuracy and success rates using Seaborn,
    from specified columns in a DataFrame.

    Parameters:
    - df: DataFrame containing the data
    - accuracy_cols: list of strings, names of the columns with accuracy data
    - success_cols: list of strings, names of the columns with success rate data
    - titles: list of strings, titles for each subplot
    - axes: array of axes objects to plot on
    """
    for idx, ax in enumerate(axes):
        # Ensure the mapper is treated as categorical
        df['mapper'] = df['mapper'].astype(str)

        # Create a temporary DataFrame to facilitate Seaborn plotting
        temp_df = df[['mapper', accuracy_cols[idx], success_cols[idx]]].melt(id_vars=['mapper'], 
                                                                            var_name='Metric', 
                                                                            value_name='Percentage')

        # Mapping the original column names to more user-friendly names
        temp_df['Metric'] = temp_df['Metric'].map({accuracy_cols[idx]: 'Accuracy', success_cols[idx]: 'Success Rate'})
        
        # Plot using Seaborn on the specified axis
        sns.barplot(x='mapper', y='Percentage', hue='Metric', data=temp_df, palette='coolwarm', ax=ax)

        # Adding labels above bars
        for p in ax.patches:
            height = p.get_height()
            if height > 0:  # Only add annotations to bars with a non-zero height
                ax.annotate(format(height, '.1f'), 
                            (p.get_x() + p.get_width() / 2., height), 
                            ha='center', va='center', 
                            xytext=(0, 9), 
                            textcoords='offset points',
                            fontsize=18)

        # Set title and labels with enhanced font settings for each subplot
        ax.set_title(titles[idx], fontsize=28, fontweight='bold')
        ax.set_xlabel(None)
        ax.set_ylabel(r'Percentage (\%)', fontsize=24, fontweight='semibold')
        
        # Improve the appearance of ticks
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=18, fontweight='bold')
        ax.set_yticklabels([f'{int(x)}%' for x in ax.get_yticks()], fontsize=18, fontweight='bold')

        # Adding a grid for better visual alignment
        ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

        # Remove individual legends
        ax.legend([],[], frameon=False)

    # Add a single legend outside the subplots
    handles, labels = ax.get_legend_handles_labels()
    fig = plt.gcf()
    fig.legend(handles, ['Accuracy', 'Success Rate'], loc='center left', bbox_to_anchor=(0.45, 0.02), fontsize=20)

    plt.tight_layout()
    # plt.savefig('./fig/aam_accuracy_barplot.pdf', dpi=600)
    #plt.subplots_adjust(hspace=0.1)
    plt.show()

# accuracy_cols = ['Chemical_accuracy', 'Biochemical_accuracy']
# success_cols = ['Chemical_success_rate', 'Biochemical_success_rate']
# titles = [r'$\textit{Chemical dataset}$', r'$\textit{Biochemical dataset}$']

# fig, axes = plt.subplots(1, 2, figsize=(18, 8))

    
# df = final_df.copy()

# df['mapper'] = [
#     r'$\texttt{RXNMapper}$', 
#     r'$\texttt{GraphormerMapper}$', 
#     r'$\texttt{LocalMapper}$', 
#     r'$\texttt{RDT}$', 
#     r'$\textit{Ensemble_1}$', 
#     r'$\textit{Ensemble_2}$'
# ]


# plot_accuracy_success_rate_subplot(df, accuracy_cols, success_cols, titles, axes)



In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# Assuming the functions plot_accuracy_success_rate_subplot and plot_reaction_times are defined elsewhere

# Copy the final_df to df and create a new 'mapper' column

accuracy_cols = ['Chemical_accuracy', 'Biochemical_accuracy']
success_cols = ['Chemical_success_rate', 'Biochemical_success_rate']
titles = [r'$\textit{B. Chemical dataset}$', r'$\textit{C. Biochemical dataset}$']

df = final_df.copy()
df['mapper'] = [
    r'$\texttt{RXNMapper}$', 
    r'$\texttt{GraphormerMapper}$', 
    r'$\texttt{LocalMapper}$', 
    r'$\texttt{RDT}$', 
    r'$\textit{Ensemble_1}$', 
    r'$\textit{Ensemble_2}$'
]

# Create a 2x2 subplot layout
fig = plt.figure(figsize=(18, 16))
gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])

# First row: reaction times plot (spans the entire width)
ax0 = plt.subplot(gs[0, :])
plot_reaction_times(df_time, ax0)

# Second row: accuracy and success rate subplots
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[1, 1])
plot_accuracy_success_rate_subplot(df, accuracy_cols, success_cols, titles, [ax1, ax2])

# Adjust layout to prevent overlap
plt.tight_layout()

# Display the combined plot
fig.savefig('./fig/aam_time_data_benchmark.pdf', dpi = 600)
plt.show()
