In [None]:
import operator

In [None]:
from itertools import combinations

In [None]:
import pandas as pd

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from scipy.spatial import Voronoi
from scipy.sparse.linalg import svds
from scipy.optimize import curve_fit
from scipy.spatial import Delaunay

In [None]:
import networkx as nx

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw

In [None]:
from IPython.display import display

In [None]:
import plotly.express as px

In [None]:
from preprocessing.utils import get_reagent_statistics, build_pmi_dict, pmi_dict_to_sparse_matrix

In [None]:
from umap import UMAP

In [None]:
def voronoi_finite_polygons_2d(vor, radius=None):
    """
    Reconstruct infinite voronoi regions in a 2D diagram to finite
    regions.

    Parameters
    ----------
    vor : Voronoi
        Input diagram
    radius : float, optional
        Distance to 'points at infinity'.

    Returns
    -------
    regions : list of tuples
        Indices of vertices in each revised Voronoi regions.
    vertices : list of tuples
        Coordinates for revised Voronoi vertices. Same as coordinates
        of input vertices, with 'points at infinity' appended to the
        end.

    """

    if vor.points.shape[1] != 2:
        raise ValueError("Requires 2D input")

    new_regions = []
    new_vertices = vor.vertices.tolist()

    center = vor.points.mean(axis=0)
    if radius is None:
        radius = vor.points.ptp().max()

    # Construct a map containing all ridges for a given point
    all_ridges = {}
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))

    # Reconstruct infinite regions
    for p1, region in enumerate(vor.point_region):
        vertices = vor.regions[region]

        if all(v >= 0 for v in vertices):
            # finite region
            new_regions.append(vertices)
            continue

        # reconstruct a non-finite region
        ridges = all_ridges[p1]
        new_region = [v for v in vertices if v >= 0]

        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0:
                # finite ridge: already in the region
                continue

            # Compute the missing endpoint of an infinite ridge

            t = vor.points[p2] - vor.points[p1]  # tangent
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])  # normal

            midpoint = vor.points[[p1, p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, n)) * n
            far_point = vor.vertices[v2] + direction * radius

            new_region.append(len(new_vertices))
            new_vertices.append(far_point.tolist())

        # sort region counterclockwise
        vs = np.asarray([new_vertices[v] for v in new_region])
        c = vs.mean(axis=0)
        angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
        new_region = np.array(new_region)[np.argsort(angles)]

        # finish
        new_regions.append(new_region.tolist())

    return new_regions, np.asarray(new_vertices)

In [None]:
def read_and_filter_reagent_data(path, min_count):
    reagent_smiles = pd.read_csv(path, header=None)[0]
    reagent_occurrence_counter = get_reagent_statistics(reagent_smiles, separator=";")
    i2r = {i: smi for i, (smi, count) in enumerate(reagent_occurrence_counter.most_common()) if count >= min_count}
    r2i = {v: k for k, v in i2r.items()}
    smiles = [None] * len(i2r)
    for i in i2r:
        smiles[i] = i2r[i]
    smiles_table = pd.DataFrame(smiles)
    smiles_table.columns = ["smiles"]
    smiles_table["count"] = smiles_table["smiles"].map(reagent_occurrence_counter)

    filtered_reagent_smiles = reagent_smiles.apply(lambda x: [r for r in x.split(";") if r in r2i])
    return smiles_table, filtered_reagent_smiles

In [None]:
def get_distributed_representations(unique_smiles, reagent_smiles, emb_dim):
    # Building PMI matrix
    pmi_scores = pmi_dict_to_sparse_matrix(build_pmi_dict(reagent_smiles),
                                           reagent_to_index={smi: i for i, smi in enumerate(unique_smiles)})
    # Factorizing PMI matrix
    embeddings, _, _ = svds(pmi_scores, k=emb_dim)
    norms = np.sqrt(np.sum(np.square(embeddings), axis=1, keepdims=True))
    embeddings /= np.maximum(norms, 1e-7)
    return embeddings

In [None]:
def reagent_report(smiles_table: pd.Series,
                   embs: np.array,
                   standard_rgs: pd.DataFrame,
                   umap_object):
    roles = smiles_table["smiles"].map(dict(standard_rgs.set_index("smiles")["class"]),
                                       na_action='ignore').fillna("unk")
    names = smiles_table["smiles"].map(dict(standard_rgs.set_index("smiles")["name"]),
                                       na_action='ignore').fillna("???")
    xy = umap_object.fit_transform(embs)
    xy = pd.DataFrame(xy)
    result = pd.concat((xy, smiles_table, roles, names), axis=1)
    result.columns = ["x", "y", "smiles", "count", "class", "name"]
    return result

In [None]:
def get_role_colors(roles):
    return {v: px.colors.qualitative.Light24[i] for i, v in enumerate(sorted(roles.unique()))}

In [None]:
def umap_plot(points, roles, save_path=None):
    plt.figure(figsize=(10, 10))

    role_colors = get_role_colors(roles)
    color_seq = roles.map(role_colors)

    plt.scatter(points[:, 0], points[:, 1], c=color_seq, edgecolors="k")
    legend_handles = [
        plt.Line2D(
            [],
            [],
            marker="s",
            color="w",
            markerfacecolor=v,
            ms=10,
            alpha=1,
            linewidth=0,
            label=k,
            markeredgecolor="k",
        )
        for k, v in role_colors.items()
    ]
    legend_kwargs_ = dict(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, )
    plt.legend(handles=legend_handles, **legend_kwargs_)
    plt.title("UMAP projection of reagent embeddings", fontdict={"size": FONTSIZE})
    plt.xlabel("UMAP axis 1", fontdict={"size": FONTSIZE - 2})
    plt.ylabel("UMAP axis 2", fontdict={"size": FONTSIZE - 2})
    plt.xticks([])
    plt.yticks([])
    plt.tick_params(axis='x', which='both', bottom=False, top=False)  # Remove ticks on x-axis
    plt.tick_params(axis='y', which='both', left=False, right=False)  # Remove ticks on y-axis
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Global varibles

In [None]:
FONTSIZE = 16
MIN_COUNT = 100
EMB_DIM = 50
STANDARD_REAGENTS_PATH = "../data/standard_reagents.csv"

In [None]:
standard_reagents = pd.read_csv(STANDARD_REAGENTS_PATH, index_col=[0])

# Read data

Reagents determined by atom mapping

In [None]:
uspto_aam_reagents_path = "../data/uspto_aam_reagents/reagents-1128297.txt"

In [None]:
smiles_table_aam, reagent_smiles_aam = read_and_filter_reagent_data(uspto_aam_reagents_path, min_count=MIN_COUNT)

Reagents determined by fingerprints where possible

In [None]:
# uspto_mixed_reagents_path = "../data/uspto_mixed_reagents/reagents-1131934.txt"

In [None]:
# smiles_table_mixed, reagent_smiles_mixed = read_and_filter_reagent_data(uspto_mixed_reagents_path, min_count=MIN_COUNT)

# Get embeddings

In [None]:
embeddings_aam = get_distributed_representations(smiles_table_aam["smiles"], reagent_smiles_aam, emb_dim=EMB_DIM)

In [None]:
# embeddings_mixed = get_distributed_representations(smiles_table_mixed["smiles"], reagent_smiles_mixed, emb_dim=EMB_DIM)

# Project embeddings to the plane, get reports

In [None]:
umap_aam = UMAP(random_state=12345)

In [None]:
# umap_mixed = UMAP(random_state=12345)

In [None]:
r_aam = reagent_report(smiles_table_aam, embeddings_aam, standard_reagents, umap_aam)

In [None]:
# r_mixed = reagent_report(smiles_table_mixed, embeddings_mixed, standard_reagents, umap_mixed)

# Visualize the UMAP projections

In [None]:
points_aam = r_aam[["x", "y"]].values

In [None]:
umap_plot(points_aam, r_aam["class"], save_path="../figures/umap_aam_rgs.png")

In [None]:
# Reagents determined by the fingerprint procedure when possible
# points_mixed = r_mixed[["x", "y"]].values
# umap_plot(points_mixed, r_mixed["class"])

# Highlight the regions of the same role using a Voronoi diagram

In [None]:
# compute Voronoi tesselation
vor = Voronoi(points_aam)

role_colors_aam = get_role_colors(r_aam["class"])
color_aam = r_aam["class"].map(role_colors_aam)

# plot
regions, vertices = voronoi_finite_polygons_2d(vor)

plt.figure(figsize=(10, 10))
# colorize
for i, region in enumerate(regions):
    polygon = vertices[region]
    plt.fill(*zip(*polygon), color=color_aam[i], alpha=1)

plt.scatter(points_aam[:, 0], points_aam[:, 1], c=color_aam, edgecolors="grey")
plt.xlim(vor.min_bound[0] - 0.1, vor.max_bound[0] + 0.1)
plt.ylim(vor.min_bound[1] - 0.1, vor.max_bound[1] + 0.1)

examples = {
    1: "Peptide coupling activators",
    2: "Phosphorus-based ligands",
    3: "Pd catalysts for cross-coupling",
    4: "Chelators",
    5: "Hydrogenation and Cu catalysts",
    6: "Mitsunobu reaction reagents",
    7: "Chlorinating agents",
    8: "Aliphatic amine bases",
    9: "Borohydrides"
}

number_labels_fontdict = {"size": FONTSIZE + 2, 'weight': 'bold'}
plt.figtext(0.14, 0.75, '1', fontdict=number_labels_fontdict)  # Activators for peptide coupling
plt.figtext(0.85, 0.30, '2', fontdict=number_labels_fontdict)  # Phosphorus-based ligands
plt.figtext(0.70, 0.20, '3', fontdict=number_labels_fontdict)  # Pd-based catalysts for cross-coupling reactions
plt.figtext(0.68, 0.12, '4', fontdict=number_labels_fontdict)  # Chelators
plt.figtext(0.41, 0.19, '5', fontdict=number_labels_fontdict)  # Hydrogenation catalysts and Cu-based catalysts
plt.figtext(0.75, 0.30, '6', fontdict=number_labels_fontdict)  # Mitsunobu reaction reagents
plt.figtext(0.31, 0.65, '7', fontdict=number_labels_fontdict)  # Chlorinating agents
plt.figtext(0.48, 0.83, '8', fontdict=number_labels_fontdict)  # Aliphatic amine bases
plt.figtext(0.66, 0.76, '9', fontdict=number_labels_fontdict)  # Borohydrides

for i in range(1, len(examples) + 1):
    plt.figtext(0.93, 0.85 - (i - 1) * 0.04, f'{i}: {examples[i]}', fontdict={"size": FONTSIZE - 2})

legend_handles = [
    plt.Line2D(
        [],
        [],
        marker="s",
        color="w",
        markerfacecolor=v,
        ms=10,
        alpha=1,
        linewidth=0,
        label=k,
        markeredgecolor="k",
    )
    for k, v in role_colors_aam.items()
]
legend_kwargs_ = dict(loc="best", bbox_to_anchor=(1, 0.5), frameon=False, fontsize=FONTSIZE - 2)
plt.legend(handles=legend_handles, **legend_kwargs_)
plt.title("Voronoi diagram of the UMAP projection of reagent embeddings", fontdict={"size": FONTSIZE})
plt.xlabel("UMAP axis 1", fontdict={"size": FONTSIZE - 2})
plt.ylabel("UMAP axis 2", fontdict={"size": FONTSIZE - 2})
plt.xticks([])
plt.yticks([])
plt.tick_params(axis='x', which='both', bottom=False, top=False)  # Remove ticks on x-axis
plt.tick_params(axis='y', which='both', left=False, right=False)  # Remove ticks on y-axis

plt.savefig("../figures/Fig-5_umap_aam_rgs_voronoi.png", dpi=300, bbox_inches='tight')
plt.show()

# Using the Delaunay triangulation to visualize regions of the same color in the Voronoi diagram as connected components of a graph

In [None]:
def contiguous_role_regions(points, color_seq):
    delaunay = Delaunay(points=points)
    G = nx.Graph()
    for i in range(len(points)):
        G.add_node(i)
    for path in delaunay.simplices:
        p1, p2, p3 = path
        if color_seq[p1] == color_seq[p2]:
            G.add_edge(p1, p2)
        if color_seq[p2] == color_seq[p3]:
            G.add_edge(p2, p3)
        if color_seq[p3] == color_seq[p1]:
            G.add_edge(p1, p3)
    return G

In [None]:
role_colors_aam = get_role_colors(r_aam["class"])
color_seq_aam = r_aam["class"].map(role_colors_aam)
G = contiguous_role_regions(points_aam, color_seq_aam)

plt.figure(figsize=(10, 10))
nx.draw(G, with_labels=False, node_size=20, node_color=color_seq_aam)
print("Connected components:", nx.number_connected_components(G))
plt.title("Regions of the same role in the Voronoi diagram represented as connected components of a graph")
plt.show()

# Reagent occurrence counts

## In logarithmic scale

In [None]:
plt.figure(figsize=(10, 10))
plt.plot(np.log10(r_aam["count"]), linewidth=2)
plt.title("Occurrence distribution (truncated to 100) for every unique reagent in the USPTO dataset",
          fontdict={"size": FONTSIZE}, y=1.05)
plt.xlabel("Unique reagent index", fontdict={"size": FONTSIZE})
plt.ylabel("Decimal logarithm of the number of occurrences", fontdict={"size": FONTSIZE})
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)
plt.grid(axis="y")
plt.tight_layout()
plt.savefig("../figures/Fig-2_occurrences_log.png", dpi=300, bbox_inches='tight')
plt.show()

## In linear scale

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(15, 10))
THRESHOLD_1, THRESHOLD_2 = 100, 400
ax[0].plot(r_aam["count"], linewidth=3)
ax[0].grid(axis='y')
ax[1].plot(r_aam["count"][THRESHOLD_1:], linewidth=3)
ax[1].grid(axis='y')
ax[2].plot(r_aam["count"][THRESHOLD_2:], linewidth=3)
ax[2].grid(axis='y')

# First bounding box
line_color = "red"
line_width = 2

box_right = 0.868

box_1_low = 0.64
box_1_high = 0.7
box_1_left = 0.27
graph_2_top_right_x = 0.9005
graph_2_top_y = 0.596
graph_2_top_left_x = 0.127

fig.lines.extend([
    plt.Line2D([box_1_left, box_right], [box_1_high, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, box_right], [box_1_low, box_1_low],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, box_1_left], [box_1_low, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, box_right], [box_1_low, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, graph_2_top_right_x], [box_1_low, graph_2_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, graph_2_top_left_x], [box_1_low, graph_2_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width)
])

# Second bounding box
box_2_low = 0.36
box_2_high = 0.42
box_2_left = 0.6
graph_3_top_right_x = 0.9005
graph_3_top_y = 0.3114
graph_3_top_left_x = 0.127

fig.lines.extend([
    plt.Line2D([box_2_left, box_right], [box_2_high, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, box_right], [box_2_low, box_2_low],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, box_2_left], [box_2_low, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, box_right], [box_2_low, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, graph_3_top_right_x], [box_2_low, graph_3_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, graph_3_top_left_x], [box_2_low, graph_3_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width)
])

plt.sca(ax[0])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)

plt.sca(ax[1])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)
plt.ylabel("Number of occurrences", fontdict={"size": FONTSIZE})

plt.sca(ax[2])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)
plt.xlabel("Unique reagent index", fontdict={"size": FONTSIZE})
plt.subplots_adjust(hspace=0.4)
fig.suptitle("Number of occurrences (truncated to 100) for every unique reagent in the USPTO dataset",
             y=0.92,
             fontsize=16)
# plt.savefig("../figures/occurrences.png", dpi=300, bbox_inches='tight')
plt.show()

## Fitting the rule for the reagent count decrease

In [None]:
# Sample data
y = r_aam["count"] / r_aam["count"].sum()
x = np.arange(1, len(y) + 1)


# Define the reciprocal function
def reciprocal_func(x, a, b):
    return a / x ** b


# Fit the data to the reciprocal function
params, covariance = curve_fit(reciprocal_func, x[100:], y[100:])

# Extracting the fitted parameter
fitted_a, fitted_b = params

# Plotting the original data and the fitted curve
i = 50
plt.scatter(x[i:], y[i:], label='Original data')
plt.plot(x[i:], reciprocal_func(x, fitted_a, fitted_b)[i:], color='red', label='Fitted curve')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Fitting Reciprocal Function')

print(f"Best fit: {fitted_a:.2f} / x^{fitted_b:.2f}")

plt.show()

## Distribution of roles

In [None]:
r_aam["class"] = r_aam["class"].apply(lambda x: {"cat": "catalyst",
                                                 "ox": "ox. agent",
                                                 "red": "red. agent"}.get(x, x).capitalize())
role_distr_aam = r_aam["class"].value_counts(normalize=True).sort_index()

In [None]:
# role_distr_mixed = r_mixed["class"].value_counts(normalize=True).sort_index()

In [None]:
role_colors = get_role_colors(r_aam["class"])

plt.figure(figsize=(8, 8))
plt.pie(role_distr_aam, colors=px.colors.qualitative.Light24, autopct='%1.1f%%', labels=role_distr_aam.index,
        textprops={"size": FONTSIZE - 2})
plt.suptitle("Distribution of roles of the reagents in the USPTO dataset", fontdict={"size": FONTSIZE}, y=0.9)
plt.savefig("../figures/Fig-3_pie.png", dpi=300, bbox_inches='tight')
plt.show()

# Reagent pairs with largest PMI scores

In [None]:
TOP_SCORES = 15
pmi_scores_dict_aam = build_pmi_dict(reagent_smiles_aam)
top_pmi_scores_aam = sorted(list(pmi_scores_dict_aam.items()), key=operator.itemgetter(1), reverse=True)[:TOP_SCORES]

In [None]:
for (r1, r2), score in top_pmi_scores_aam:
    print(f"Molecules: {r1} & {r2}. PMI score: {score}")
    display(Draw.MolsToGridImage(
        [Chem.MolFromSmiles(r1), Chem.MolFromSmiles(r2)],
        molsPerRow=2,
        subImgSize=(300, 300)
    ))

# Orphan reagents

Counting reagent pairs

In [None]:
n_smiles_aam = smiles_table_aam["smiles"].shape[0]
reagent_to_index = {smi: i for i, smi in enumerate(smiles_table_aam["smiles"])}
reagent_smiles_aam_ids = reagent_smiles_aam.apply(lambda x: [reagent_to_index[s] for s in x if s in reagent_to_index])

count_pair = np.zeros((n_smiles_aam, n_smiles_aam))
for entry in reagent_smiles_aam_ids:
    for rgs_i_1, rgs_i_2 in map(sorted, combinations(entry, 2)):
        count_pair[rgs_i_1, rgs_i_2] += 1
        count_pair[rgs_i_2, rgs_i_1] += 1

In [None]:
# Which reagents were always the single reagents in a reaction
orphan = pd.Series(count_pair.sum(1) == 0)

In [None]:
orphan_reagents_aam = r_aam[orphan]

In [None]:
orphan_reagents_aam

In [None]:
print("Orphan reagents (atom mapping)")
display(Draw.MolsToGridImage(
    [Chem.MolFromSmiles(i) for i in orphan_reagents_aam["smiles"]],
    molsPerRow=3,
    legends=[f"{i}: {r}" for i, r in zip(orphan_reagents_aam.index, orphan_reagents_aam["class"])],
    subImgSize=(300, 300)
))

# Percentage of reactions with rare reagents

In [None]:
frequent_50 = set(r_aam["smiles"].head(50))

In [None]:
# Reactions in which all reagents are from the 50 most common ones
reagent_smiles_aam.apply(lambda x: all([r in frequent_50 for r in x])).value_counts(normalize=True)[True]

In [None]:
# Reactions in which at least one reagent is from the 50 most common ones
reagent_smiles_aam.apply(lambda x: any([r in frequent_50 for r in x])).value_counts(normalize=True)[True]

# Exclusive reagents in both reagent determination procedures

In [None]:
aam_only_reagents = r_aam[r_aam["name"].apply(lambda x: x not in set(r_mixed["name"]))]
mixed_only_reagents = r_mixed[r_mixed["name"].apply(lambda x: x not in set(r_aam["name"]))]

In [None]:
print("Number of all reagents")
print("AAM:", len(r_aam))
print("Mixed:", len(r_mixed))
print()
print("Number of exclusive reagents")
print("AAM:", len(aam_only_reagents))
print("Mixed:", len(mixed_only_reagents))

Exlusive reagents mostly have the "reactant" role

In [None]:
aam_only_reagents["class"].value_counts()

In [None]:
mixed_only_reagents["class"].value_counts()

In [None]:
aam_only_reagents

In [None]:
mixed_only_reagents