In [None]:
import sys
from pathlib import Path
sys.path.append(
    str(Path.cwd().parent.resolve())
)

In [None]:
import argparse
import subprocess
from multiprocessing import cpu_count
from typing import List, Set
import operator

import pandas as pd
import numpy as np

from rdkit import RDLogger
from rdkit.Chem import rdChemReactions
from rdkit.DataStructs.cDataStructs import ConvertToNumpyArray

from openTSNE import TSNE

import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

from tqdm import tqdm

import src.utils as ut
from src import root
from src.prediction import MTReagentPredictor, MTProductPredictor
from src.preprocessing import HeuristicRoleClassifier, SOLVENTS

In [None]:
from IPython.display import display

In [None]:
tqdm.pandas()

In [None]:
RDLogger.DisableLog('rdApp.*')

In [None]:
def remove_redundant_separators(smi: str) -> str:
    """
    Removes trailing or repeated separators from a SMILES string
    :param smi: A SMILES string encoding several molecules
    :return: A SMILES string with no redundant separators
    """
    return ".".join([i for i in smi.split('.') if i != ''])


def standardize_pd_pph3(smi: str) -> str:
    """
    Replaces a Pd(PPh3)4 molecule with 5 separate species - one Pd and four PPh3
    :param smi:
    :return:
    """
    united = "c1ccc([PH](c2ccccc2)(c2ccccc2)[Pd]([PH](c2ccccc2)(c2ccccc2)c2ccccc2)([PH](c2ccccc2)(c2ccccc2)c2ccccc2)[PH](c2ccccc2)(c2ccccc2)c2ccccc2)cc1"
    split = "[Pd].c1ccc(P(c2ccccc2)c2ccccc2)cc1.c1ccc(P(c2ccccc2)c2ccccc2)cc1.c1ccc(P(c2ccccc2)c2ccccc2)cc1.c1ccc(P(c2ccccc2)c2ccccc2)cc1"
    return smi.replace(united, split)


def match_accuracy(tgt_and_pred: 'pd.Series', mode: str) -> List[int]:
    """
    Calculates the accuracy of reagent predictions.
    :param tgt_and_pred: The series, where the first row containts the target string,
    and all the other contain predictions ordered by priority.
    The function is meant to be applied to a DataFrame of shape (M, N+1), where M is the number of reactions
    and N is the number of predictions.
    :param mode: If it is 'exact', calculates exact match accuracy.
    If it is 'partial', calculates partial match accuracy.
    :return: List of zeros and ones of length N.
    """

    def _exact_match(t_set: Set[str], p_set: Set[str]) -> bool:
        return len(t_set.symmetric_difference(p_set)) == 0

    def _partial_match(t_set: Set[str], p_set: Set[str]) -> bool:
        return len(t_set & p_set) > 0

    if mode == 'exact':
        _match = _exact_match
    elif mode == 'partial':
        _match = _partial_match
    else:
        raise ValueError(f"The only allowed values for the 'mode' argument are 'exact' and 'partial', not {mode}")

    tgt_mols, *pred_mols = tgt_and_pred
    tgt_set = {s for s in tgt_mols.split(".") if s}
    topn = {i + 1: False for i in range(len(pred_mols))}
    for i, p in enumerate(pred_mols):
        pred_set = {s for s in p.split(".") if s}
        topn[i + 1] = topn.get(i, False) or _match(tgt_set, pred_set)
    return [int(topn[i + 1]) for i in range(len(pred_mols))]


def recall(tgt_and_pred: 'pd.Series') -> List[float]:
    """
    Calculates the recall of reagent predictions.
    :param tgt_and_pred: The series, where the first row containts the target string,
    and all the other contain predictions ordered by priority.
    The function is meant to be applied to a DataFrame of shape (M, N+1), where M is the number of reactions
    and N is the number of predictions.
    :param mode: If it is 'exact', calculates exact match accuracy.
    If it is 'partial', calculates partial match accuracy.
    :return: List of recall scores of length N.
    """

    tgt_mols, *pred_mols = tgt_and_pred
    tgt_set = {s for s in tgt_mols.split(".") if s}
    tgt_set_orig_len = len(tgt_set)
    topn = {i + 1: 0 for i in range(len(pred_mols))}
    for i, p in enumerate(pred_mols):
        pred_set = {s for s in p.split(".") if s}
        tgt_set = tgt_set - pred_set
        topn[i + 1] = 1 - len(tgt_set) / tgt_set_orig_len
    return [topn[i + 1] for i in range(len(pred_mols))]


def plot_confidence_distributions(confs: 'pd.Series', match: 'pd.Series'):
    fig, ax = plt.subplots(ncols=2, nrows=1, sharey=True, figsize=(10, 10))

    plt.sca(ax[0])
    sns.violinplot(y=confs,
                   scale="width",
                   ax=plt.gca())
    plt.grid(axis='y')

    plt.sca(ax[1])
    sns.boxplot(x=match, y=confs, ax=plt.gca())
    plt.grid(axis='y')

    return fig, ax

In [None]:
CLASS_COLORS = {
    "Acylation and related processes": "#C0392B",
    "Heteroatom alkylation and arylation": "#E67E22",
    "C-C bond formation": "#27AE60",
    "Heterocycle formation": "#F1C40F",
    "Protections": "#1186F3",
    "Deprotections": "#707B7C",
    "Reductions": "#40C4DE",
    "Oxidations": "#DC40DE",
    "Functional group interconversion (FGI)": "#CCAF1C",
    "Functional group addition (FGA)": "#8E44AD"
}


class ReactionFPS:
    """
    Calculates reaction fingerprints using RDKit.
    """

    FP_TYPES = {"AtomPairFP": rdChemReactions.FingerprintType.AtomPairFP,
                "MorganFP": rdChemReactions.FingerprintType.MorganFP,
                "PatternFP": rdChemReactions.FingerprintType.PatternFP,
                "RDKitFP": rdChemReactions.FingerprintType.RDKitFP,
                "TopologicalTorsion": rdChemReactions.FingerprintType.TopologicalTorsion
                }

    def calculate(self,
                  rx_smi: str,
                  fp_method: str,
                  n_bits: int,
                  fp_type: str,
                  include_agents: bool,
                  agent_weight: int,
                  non_agent_weight: int,
                  bit_ratio_agents: float = 0.2
                  ) -> 'np.array':
        """
        Calculates reaction fingerprints for a given reaction SMILES string.
        More info on arguments: https://www.rdkit.org/docs/cppapi/structRDKit_1_1ReactionFingerprintParams.html
        :param rx_smi: Reaction SMILES to calculate fingerprints for.
        :param fp_method: 'structural' or 'difference'.
        :param n_bits: Number of bits in the fingerprint vectors
        :param fp_type: the algorithm for fingerprints, e.g. AtompairFP.
        Be aware that only AtompairFP, TopologicalTorsion and MorganFP are supported in the difference fingerprint.
        :param include_agents: a flag: include the agents of a reaction for fingerprint generation or not
        :param agent_weight: if agents are included, agents could
        be weighted compared to reactants and products in difference fingerprints.
        :param non_agent_weight: in difference fingerprints weight factor for reactants and products compared to agents
        :param bit_ratio_agents: in structural fingerprints it determines the ratio of bits of the agents in the fingerprint
        :return: fingerprint vector (numpy array)
        """
        # === Parameters section
        params = rdChemReactions.ReactionFingerprintParams()
        params.fpSize = n_bits
        params.includeAgents = include_agents
        params.fpType = self.FP_TYPES[fp_type]
        # ===

        rxn = rdChemReactions.ReactionFromSmarts(
            rx_smi,
            useSmiles=True)

        arr = np.zeros((1,))
        if fp_method == "difference":
            params.agentWeight = agent_weight
            params.nonAgentWeight = non_agent_weight
            # NOTE: difference fingerprints are not binary
            fps = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn, params)

        elif fp_method == "structural":
            params.bitRatioAgents = bit_ratio_agents
            # NOTE: structural fingerprints are binary
            fps = rdChemReactions.CreateStructuralFingerprintForReaction(rxn, params)
        else:
            raise ValueError("Invalid fp_method. Allowed are 'difference' and 'structural'")

        ConvertToNumpyArray(fps, arr)
        return arr


def diff_fp(smi: str) -> 'np.array':
    return ReactionFPS().calculate(smi,
                                   fp_method="difference",
                                   n_bits=2048,
                                   fp_type="MorganFP",
                                   include_agents=True,
                                   agent_weight=1,
                                   non_agent_weight=1)


def plot_2d_distribution(x,
                         y,
                         save_path=None,
                         ax=None,
                         title=None,
                         draw_legend=True,
                         colors=None,
                         legend_kwargs=None,
                         label_order=None,
                         **kwargs) -> None:
    """
    Plots TSNE embeddings
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(20, 20))

    if title is not None:
        ax.set_title(title)

    plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)}

    # Create main plot
    if label_order is not None:
        assert all(np.isin(np.unique(y), label_order))
        classes = [l for l in label_order if l in np.unique(y)]
    else:
        classes = np.unique(y)
    if colors is None:
        default_colors = matplotlib.rcParams["axes.prop_cycle"]
        colors = {k: v["color"] for k, v in zip(classes, default_colors())}

    point_colors = list(map(colors.get, y))

    ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params)

    # Hide ticks and axis
    ax.set_xticks([]), ax.set_yticks([]), ax.axis("off")

    if draw_legend:
        legend_handles = [
            matplotlib.lines.Line2D(
                [],
                [],
                marker="s",
                color="w",
                markerfacecolor=colors[yi],
                ms=10,
                alpha=1,
                linewidth=0,
                label=yi,
                markeredgecolor="k",
            )
            for yi in classes
        ]
        legend_kwargs_ = dict(loc="best", bbox_to_anchor=(1, 0.5), frameon=False, )
        if legend_kwargs is not None:
            legend_kwargs_.update(legend_kwargs)
        ax.legend(handles=legend_handles, **legend_kwargs_)

    if save_path is not None:
        plt.savefig(save_path, dpi=400, transparent=True)
    else:
        plt.show()


def plot_rxn_class_proportions(classes_in: 'pd.Series',
                               classes_out: 'pd.Series',
                               save_path=None) -> None:
    """
    Plot the graph comparing the proportions of the ten reaction classes from USPTO 50K
    between the in-distribution data and the out-of-distribution data.
    :param save_path: Path to save the image.
    :param classes_in: Proportions of reaction classes for the in-distribution data.
    :param classes_out: Proportions of reaction classes for the out-of-distribution data.
    :return:
    """

    plt.figure(figsize=(15, 8))
    width = 0.3

    plt.bar(
        np.arange(len(classes_in)),
        classes_in,
        width,
        color='#abebc6',
        alpha=1,
        label="USPTO 50K"
    )
    plt.bar(
        np.arange(len(classes_in)) + width,
        classes_out[classes_in.index],
        width,
        color='#cd6155',
        alpha=1,
        label="Reaxys Test"
    )
    plt.legend()
    plt.xticks(np.arange(len(classes_in)) + width / 2,
               [i.split("(")[-1].strip('(').strip(')') for i in classes_in.index],
               rotation=35,
               fontsize=16)
    plt.title("Proportion of reaction classes")
    plt.grid(axis='y')

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    else:
        plt.show()

In [None]:
np.random.seed(123456)

### Arguments

In [None]:
USPTO50_PATH = "../uspto50k.csv"

TEST_PATH = "../data/raw/reaxys_test_full_nice_classes.csv"
VOCAB = "../data/vocabs/reag_no_MIT_src_vocab.json"
MODEL = "../experiments/trained_models/reag_no_MIT_model_step_80000.pt"
# MODEL = "../experiments/trained_models/reag_no_MIT_model_step_80000_shared_voc.pt"
# MODEL = "../experiments/trained_models/simple_tok_with_rep_aug_model_step_520000.pt"
RUN_NAME = "reaxys_rgs_80k_test_full_poster"
SRC_COLUMN = "ReactionRgsSource"
GPU = 0

# Loading the test set and USPTO 50K for visual comparison

In [None]:
# Data which is in-distribution to the training set
data_in = pd.read_csv("../uspto50k.csv")

# Test data
data_out = pd.read_csv(TEST_PATH).fillna('')

In [None]:
data_out.shape

In [None]:
stat_in = data_in["General class"].value_counts(normalize=True)

stat_out = data_out["General class"].value_counts(normalize=True)

In [None]:
stat_in

In [None]:
stat_out

### Comparing reaction class proportions

In [None]:
# sns.set_theme(style='white', font_scale=1.5)
plot_rxn_class_proportions(stat_in, stat_out, save_path="proportion.png")

### Comparing the TSNE embeddings of poins in USPTO50K and the test set

In [None]:
tsne = TSNE(
        perplexity=30,
        metric="euclidean",
        n_jobs=cpu_count(),
        random_state=123456,
        verbose=True
    )

In [None]:
fp_in = np.vstack(
        data_in["Reaction"].apply(diff_fp)
    )
embedding_in = tsne.fit(fp_in)

TSNE is parametric here. It produces embeddings for the test reactions based on the embeddings of reactions in USPTO 50K.

In [None]:
fp_out = np.vstack(
        data_out["FullR"].apply(diff_fp)
    )
embedding_out = embedding_in.transform(fp_out)

In [None]:
plot_2d_distribution(embedding_in,
                     data_in["General class"],
                     colors=CLASS_COLORS,
                     draw_legend=False,
                     save_path="uspto_map.png",
                     s=15, 
                     alpha=0.5)

In [None]:
plot_2d_distribution(embedding_out,
                     data_out["General class"],
                     save_path="reaxys_map.png",
                     draw_legend=False,
                     s=15,
                     alpha=0.5,
                     colors={i: CLASS_COLORS.get(i, "#00FF00") for i in data_out["General class"].unique()})

In [None]:
plot_2d_distribution(
        np.vstack((embedding_in, embedding_out)),
        np.array(["USPTO 50K"] * embedding_in.shape[0] + ["Reaxys Test"] * embedding_out.shape[0]),
        s=10,
        alpha=0.4,
        draw_legend=False,
        save_path="reaxys_and_uspto_map.png"
    )

In [None]:
sns.set_theme(style='white', font_scale=1.9)
for c in CLASS_COLORS:
    if c in data_in["General class"].unique() and c in data_out["General class"].unique():
        _u = embedding_in[data_in["General class"] == c]
        _r = embedding_out[data_out["General class"] == c]
        both = np.vstack((_u, _r))
        h = sns.jointplot(x=both[:, 0],
                          y=both[:, 1],
                          height=10,
                          kind='scatter',
                          alpha=0.7,
                          palette=["#000000", CLASS_COLORS[c]],
                          hue=np.array(["USPTO 50K"] * _u.shape[0] + ["Reaxys Test"] * _r.shape[0]))
        h.set_axis_labels("TSNE 1", "TSNE 2")
        # plt.suptitle(f"TSNE embeddings of points in USPTO 50K and the test set.\nClass: {c}")
        plt.suptitle(c)
        plt.tight_layout()
        plt.savefig(f"tsne_{c}.png", dpi=300)
        plt.show()

# Loading the test set for reagent prediction

In [None]:
PREDICT = False

In [None]:
test = pd.read_csv(TEST_PATH).fillna('')

In [None]:
test.head()

In [None]:
test.shape

In [None]:
test["General class"].value_counts(normalize=True)

In [None]:
src_for_rgs = test[SRC_COLUMN]
tgt_for_rgs = test["Reagents"]
print(SRC_COLUMN)

In [None]:
src_for_rgs

# Preprocessing Reaxys input for reagents prediction: reassigning roles, sorting

In [None]:
def extract_src(smi: str) -> str:
    left, center, right = smi.split(">")
    return left + ">>" + right


def extract_tgt(smi: str) -> str:
    left, center, right = smi.split(">")
    return center

In [None]:
# Reassigning reactants-reagents
reassigned_fullr = ut.parallelize_on_rows(test["FullR"],
                                          ut.assign_reaction_roles_schneider,
                                          num_of_processes=cpu_count(),
                                          use_tqdm=True)
fullrs = pd.concat((test["FullR"], reassigned_fullr), axis=1)
fullrs.columns = ['before_reass', 'after_reass']
fullrs["after_reass"] = fullrs["after_reass"].apply(lambda x: np.nan if (x.startswith(">") or ">>" in x) else x)
fullrs["after_reass"].fillna(fullrs["before_reass"], inplace=True)
print("Full reaction did not change roles:")
display((fullrs["after_reass"].apply(ut.order_molecules) == fullrs["before_reass"].apply(ut.order_molecules)).value_counts())

src_for_rgs = fullrs["after_reass"].apply(extract_src)
tgt_for_rgs = fullrs["after_reass"].apply(extract_tgt)

In [None]:
# Sorting molecules in the source
src_for_rgs = src_for_rgs.apply(ut.order_molecules)

In [None]:
src_for_rgs

In [None]:
(src_for_rgs == test[SRC_COLUMN]).value_counts()

In [None]:
(tgt_for_rgs == '').value_counts()

In [None]:
RUN_NAME = "ord_reass_" + RUN_NAME

# Reagent prediction model

The transformer predicts the reagents for a reaction in the form of a SMILES string. We trained our models to predict the reagents for the broad range of reactions without restrictions: the number of reagents and their particular roles are not predetermined. The model yields the output sequentially, token by token, treating the choice of a token as a multiclass classification problem and conditioning that choice on the input sequence and the previously decoded tokens. As we arranged the molecules in the target sequences in the training set by their roles in a reaction, the model is trained to first predict metal catalysts conditioned on the input sequence, if they are necessary, then redox agents conditioning on the input sequence and predicted catalysts, then acids and bases, then uncategorized molecules and ions like alkali metal cations, and finally solvents conditioned on everything mentioned. The solvents are the most interchangeable and they require the most context to predict. A similar reagent role ordering was also used in Gao et al. \cite{Gao2018}.

### Prediction on Reaxys without source ordering and role reassignment

In [None]:
reag_predictor = MTReagentPredictor(
        vocabulary_path=VOCAB,
        model_path=MODEL,
        tokenized_path=str((root / "data" / "test" / ("src_" + RUN_NAME)).with_suffix(".txt")),
        output_path=str((root / "experiments" / "results" / RUN_NAME).with_suffix(".txt")),
        beam_size=5,
        n_best=5,
        batch_size=32,
        gpu=GPU
    )

In [None]:
%%time
if PREDICT:
    reag_predictor.predict(src_for_rgs)

### Loading predictions

In [None]:
reag_predictor.load_predictions()
pred_rgs = reag_predictor.predictions
pred_rgs_conf = reag_predictor.pred_probs

### Cleaning up predictions a bit

In [None]:
for c in pred_rgs.columns:
    pred_rgs[c] = pred_rgs[c].apply(remove_redundant_separators)
    pred_rgs[c] = pred_rgs[c].apply(ut.IonAssembler.run)
    pred_rgs[c] = pred_rgs[c].apply(standardize_pd_pph3)

### Printing the number of predicted SMILES which are invalid

The model captures the SMILES language structure nicely - ***practically*** all of the predictions of reagent models are valid SMILES strings. 

In [None]:
for c in pred_rgs.columns:
    print(c, (pred_rgs[c] == '').sum())

# Calculating accuracy

We analyze the model's performance from several different perspectives:  
  * Perfect match accuracy:   
    The proportion of the predicted SMILES strings in which the molecules exactly match the groung truth molecules.
  * Patial match accuracy:   
    The proportion of the predicted SMILES strings in which at least one of the molecules is in the groung truth set. 
  * Recall: 
    Coverage of the ground truth sequence 

In [None]:
rgs_and_predictions = pd.concat(
        (tgt_for_rgs,
         pred_rgs),
        axis=1
    )
rgs_and_predictions.columns = ["ground_truth"] + list(rgs_and_predictions.columns[1:])
rgs_and_predictions.head()

### Exact match accuracy

In [None]:
topn_exact_match_acc = rgs_and_predictions.apply(
        lambda x: match_accuracy(x, 'exact'),
        axis=1
    )
topn_exact_match_acc = pd.DataFrame(topn_exact_match_acc.to_list())
topn_exact_match_acc.columns = [f"top_{i + 1}_exact" for i in range(pred_rgs.shape[1])]

print("Top-N exact match accuracy")
for c in topn_exact_match_acc.columns:
    print(c, topn_exact_match_acc[c].mean())

### Partial match accuracy

In [None]:
topn_partial_match_acc = rgs_and_predictions.apply(
        lambda x: match_accuracy(x, 'partial'),
        axis=1
    )
topn_partial_match_acc = pd.DataFrame(topn_partial_match_acc.to_list())
topn_partial_match_acc.columns = [f"top_{i + 1}_partial" for i in range(pred_rgs.shape[1])]

print("Top-N partial match accuracy")
for c in topn_partial_match_acc.columns:
    print(c, topn_partial_match_acc[c].mean())

### Recall

In [None]:
topn_recall = rgs_and_predictions.apply(
        lambda x: recall(x),
        axis=1
    )
topn_recall = pd.DataFrame(topn_recall.to_list())
topn_recall.columns = [f"top_{i + 1}_recall" for i in range(pred_rgs.shape[1])]

In [None]:
topn_recall

In [None]:
plt.plot(
    topn_recall["top_5_recall"].value_counts().sort_index(),
    '--o',
);

In [None]:
for c in topn_recall.columns:
    print(c, (topn_recall[c] == 1).mean())

### Accuracy without solvents

In [None]:
rgs_and_predictions

In [None]:
def remove_solvents(smi: str) -> str:
    """
    Removes standard solvents from a sequence of molecules separated by dots 
    """
    return '.'.join([i for i in smi.split('.') if i not in SOLVENTS])

In [None]:
rgs_and_predictions_no_solvents = pd.DataFrame()
for c in rgs_and_predictions.columns:
    rgs_and_predictions_no_solvents[c] = rgs_and_predictions[c].apply(remove_solvents)

In [None]:
rgs_and_predictions_no_solvents.head()

In [None]:
topn_exact_match_acc_no_solvents = rgs_and_predictions_no_solvents.apply(
        lambda x: match_accuracy(x, 'exact'),
        axis=1
    )
topn_exact_match_acc_no_solvents = pd.DataFrame(topn_exact_match_acc_no_solvents.to_list())
topn_exact_match_acc_no_solvents.columns = [f"top_{i + 1}_exact" for i in range(pred_rgs.shape[1])]

In [None]:
print("Top-N exact match accuracy")
for c in topn_exact_match_acc_no_solvents.columns:
    print(c, topn_exact_match_acc_no_solvents[c].mean())

In [None]:
topn_partial_match_acc_no_solvents = rgs_and_predictions_no_solvents.apply(
        lambda x: match_accuracy(x, 'partial'),
        axis=1
    )
topn_partial_match_acc_no_solvents = pd.DataFrame(topn_partial_match_acc_no_solvents.to_list())
topn_partial_match_acc_no_solvents.columns = [f"top_{i + 1}_partial" for i in range(pred_rgs.shape[1])]
print("Top-N partial match accuracy")
for c in topn_partial_match_acc_no_solvents.columns:
    print(c, topn_partial_match_acc_no_solvents[c].mean())

# How many unique reactions have at least one reagent set predicted right?

In [None]:
# Number of unique reactions in the test set
src_for_rgs.nunique()

In [None]:
# How many duplicates does every unique reaction have
src_for_rgs.value_counts().value_counts().sort_index()

In [None]:
unique_rxn_perf = pd.concat(
    (
        src_for_rgs,
        topn_exact_match_acc,
        topn_partial_match_acc,
        (topn_recall == 1).astype(int),
        test["General class"]
    ), 
    axis=1
)
unique_rxn_perf.columns = ["src"] + list(unique_rxn_perf.columns[1:])
unique_rxn_perf = unique_rxn_perf.groupby(by="src")

In [None]:
unique_exact_matches = [unique_rxn_perf[c].max() for c in topn_exact_match_acc.columns]
unique_partial_matches = [unique_rxn_perf[c].max() for c in topn_partial_match_acc.columns]
unique_recalls = [unique_rxn_perf[c].max() for c in topn_recall.columns]
unique_rxn_types = [unique_rxn_perf["General class"].apply(lambda x: x.iloc[0])]
unique_rxn_perf = pd.concat(
    unique_exact_matches + unique_partial_matches + unique_recalls + unique_rxn_types,
    axis=1
).reset_index()

In [None]:
for c in topn_exact_match_acc.columns:
    print(c, round(100 * unique_rxn_perf[c].mean(), 1))
print()  
for c in topn_partial_match_acc.columns:
    print(c, round(100 * unique_rxn_perf[c].mean(), 1))
print()
for c in topn_recall.columns:
    print(c, round(100 * unique_rxn_perf[c].mean(), 1))

# Confidence of predictions

In [None]:
fig, ax = plot_confidence_distributions(
    pred_rgs_conf["p_reagents_1_conf"],
    topn_exact_match_acc["top_1_exact"]
)
# fig.suptitle("Distribution of the confidence scores of the model")
ax[0].set_title("Overall distribution of confidence scores", fontsize=16)
ax[0].set_xlabel("Proportion of reactions", fontsize=14)
ax[0].set_ylabel("Confidence of the first prediction", fontsize=16)
plt.sca(ax[0])
plt.yticks(fontsize=14)

ax[1].set_title("Confidence scores for reactions\n predicted correctly and not", fontsize=16)
ax[1].set_ylabel(None)
plt.sca(ax[1])
plt.xticks(ticks=range(2), labels=["Top-1 (exact) miss", "Top-1 (exact) hit"], fontsize=14)
plt.xlabel(None)

plt.tight_layout()
# plt.savefig("confidence.png", dpi=300)
plt.show()

# Performance across reaction classes

In [None]:
def plot_performance_across_classes(classes_col: 'pd.Series',
                                    exact_accuracies: 'pd.DataFrame',
                                    partial_accuracies: 'pd.DataFrame',
                                    width: float = 0.2,
                                    **kwargs):
    """
    Plots the performance of reagent prediction across reaction classes.
    :param classes_col: The column with a reaction class name for each reaction
    :param exact_accuracies: The dataframe with exact match accuracies.
    Columns are thought to be ordered by priority in descending order.
    :param partial_accuracies: The dataframe with partial match accuracies.
    Columns are thought to be ordered by priority in descending order.
    :param width: The width of bars to plot
    :return:
    """
    fontsize = kwargs.get("fontsize", 16)
    fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=False, figsize=kwargs.get("figsize", (8, 8)))

    colors_exact = {1: "#239b56", 3: "#58d68d", 5: "#abebc6"}
    colors_partial = {1: "#922b21", 3: "#cd6155", 5: "#e6b0aa"}

    ascending_top_1_order = classes_col.value_counts().index
    _accs_classes = pd.concat((classes_col, exact_accuracies, partial_accuracies), axis=1)
    class_col = _accs_classes.columns[0]
    for i in (5, 3, 1):
        score_col = exact_accuracies.columns[i - 1]
        ax[0].bar(
            np.arange(len(ascending_top_1_order)),
            _accs_classes.groupby(by=[class_col])[score_col].mean()[ascending_top_1_order],
            width,
            color=colors_exact[i],
            alpha=1,
            label=f"Top-{i} exact"
        )
        score_col = partial_accuracies.columns[i - 1]
        ax[0].bar(
            np.arange(len(ascending_top_1_order)) + width,
            _accs_classes.groupby(by=[class_col])[score_col].mean()[ascending_top_1_order],
            width,
            color=colors_partial[i],
            alpha=1,
            label=f"Top-{i} partial"
        )

    ax[0].tick_params(length=0)
    plt.sca(ax[0])
    # plt.legend()
    plt.grid(axis='y', alpha=0.5)
    plt.title(
        "Reagent prediction scores across reaction classes",
        fontsize=fontsize,
        loc='right',
        pad=20)

    classes_decreasing = _accs_classes[class_col].value_counts(normalize=True)[ascending_top_1_order]
    ax[1].bar(
        np.arange(len(ascending_top_1_order)) + width / 2,
        classes_decreasing,
        width,
        color=[CLASS_COLORS.get(i, "#00FF00") for i in classes_decreasing.index],
        alpha=0.6,
    )

    ax[0].set_yticks(np.arange(0.0, 1.0, 0.1))
    plt.sca(ax[0])
    plt.yticks(fontsize=int(fontsize * 0.8))

    ax[1].tick_params(length=0)
    plt.sca(ax[1])
    # plt.xticks(np.arange(len(ascending_top_1_order)) + width / 2, ascending_top_1_order, rotation=19, fontsize=int(fontsize * 0.8))
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off
    plt.yticks(fontsize=int(fontsize * 0.8))

    # plt.xticks(rotation=15)
    plt.xlabel(None)
    
    plt.grid(axis='y', alpha=0.5)
    plt.title("Distribution of reaction classes in the test", fontsize=fontsize, loc="right", pad=20)

    # fig.suptitle(
    #     f"Performance of the transformer model on the test set.\nTest size: {test.shape[0]} reactions.",
    #     fontsize=fontsize)

In [None]:
colors={i: CLASS_COLORS.get(i, "#00FF00") for i in data_out["General class"].unique()

In [None]:
{i: CLASS_COLORS.get(i, "#00FF00") for i in test["General class"].unique()}

In [None]:
CLASS_COLORS

In [None]:
plot_performance_across_classes(
    unique_rxn_perf["General class"], # test["General class"]
    unique_rxn_perf[topn_exact_match_acc.columns], # topn_exact_match_acc
    unique_rxn_perf[topn_partial_match_acc.columns], # topn_partial_match_acc 
    figsize=(16, 15),
    width=0.39,
    fontsize=20
)
plt.savefig("reaxys_classes_new.png", dpi=300)
plt.show()

Correlates with the TSNE maps similarity of corresponding reaction classes

# Performance across reagent roles

### Roles of the target reagents

In [None]:
test_r_roles = pd.concat(
        (
            test["General class"],
            tgt_for_rgs.apply(HeuristicRoleClassifier.classify_to_str).str.split("&", expand=True)
        ),
        axis=1
    )
test_r_roles.columns = ["General class"] + [f"tgt_{r}" for r in HeuristicRoleClassifier.types]
test_r_roles.head()

In [None]:
test_r_roles.iloc[10000:10005]

### Roles of the predicted reagents for each of the N predictions

In [None]:
n = pred_rgs.shape[1]
top_n_roles = {}
for i in range(1, n + 1):
    _temp = pred_rgs[f"p_reagents_{i}"].apply(
        HeuristicRoleClassifier.classify_to_str
    ).str.split("&",
                expand=True)
    _temp.columns = [f"pred_{r}" for r in HeuristicRoleClassifier.types]
    top_n_roles[i] = _temp

In [None]:
for i in top_n_roles:
    print(f"Pred {i}")
    display(top_n_roles[i].head())

### Top-N exact match of predictions for each role

In [None]:
role_performance = {}
for r in HeuristicRoleClassifier.types:
    
    tgt_and_pred_roles = pd.concat(
        [test_r_roles[f"tgt_{r}"]] + [top_n_roles[i][f"pred_{r}"] for i in range(1, n + 1)],
        axis=1
    )
    perf = tgt_and_pred_roles.apply(
        lambda x: match_accuracy(x, 'exact'),
        axis=1
    )
    perf = pd.DataFrame(perf.to_list())
    perf.columns = [f"top_{i}_exact" for i in range(1, n + 1)]
    role_performance[r] = perf

In [None]:
role_perf_summary = {}
for r in HeuristicRoleClassifier.types:
    nonempty_target = test_r_roles[test_r_roles[f"tgt_{r}"] != ''].index
    role_perf_summary[r] = {
        "Top-1": role_performance[r]["top_1_exact"].mean(),
        "Top-1\nnon-empty\nground truth": role_performance[r].loc[nonempty_target]["top_1_exact"].mean(),
        "Top-5": role_performance[r]["top_5_exact"].mean(),
        "Top-5\nnon-empty\nground truth": role_performance[r].loc[nonempty_target]["top_5_exact"].mean()
    }

In [None]:
role_perf_summary

In [None]:
role_perf_summary = pd.DataFrame.from_dict(role_perf_summary)
role_perf_summary

In [None]:
role_perf_summary = role_perf_summary.T
role_perf_summary

In [None]:
plt.figure(figsize=(8, 8))
sns.heatmap(role_perf_summary, 
            annot=True, 
            annot_kws={"fontsize": 13},
            cmap="mako", 
            ax=plt.gca())
# plt.title(
#     "Proportion of test examples on which the prediction\nmatches the ground truth exactly in each reagent role.",
#     fontsize=14)
plt.yticks(rotation=0, fontsize=14)
plt.xticks(fontsize=14)
# plt.xticks(fontsize=16)
plt.tight_layout()
plt.savefig("perf_class_1_new.png", dpi=300)
plt.show()

# Performance on reaction classes vs reagent roles

In [None]:
role_top_1_match_by_class = []
role_top_1_match_by_class_nonempty = []
role_nonempty_count = []

for r in HeuristicRoleClassifier.types:
    
    # Stacking the groung truth reagents in one role against prediction performance on that role
    _temp = pd.concat(
        (
            test_r_roles[["General class", f"tgt_{r}"]], 
            role_performance[r]
        ),
        axis=1
    )
    # Need to be careful with empty GT
    _temp["tgt_nonempty"] = _temp[f"tgt_{r}"].apply(lambda x: x != '')
    
    # Proportion of matches by class
    role_top_1_match_by_class.append(
        _temp.groupby(by="General class")["top_1_exact"].mean()
    )
    
    # Proportion of matches by class for nonempty GT
    role_top_1_match_by_class_nonempty.append(
        _temp[_temp["tgt_nonempty"]].groupby(by="General class")["top_1_exact"].mean()
    )
    
    # Number of nonempty GT per class
    role_nonempty_count.append(
        _temp.groupby(by="General class")["tgt_nonempty"].sum()
    )

# Concatenating information for each role    

role_top_1_match_by_class = pd.concat(role_top_1_match_by_class, axis=1)
role_top_1_match_by_class.columns = HeuristicRoleClassifier.types
role_top_1_match_by_class = role_top_1_match_by_class.T

role_top_1_match_by_class_nonempty = pd.concat(role_top_1_match_by_class_nonempty, axis=1)
role_top_1_match_by_class_nonempty.columns = HeuristicRoleClassifier.types
role_top_1_match_by_class_nonempty = role_top_1_match_by_class_nonempty.T

role_nonempty_count = pd.concat(role_nonempty_count, axis=1)
role_nonempty_count.columns = HeuristicRoleClassifier.types
role_nonempty_count = role_nonempty_count.T

In [None]:
xticklabels = [
    'Acylation\nand related\nprocesses', 
    'C-C bond\nformation',
    'Deprotections', 
    'Functional\ngroup\naddition\n(FGA)',
    'Functional\ngroup\ninterconversion\n(FGI)',
    'Heteroatom\nalkylation\nand arylation', 
    'Heterocycle\nformation',
    'Oxidations', 
    'Protections', 
    'Reductions'
]

In [None]:
fig, ax = plt.subplots(ncols=1, nrows=3, sharex=True, figsize=(20, 20))
angle = 0
annot_size = 13
title_size = 16
tick_size = 14
pad = 20

sns.heatmap(role_top_1_match_by_class, 
            annot=True, 
            ax=ax[0], 
            annot_kws={"fontsize": annot_size},
            xticklabels=xticklabels,
            fmt='.3g',
            cmap="mako")

plt.sca(ax[0])
plt.ylabel(None)
plt.xlabel(None)

plt.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False,
    labeltop=True) # labels along the bottom edge are off
plt.xticks(rotation=angle, fontsize=tick_size)
plt.title("Top-1 accuracy of predictions by roles across classes", fontsize=title_size, loc="right", pad=pad, fontweight='bold')
plt.yticks(fontsize=tick_size)

sns.heatmap(role_top_1_match_by_class_nonempty, 
            annot=True, 
            ax=ax[1], 
            annot_kws={"size": annot_size},
            xticklabels=xticklabels,
            fmt='.3g',
            cmap="mako")
plt.sca(ax[1])
plt.ylabel(None)
plt.xlabel(None)
plt.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off

# plt.xticks(rotation=angle, fontsize=14)
plt.title("Top-1 accuracy of predictions by roles across classes for nonempty ground truth", fontsize=title_size, loc="right", pad=pad, fontweight='bold')
plt.yticks(fontsize=tick_size)

sns.heatmap(role_nonempty_count, 
            annot=True, 
            xticklabels=xticklabels,
            annot_kws={"size": annot_size},
            fmt='d',
            ax=ax[2], 
            cmap="mako")
plt.sca(ax[2])
plt.title("Number of non-empty ground truth SMILES", fontsize=title_size, loc="right", pad=pad, fontweight='bold')
plt.ylabel(None)
plt.xlabel(None)
plt.tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
# plt.xticks(rotation=angle, fontsize=14)
plt.yticks(fontsize=tick_size)

# fig.suptitle("Performance of reagent prediction across reaction classes and reagent roles", fontsize=16)
plt.tight_layout()
plt.savefig("perf_class_2_new.png", dpi=300)
plt.show()

# Comparing the number of predicted molecules with the length of GT

In [None]:
n_tgt_mols = test["Reagents"].apply(lambda x: len(x.split('.')))

In [None]:
def label_difference(n: int):
    if n == 0:
        return "Same"
    elif n < 0:
        return "Less"
    else:
        return "More"

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=pred_rgs.shape[1], sharey=True, figsize=(14, 4))
order = ["Less", "Same", "More"]
for i in range(pred_rgs.shape[1]):
    n_pred_mols = pred_rgs[f"p_reagents_{i + 1}"].apply(lambda x: len(x.split('.')))
    (n_pred_mols - n_tgt_mols).apply(label_difference).value_counts(normalize=True).loc[order].plot(kind='bar', ax=ax[i])
    ax[i].set_title(f"Prediction {i + 1}")

fig.suptitle("Number of predicted molecules compared to the number of molecules in GT")
plt.show()

The model tends to predict less molecules than needed

# Unique reagents in test

In [None]:
unique_reagents = ut.get_reagent_statistics(test["Reagents"])

In [None]:
print("Unique reagents in the test set")
len(unique_reagents)

In [None]:
print("Unique reagents if they are repeated more than 20 times")
len([(k, v) for k, v in unique_reagents.items() if v >= 20])

# Performance across target lengths

### Across the number of molecules in target

In [None]:
topn_exact_match_acc

In [None]:
pd.concat(
    (test["Reagents"].apply(lambda x: len(x.split('.'))),
     topn_exact_match_acc),
     axis=1
).groupby(by="Reagents")["top_1_exact"].mean().plot(kind='bar')

In [None]:
pd.concat(
    (test["Reagents"].apply(lambda x: len(x.split('.'))),
     topn_partial_match_acc),
     axis=1
).groupby(by="Reagents")["top_1_partial"].mean().plot(kind='bar')

### Across the target string length

In [None]:
plt.figure(figsize=(13, 5))

plt.plot(
    pd.concat(
        (test["Reagents"].apply(len),
         topn_exact_match_acc),
         axis=1
    ).groupby(by="Reagents")["top_1_exact"].mean(),
    '-x'
)

plt.show()

In [None]:
plt.figure(figsize=(13, 5))

plt.plot(
    pd.concat(
        (test["Reagents"].apply(len),
         topn_partial_match_acc),
         axis=1
    ).groupby(by="Reagents")["top_1_partial"].mean(),
    '-x'
)

plt.show()

Small improvement, like 3 percent points. Measuring performance like that improves the score, more duplicates needed.

# Number of unique predictions and unique GT sequence for each reaction class

In [None]:
compar = pd.concat((test["General class"], rgs_and_predictions), axis=1)

In [None]:
compar[compar["General class"] == "Reductions"]["p_reagents_1"].nunique()

In [None]:
pd.concat(
    [compar.groupby(by="General class")[c].nunique() for c in compar.columns[1:]],
    axis=1
).style.background_gradient(cmap='Blues', axis=1)

Model's predictions are not particularly diverse

# Examples of incorrect predictions

Here the difference in train and test distributions very much manifests itself

In [None]:
failed_in_exact = test.loc[topn_exact_match_acc[topn_exact_match_acc["top_5_exact"] == 0].index]
failed_in_partial = test.loc[topn_partial_match_acc[topn_partial_match_acc["top_5_partial"] == 0].index]

In [None]:
c = "C-C bond formation"

In [None]:
fail = failed_in_partial[failed_in_partial["General class"] == c]
fail_pred = pred_rgs.loc[fail.index]

In [None]:
ix = iter(fail.index)

In [None]:
i = next(ix)
print(i)
print("GT:", fail.loc[i]["Reagents"])
display(ut.draw_reaction_smarts(fail.loc[i]["FullR"]))
left, right = fail.loc[i]["ReactionRgsSource"].split(">>")
center = fail_pred.loc[i]["p_reagents_1"]
pr = ">".join((left, center, right))
print("Prediction:", center)
display(ut.draw_reaction_smarts(pr))

### Reactions with zero top-5 recall

In [None]:
no_recall_index = topn_recall[topn_recall["top_5_recall"] == 0.0].index

In [None]:
no_recall = test.loc[no_recall_index]
no_recall_pred = pred_rgs.loc[no_recall_index]

In [None]:
test["General class"].value_counts()

In [None]:
failing_classes = no_recall["General class"].value_counts().loc[test["General class"].value_counts().index] / test["General class"].value_counts()
failing_classes = failing_classes.sort_values()

In [None]:
failing_classes

In [None]:
plt.figure(figsize=(20, 5))
plt.bar(
    list(failing_classes.index),
    list(failing_classes)
)
plt.xticks(rotation=15, fontsize=11)
plt.grid(axis='y')
plt.title("Proportion of reactions in a class with zero top-5 recall within the test set")
plt.savefig("/home/mandronov/Pictures/no_recall.png")
plt.show()

In [None]:
no_recall = no_recall[no_recall["General class"] == 'Functional group addition (FGA)']

In [None]:
ix = iter(no_recall.index)

In [None]:
i = next(ix)
print(i)
print("GT:", no_recall.loc[i]["Reagents"])
display(ut.draw_reaction_smarts(no_recall.loc[i]["FullR"]))
left, right = no_recall.loc[i]["ReactionRgsSource"].split(">>")
center = no_recall_pred.loc[i]["p_reagents_1"]
pr = ">".join((left, center, right))
print("Prediction:", center)
display(ut.draw_reaction_smarts(pr))

# Attention patterns

In [None]:
test[test["General class"] == "Reductions"]

In [None]:
rgs_and_predictions.loc[12766:12788]

In [None]:
att_sample_rxn = test.loc[12786]["ReactionRgsSource"]
print(att_sample_rxn)
ut.draw_reaction_smarts(att_sample_rxn)

In [None]:
sample_path_in = f"../data/test/att_sample.txt"
sample_path_out = f"../experiments/results/att_sample.txt"

In [None]:
sample_pred = MTReagentPredictor(VOCAB, None, None, None)
tokenized_sample = " ".join(
    sample_pred.tokenizer_source.tokenize([att_sample_rxn])[0][1:-1]
)
with open(sample_path_in, "w") as f:
    f.write(tokenized_sample)

In [None]:
tokenized_sample

In [None]:
command = [
    "python3", "/home/mandronov/work/reagents/translate.py",
    "-model", MODEL,
    "-src", sample_path_in,
    "-output", sample_path_out,
    "-batch_size", "32",
    "-max_length", "200",
    "-beam_size", "5",
    "-n_best", "5", 
    "-attn_debug", "-replace_unk", "-fast",
    "-gpu", "0"]

att_w = subprocess.check_output(command)

In [None]:
att_w = att_w.decode('utf-8').split("\n")

In [None]:
source_tokens = att_w[0].split()

In [None]:
lines = [i.split() for i in att_w[1:] if (i and not i.startswith("PRED"))]

output_tokens = []
scores = []
for line in lines:
    output_tokens.append(line[0])
    scores.append([float(j.strip('*')) for j in line[1:]])
scores = pd.DataFrame(scores)
scores.index = output_tokens
scores.columns = source_tokens

In [None]:
plt.figure(figsize=(15, 10))
axx = plt.gca()
sns.heatmap(scores, cmap="rocket_r", ax=axx, xticklabels=True, yticklabels=True)
axx.tick_params(axis='both', which='both', length=0, labelsize=10, labelbottom=False, bottom=False, top=False, labeltop=True)

axx.set_xlabel("Source sequence", fontsize=16)
axx.set_ylabel("Output sequence", fontsize=16)
axx.xaxis.set_label_position('top') 
plt.xticks(rotation=0, fontsize=9)
plt.yticks(rotation=0, fontsize=9)
plt.show()

In [None]:
left, right = "".join(source_tokens).split(">>")
decoded_reaction = left + ">" + "".join(output_tokens[:-1]) + ">" + right

In [None]:
print(decoded_reaction)
ut.draw_reaction_smarts(decoded_reaction)

# Inspecting the replaced reagents in the MIT subset of USPTO

In [None]:
path = "../data/tokenized/MIT_separated"
arrow = ">"

In [None]:
subset = "train" # "val"

In [None]:
with open(path + f"/src-{subset}.txt") as f, open(path + f"_no_reags/src-{subset}.txt") as g, open(path + f"_reags_top1/src-{subset}.txt") as h, open(path + f"_reags_top1_and_rdkit/src-{subset}.txt") as ff, open(path + f"_reags_role_voting/src-{subset}.txt") as gg:
    base = [i.strip().replace(" ", "") for i in f.readlines()]
    no_rgs = [i.strip().replace(" ", "") for i in g.readlines()]
    top = [i.strip().replace(" ", "") for i in h.readlines()]
    mix = [i.strip().replace(" ", "") for i in ff.readlines()]
    role = [i.strip().replace(" ", "") for i in gg.readlines()]
with open(path + f"/tgt-{subset}.txt") as f:
    tgt = [i.strip().replace(" ", "") for i in f.readlines()]

In [None]:
base = [i + arrow + j for i, j in zip(base, tgt)]
no_rgs = [i + ">>" + j for i, j in zip(no_rgs, tgt)]
top = [i + arrow + j for i, j in zip(top, tgt)]
mix = [i + arrow + j for i, j in zip(mix, tgt)]
role = [i + arrow + j for i, j in zip(role, tgt)]

Examples of indexes of reactions with replaced reagents in train:  
1,   3,   6,   8,  17,  22,  25,  26,  33,  38,  39,  42,  47,
            51,  56,  59,  60,  64,  68,  76,  81,  83,  84,  87,  91,  92,
             94,  98,  99, 100, 102, 103, 105, 119, 121, 124, 129, 130, 133,
            136, 143, 144, 148, 151, 153, 154, 156, 162, 164, 168, 169, 178,
            180, 181, 182, 195, 200, 201, 203, 208, 209, 210, 211, 216, 222,
            228, 233, 243, 246, 249, 250, 254, 256, 257, 263, 266, 268, 273,
            274, 275, 277, 279, 284, 288, 289, 293, 294, 301, 309, 310, 321,
            324, 327, 328, 330, 332, 333, 334, 338, 342  
 
Examples of indexes of reactions with replaced reagents in val: 
0,   3,   5,   6,   8,  22,  24,  28,  35,  47,  48,  51,  53,
             59,  63,  67,  68,  72,  73,  76,  78,  79,  85,  88,  93,  97,
            106, 107, 108, 110, 118, 129, 134, 138, 140, 143, 145, 146, 148,
            156, 166, 169, 172, 173, 174, 186, 189, 194, 195, 199, 202, 209,
            211, 213, 214, 224, 226, 230, 231, 238, 243, 244, 245, 249, 250,
            251, 253, 262, 271, 272, 276, 277, 283, 290, 291, 292, 296, 302,
            309, 314, 316, 321, 327, 332, 334, 338, 342, 347, 351, 352, 364,
            372, 373, 377, 378, 390, 393, 395, 397, 405


In [None]:
ix = iter(range(len(base)))

In [None]:
i = next(ix)
print(i)
print("base")
display(ut.draw_reaction_smarts(base[i]))
print("mix")
display(ut.draw_reaction_smarts(mix[i]))
print("top")
display(ut.draw_reaction_smarts(top[i]))
print("no")
display(ut.draw_reaction_smarts(no_rgs[i]))
print("role")
display(ut.draw_reaction_smarts(role[i]))

# Products prediction

### Separated setting

In [None]:
fullr = (test["Reactants"] + '>' + test["Reagents"] + '>' + test["Product"])
reassigned_fullr = ut.parallelize_on_rows(fullr,
                                          ut.assign_reaction_roles_schneider,
                                          num_of_processes=12,
                                          use_tqdm=True)
fullrs = pd.concat((fullr, reassigned_fullr), axis=1)

In [None]:
fullrs.columns = ["before_reass", "after_reass"]
fullrs["after_reass"] = fullrs["after_reass"].apply(lambda x: np.nan if (x.startswith(">") or ">>" in x) else x)
fullrs["after_reass"].fillna(fullrs["before_reass"], inplace=True)

In [None]:
def extract_separated_src(smi: str) -> str:
    left, center, right = smi.split(">")
    return left + ">" + center


precursors = fullrs["after_reass"].apply(extract_separated_src)

In [None]:
# precursors = (test["Reactants"] + '.' + test["Reagents"]).str.strip('.')
# precursors = test["Reactants"] + '>' + test["Reagents"]

In [None]:
precursors

In [None]:
def order(s: str) -> str:
    left, center = [i.strip() for i in s.split(">")]
    left = ".".join(sorted(left.split("."), key=lambda x: (len(x), x), reverse=True))
    center = ".".join(sorted(center.split("."), key=lambda x: (len(x), x), reverse=True))
    return left + ">" + center

precursors_ordered = precursors.apply(order)

In [None]:
precursors_ordered

In [None]:
precursors.nunique(), precursors.shape, precursors_ordered.nunique(), precursors_ordered.shape

In [None]:
MODELS = {
    "mit_sep_40k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_40000.pt",
    "mit_sep_40k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_40000.pt",
    "mit_sep_40k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_40000.pt",
    
    "mit_sep_60k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_60000.pt",
    "mit_sep_60k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_60000.pt",
    "mit_sep_60k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_60000.pt",
    
    "mit_sep_80k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_80000.pt",
    "mit_sep_80k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_80000.pt",
    "mit_sep_80k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_80000.pt",
    
    "mit_sep_100k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_100000.pt",
    "mit_sep_100k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_100000.pt",
    
    "mit_sep_120k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_120000.pt",
    "mit_sep_120k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_120000.pt",
    
    "mit_sep_160k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_160000.pt",
    "mit_sep_160k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_160000.pt",
    
    "mit_sep_200k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_200000.pt",
    "mit_sep_200k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_200000.pt",
    
    "mit_sep_220k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_220000.pt",
    "mit_sep_220k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_220000.pt",
    "mit_sep_220k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_220000.pt",
    
    "mit_sep_260k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_260000.pt",
    "mit_sep_260k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_260000.pt",
    "mit_sep_260k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_260000.pt",
    "mit_sep_260k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_260000.pt",
    
    "mit_sep_300k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_300000.pt",
    "mit_sep_300k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_300000.pt",
    "mit_sep_300k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_300000.pt",
    "mit_sep_300k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_300000.pt",
    
    "mit_sep_320k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_320000.pt",
    "mit_sep_320k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_320000.pt",
    "mit_sep_320k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_320000.pt",
    
    "mit_sep_350k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_350000.pt",
    "mit_sep_350k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_350000.pt",
    "mit_sep_350k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_350000.pt",
    "mit_sep_350k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_350000.pt",
    
    "mit_sep_370k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_370000.pt",
    "mit_sep_370k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_370000.pt",
    "mit_sep_370k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_370000.pt",
    "mit_sep_370k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_370000.pt",
    
    "mit_sep_400k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_400000.pt",
    "mit_sep_400k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_400000.pt",
    "mit_sep_400k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_400000.pt",
    "mit_sep_400k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_400000.pt",
    
    "mit_sep_450k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_450000.pt",
    "mit_sep_450k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_450000.pt",
    "mit_sep_450k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_450000.pt",
    "mit_sep_450k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_450000.pt",
    
    "mit_sep_500k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_model_1gpu_step_500000.pt",
    "mit_sep_500k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_model_1gpu_step_500000.pt",
    "mit_sep_500k_no_rg": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_no_reags_model_1gpu_step_500000.pt",
    "mit_sep_500k_repl_sort": "/home/mandronov/work/reagents/experiments/trained_models/MIT_separated_reags_top1_and_rdkit_sorted_model_1gpu_step_500000.pt"
}

In [None]:
product_predictors = {}
for name, model in MODELS.items():
    product_predictors[name] = MTProductPredictor(
        model_path=model,
        tokenized_path=f"../data/test/src_{name}.txt",
        output_path=f"../experiments/results/test_reaxys_{name}.txt",
        beam_size=5,
        n_best=5,
        gpu=0
    ) 

In [None]:
product_predictors

In [None]:
def model_label(name: str) -> str:
    if name.endswith("repl"):
        return "new"
    elif name.endswith("no_rg"):
        return "no_rg"
    elif name.endswith("repl_sort"):
        return "new_sort"
    else:
        return "base"
    
def step_num(name: str) -> int:
    s = name.lstrip("mit_sep")
    s = s[:s.index("k")]
    return int(s)

In [None]:
%%time
# No need to execute is twice if the predictions are already stored on the disk
for k in product_predictors:
    if model_label(k) == "no_rg":
        product_predictors[k].predict(precursors.apply(lambda x: x.split(">")[0]))
    elif model_label(k) == "new_sort":
        product_predictors[k].predict(precursors_ordered)
    # else:
        # product_predictors[k].predict(precursors)

In [None]:
for k in product_predictors:
    product_predictors[k].load_predictions()

In [None]:
for k in product_predictors:
    for c in product_predictors[k].predictions.columns:
        product_predictors[k].predictions[c] = product_predictors[k].predictions[c].progress_apply(lambda smi: ut.canonicalize_smiles(max(smi.split('.'), key=len)))

In [None]:
prod_pred_acc = {
    "Model": [], 
    "top_1_exact": [], 
    "top_2_exact": [], 
    "top_3_exact": [], 
    "top_4_exact": [], 
    "top_5_exact": []
}

In [None]:
for k in product_predictors:
    prod_pred_acc["Model"].append(k)
    prod_topn_exact_match_acc = pd.concat(
        (
            test["Product"], 
            product_predictors[k].predictions
        )
        , axis=1).apply(lambda x: match_accuracy(x, 'exact'), axis=1)

    prod_topn_exact_match_acc = pd.DataFrame(prod_topn_exact_match_acc.to_list())
    prod_topn_exact_match_acc.columns = [f"top_{i + 1}_exact" for i in range(5)]

    for c in prod_topn_exact_match_acc.columns:
        prod_pred_acc[c].append(prod_topn_exact_match_acc[c].sum())

In [None]:
prod_pred_acc_df = prod_pred_acc

In [None]:
prod_pred_acc_df = pd.DataFrame.from_dict(prod_pred_acc)

In [None]:
prod_pred_acc_df.head()

In [None]:
prod_pred_acc_df["Step"] = prod_pred_acc_df["Model"].apply(step_num)
prod_pred_acc_df["Model_Name"] = prod_pred_acc_df["Model"].apply(model_label)
prod_pred_acc_df

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=5, sharey=True, figsize=(28, 10))
_base = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "base"]
_new = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "new"]
_no_rg = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "no_rg"]
_new_sort = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "new_sort"]
for i in range(5):
    plt.sca(ax[i])
    plt.plot(_base.Step, _base[f"top_{i + 1}_exact"], '-o', label=f'Base top {i + 1}')
    plt.plot(_new.Step, _new[f"top_{i + 1}_exact"], '-o', label=f'New top {i + 1}')
    plt.plot(_no_rg.Step, _no_rg[f"top_{i + 1}_exact"], '-o', label=f'No reag. top {i + 1}')
    plt.plot(_new_sort.Step, _new_sort[f"top_{i + 1}_exact"], '-o', label=f'New sort top {i + 1}')
    plt.legend()
    plt.grid()
plt.subplots_adjust()
plt.show()

In [None]:
prod_pred_acc_df

In [None]:
81855 / len(test["Product"])

In [None]:
for i in range(5):
    display(prod_pred_acc_df[["Model", f"top_{i + 1}_exact"]].sort_values([f"top_{i + 1}_exact"], ascending=False).head())

### Mixed setting

In [None]:
precursors_mixed = (test["Reactants"] + '.' + test["Reagents"]).str.strip('.')

In [None]:
precursors_mixed

In [None]:
precursors_mixed.nunique(), precursors_mixed.shape

In [None]:
MODELS_MIXED = {
    "mit_mix_40k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_40000.pt",
    "mit_mix_40k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_40000.pt",
    "mit_mix_60k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_60000.pt",
    "mit_mix_60k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_60000.pt",
    "mit_mix_80k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_80000.pt",
    "mit_mix_80k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_80000.pt",
    "mit_mix_100k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_100000.pt",
    "mit_mix_100k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_100000.pt",
    "mit_mix_120k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_120000.pt",
    "mit_mix_120k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_120000.pt",
    "mit_mix_140k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_140000.pt",
    "mit_mix_140k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_140000.pt",
    "mit_mix_160k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_160000.pt",
    "mit_mix_160k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_160000.pt",
    "mit_mix_180k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_180000.pt",
    "mit_mix_180k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_180000.pt",
    "mit_mix_200k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_200000.pt",
    "mit_mix_200k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_200000.pt",
    "mit_mix_220k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_220000.pt",
    "mit_mix_220k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_220000.pt",
    "mit_mix_250k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_250000.pt",
    "mit_mix_250k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_250000.pt",
    "mit_mix_300k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_300000.pt",
    "mit_mix_300k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_300000.pt",
    "mit_mix_320k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_320000.pt",
    "mit_mix_320k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_320000.pt",
    "mit_mix_350k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_350000.pt",
    "mit_mix_350k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_350000.pt",
    "mit_mix_370k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_370000.pt",
    "mit_mix_370k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_370000.pt",
    "mit_mix_400k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_400000.pt",
    "mit_mix_400k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_400000.pt",
    "mit_mix_450k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_450000.pt",
    "mit_mix_450k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_450000.pt",
    "mit_mix_500k": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_model_1gpu_step_500000.pt",
    "mit_mix_500k_repl": "/home/mandronov/work/reagents/experiments/trained_models/MIT_mixed_reags_top1_and_rdkit_model_1gpu_step_500000.pt"
}


In [None]:
product_predictors_mixed = {}
for name, model in MODELS_MIXED.items():
    product_predictors_mixed[name] = MTProductPredictor(
        model_path=model,
        tokenized_path=f"../data/test/src_{name}.txt",
        output_path=f"../experiments/results/test_reaxys_{name}.txt",
        beam_size=5,
        n_best=5,
        gpu=0
    ) 

In [None]:
product_predictors_mixed

In [None]:
%%time
# No need to execute is twice if the predictions are already stored on the disk
for k in product_predictors_mixed:
    product_predictors_mixed[k].predict(precursors_mixed)

In [None]:
for k in product_predictors_mixed:
    product_predictors_mixed[k].load_predictions()

In [None]:
for k in product_predictors_mixed:
    for c in product_predictors_mixed[k].predictions.columns:
        product_predictors_mixed[k].predictions[c] = product_predictors_mixed[k].predictions[c].progress_apply(lambda smi: ut.canonicalize_smiles(max(smi.split('.'), key=len)))

In [None]:
prod_pred_acc_mixed = {
    "Model": [], 
    "top_1_exact": [], 
    "top_2_exact": [], 
    "top_3_exact": [], 
    "top_4_exact": [], 
    "top_5_exact": []
}

In [None]:
for k in product_predictors_mixed:
    prod_pred_acc_mixed["Model"].append(k)
    prod_topn_exact_match_acc_mixed = pd.concat(
        (
            test["Product"], 
            product_predictors_mixed[k].predictions
        )
        , axis=1).apply(lambda x: match_accuracy(x, 'exact'), axis=1)

    prod_topn_exact_match_acc_mixed = pd.DataFrame(prod_topn_exact_match_acc_mixed.to_list())
    prod_topn_exact_match_acc_mixed.columns = [f"top_{i + 1}_exact" for i in range(5)]

    for c in prod_topn_exact_match_acc_mixed.columns:
        prod_pred_acc_mixed[c].append(prod_topn_exact_match_acc_mixed[c].sum())

In [None]:
prod_pred_acc_mixed_df = pd.DataFrame.from_dict(prod_pred_acc_mixed)

In [None]:
prod_pred_acc_mixed_df["Step"] = prod_pred_acc_mixed_df["Model"].apply(lambda x: x.strip("_repl").split("_")[-1].strip("k")).astype(int)
prod_pred_acc_mixed_df["Model_Name"] = prod_pred_acc_mixed_df["Model"].apply(lambda x: "new" if x.endswith("repl") else "base")
prod_pred_acc_mixed_df

In [None]:
prod_pred_acc_mixed_df.loc[[34, 35]]["top_5_exact"] / test.shape[0]

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=5, sharey=True, figsize=(18, 5))
_base = prod_pred_acc_mixed_df[prod_pred_acc_mixed_df["Model_Name"] == "base"]
_new = prod_pred_acc_mixed_df[prod_pred_acc_mixed_df["Model_Name"] == "new"]
for i in range(5):
    plt.sca(ax[i])
    plt.plot(_base.Step, _base[f"top_{i + 1}_exact"], '-o', label=f'Base top {i + 1}')
    plt.plot(_new.Step, _new[f"top_{i + 1}_exact"], '-o', label=f'New top {i + 1}')
    plt.legend()
    plt.grid()
plt.subplots_adjust()
plt.show()

In [None]:
for i in range(5):
    display(prod_pred_acc_mixed_df[["Model", f"top_{i + 1}_exact"]].sort_values([f"top_{i + 1}_exact"], ascending=False).set_index("Model").head() / test.shape[0])

### All together

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=5, sharey=True, figsize=(28, 10))
_base_sep = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "base"]
_new_sep = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "new"]
_no_rg = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "no_rg"]
_new_sort_sep = prod_pred_acc_df[prod_pred_acc_df["Model_Name"] == "new_sort"]
_base_mix = prod_pred_acc_mixed_df[prod_pred_acc_mixed_df["Model_Name"] == "base"]
_new_mix = prod_pred_acc_mixed_df[prod_pred_acc_mixed_df["Model_Name"] == "new"]
for i in range(5):
    plt.sca(ax[i])
    plt.plot(_base.Step, _base[f"top_{i + 1}_exact"], '-o', label=f'Base top {i + 1}')
    plt.plot(_new.Step, _new[f"top_{i + 1}_exact"], '-o', label=f'New top {i + 1}')
    plt.plot(_no_rg.Step, _no_rg[f"top_{i + 1}_exact"], '-o', label=f'No reag. top {i + 1}', c='k')
    plt.plot(_new_sort.Step, _new_sort[f"top_{i + 1}_exact"], '-o', label=f'New sort top {i + 1}')
    plt.plot(_base_mix.Step, _base_mix[f"top_{i + 1}_exact"], '-o', label=f'Base mix top {i + 1}')
    plt.plot(_new_mix.Step, _new_mix[f"top_{i + 1}_exact"], '-o', label=f'New mix top {i + 1}')
    plt.legend()
    plt.grid()
plt.subplots_adjust()
plt.show()

In [None]:
_no_rg

In [None]:
74808 / len(test)