# **1. CGRTool**

In [None]:
import sys

sys.path.append("../../")
from syntemp.SynUtils.utils import load_database, save_database
import pandas as pd

uspto_3k = pd.read_csv("../../Data/AAM/aam_benchmark/USPTO_sampled.csv")

## 1.1 Generate unbias ground truth

In [None]:
import pandas as pd
import re


def add_ground_truth_column(df):
    # Precompile the regex pattern outside of the function to avoid recompilation
    atom_map_pattern = re.compile(r":\d+")

    # Helper function to count atom mappings
    def count_atom_mappings(smiles_string):
        if pd.isna(smiles_string):
            return 0
        mappings = atom_map_pattern.findall(smiles_string)
        return len(set(mappings))

    # Vectorize the counting of atom mappings for each mapper based on the correct flag
    for mapper in ["RXNMapper", "GraphMapper", "LocalMapper"]:
        df[mapper + "_count"] = df.apply(
            lambda x: count_atom_mappings(x[mapper]) if x[mapper + "_correct"] else 0,
            axis=1,
        )

    # Function to determine the ground truth based on the highest count of atom mappings
    def determine_ground_truth(row):
        # Collect counts and corresponding SMILES strings if the count is positive
        mapping_data = {
            "RXNMapper": (row["RXNMapper_count"], row["RXNMapper"]),
            "GraphMapper": (row["GraphMapper_count"], row["GraphMapper"]),
            "LocalMapper": (row["LocalMapper_count"], row["LocalMapper"]),
        }
        # Select the mapper with the highest non-zero count
        max_mapper, (max_count, max_smiles) = max(
            mapping_data.items(), key=lambda x: x[1][0], default=(None, (0, None))
        )
        # Return the SMILES string of the mapper with the highest count or None if all counts are zero
        return max_smiles if max_count > 0 else None

    # Apply the function to determine the ground truth for each row
    df["Ground Truth"] = df.apply(determine_ground_truth, axis=1)
    return df

In [None]:
df = add_ground_truth_column(uspto_3k)
df.to_csv("../../Data/AAM/cgrtool_benchmark/USPTO_3K.csv", index=False)

## 1.2. Benchmark with CGRTool

In [None]:
df_u1 = pd.read_csv(
    "../../Data/AAM/cgrtool_benchmark/uspto_3k_cgrtool_old.csv", index_col=0
)
df_u2 = pd.read_csv(
    "../../Data/AAM/cgrtool_benchmark/uspto_3k_cgrtool_new.csv", index_col=0
)

In [None]:
df_u1 = df_u1[
    [
        "Ground Truth",
        "RXNMapper_correct",
        "GraphMapper_correct",
        "LocalMapper_correct",
        "CGRTool_rxnmapper",
        "CGRTool_graphmapper",
        "CGRTool_localmapper",
    ]
]

df_u2 = df_u2[
    [
        "Ground Truth",
        "RXNMapper_correct",
        "GraphMapper_correct",
        "LocalMapper_correct",
        "CGRTool_rxnmapper",
        "CGRTool_graphmapper",
        "CGRTool_localmapper",
    ]
]

In [None]:
ground_data = pd.DataFrame(
    [
        {
            "RXNMapper": round(100 * df_u1["RXNMapper_correct"].sum() / len(df_u1), 2),
            "Graphormer": round(
                100 * df_u1["GraphMapper_correct"].sum() / len(df_u1), 2
            ),
            "LocalMapper": round(
                100 * df_u1["LocalMapper_correct"].sum() / len(df_u1), 2
            ),
        }
    ]
).T
ground_data.rename(columns={0: "Ground Truth (%)"}, inplace=True)
ground_data

In [None]:
cgrtool_u1 = pd.DataFrame(
    [
        {
            "RXNMapper": round(100 * df_u1["CGRTool_rxnmapper"].sum() / len(df_u1), 2),
            "Graphormer": round(
                100 * df_u1["CGRTool_graphmapper"].sum() / len(df_u1), 2
            ),
            "LocalMapper": round(
                100 * df_u1["CGRTool_localmapper"].sum() / len(df_u1), 2
            ),
        }
    ]
).T

cgrtool_u1.rename(columns={0: "CGRTools 1 (%)"}, inplace=True)
cgrtool_u1

In [None]:
cgrtool_u2 = pd.DataFrame(
    [
        {
            "RXNMapper": round(100 * df_u2["CGRTool_rxnmapper"].sum() / len(df_u2), 2),
            "Graphormer": round(
                100 * df_u2["CGRTool_graphmapper"].sum() / len(df_u2), 2
            ),
            "LocalMapper": round(
                100 * df_u2["CGRTool_localmapper"].sum() / len(df_u2), 2
            ),
        }
    ]
).T

cgrtool_u2.rename(columns={0: "CGRTools 2 (%)"}, inplace=True)
cgrtool_u2

In [None]:
cgr_data = pd.concat([ground_data, cgrtool_u1, cgrtool_u2], axis=1)
cgr_data.rename(index={"Graphormer": "GraphMapper"}, inplace=True)

In [None]:
from syntemp.SynAAM.aam_validator import AAMValidator

df_u1 = pd.read_csv(
    "../../Data/AAM/cgrtool_benchmark/uspto_3k_cgrtool_old.csv", index_col=0
)
df_u2 = pd.read_csv(
    "../../Data/AAM/cgrtool_benchmark/uspto_3k_cgrtool_new.csv", index_col=0
)
syntemp_u1 = AAMValidator.validate_smiles(
    data=df_u1,
    ground_truth_col="GroundTruth",
    mapped_cols=["RXNMapper", "GraphMapper", "LocalMapper"],
    check_method="RC",
    ignore_aromaticity=False,
    n_jobs=4,
    verbose=0,
    ensemble=False,
    strategies=[["rxn_mapper", "graphormer", "local_mapper"]],
    ignore_tautomers=False,
)


syntemp_u2 = AAMValidator.validate_smiles(
    data=df_u2,
    ground_truth_col="GroundTruth",
    mapped_cols=["RXNMapper", "GraphMapper", "LocalMapper"],
    check_method="RC",
    ignore_aromaticity=False,
    n_jobs=4,
    verbose=0,
    ensemble=False,
    strategies=[["rxn_mapper", "graphormer", "local_mapper"]],
    ignore_tautomers=False,
)

In [None]:
temp_u1 = pd.DataFrame(syntemp_u1[0])
temp_u1.rename(columns={"accuracy": "syntemp_u1"}, inplace=True)
temp_u1.index = temp_u1["mapper"]

temp_u2 = pd.DataFrame(syntemp_u2[0])
temp_u2.rename(columns={"accuracy": "syntemp_u2"}, inplace=True)
temp_u2.index = temp_u2["mapper"]

benchmark_df = pd.concat(
    [cgr_data, temp_u1["syntemp_u1"], temp_u2["syntemp_u2"]], axis=1
)

benchmark_df

## 1.3. Analyze difference from Ground Truth

In [None]:
data_check = pd.DataFrame(syntemp_u2[0])

In [None]:
list_diff_rxn = []
for key, value in enumerate(df_u2["RXNMapper_correct"]):
    if value != data_check["results"][0][key]:
        list_diff_rxn.append(key)

list_diff_graph = []
for key, value in enumerate(df_u2["GraphMapper_correct"]):
    if value != data_check["results"][1][key]:
        list_diff_graph.append(key)
print("Differences in RXNMapper:", list_diff_rxn)
print("Differences in GraphMapper:", list_diff_graph)

In [None]:
from syntemp.SynVis.chemical_reaction_visualizer import ChemicalReactionVisualizer

vis = ChemicalReactionVisualizer()
i = 192
display(
    vis.visualize_reaction(
        df_u2.loc[i, "GroundTruth"], img_size=(1000, 300), show_atom_map=True
    )
)
display(
    vis.visualize_reaction(
        df_u2.loc[i, "RXNMapper"], img_size=(1000, 300), show_atom_map=True
    )
)
print(df_u2.loc[i, "RXNMapper_correct"])

In [None]:
i = 2157
display(
    vis.visualize_reaction(
        df_u2.loc[i, "GroundTruth"], img_size=(1000, 300), show_atom_map=True
    )
)
display(
    vis.visualize_reaction(
        df_u2.loc[i, "RXNMapper"], img_size=(1000, 300), show_atom_map=True
    )
)
print(df_u2.loc[i, "RXNMapper_correct"])