In [1]:
import pandas as pd

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import Voronoi

In [3]:
from rdkit import Chem
from rdkit.Chem import Draw

In [4]:
from IPython.display import display

In [5]:
import plotly.express as px

In [6]:
def voronoi_finite_polygons_2d(vor, radius=None):
    """
    Reconstruct infinite voronoi regions in a 2D diagram to finite
    regions.

    Parameters
    ----------
    vor : Voronoi
        Input diagram
    radius : float, optional
        Distance to 'points at infinity'.

    Returns
    -------
    regions : list of tuples
        Indices of vertices in each revised Voronoi regions.
    vertices : list of tuples
        Coordinates for revised Voronoi vertices. Same as coordinates
        of input vertices, with 'points at infinity' appended to the
        end.

    """

    if vor.points.shape[1] != 2:
        raise ValueError("Requires 2D input")

    new_regions = []
    new_vertices = vor.vertices.tolist()

    center = vor.points.mean(axis=0)
    if radius is None:
        radius = vor.points.ptp().max()

    # Construct a map containing all ridges for a given point
    all_ridges = {}
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))

    # Reconstruct infinite regions
    for p1, region in enumerate(vor.point_region):
        vertices = vor.regions[region]

        if all(v >= 0 for v in vertices):
            # finite region
            new_regions.append(vertices)
            continue

        # reconstruct a non-finite region
        ridges = all_ridges[p1]
        new_region = [v for v in vertices if v >= 0]

        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0:
                # finite ridge: already in the region
                continue

            # Compute the missing endpoint of an infinite ridge

            t = vor.points[p2] - vor.points[p1]  # tangent
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])  # normal

            midpoint = vor.points[[p1, p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, n)) * n
            far_point = vor.vertices[v2] + direction * radius

            new_region.append(len(new_vertices))
            new_vertices.append(far_point.tolist())

        # sort region counterclockwise
        vs = np.asarray([new_vertices[v] for v in new_region])
        c = vs.mean(axis=0)
        angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
        new_region = np.array(new_region)[np.argsort(angles)]

        # finish
        new_regions.append(new_region.tolist())

    return new_regions, np.asarray(new_vertices)

In [7]:
FONTSIZE=16

## Reading data

In [125]:
d = pd.read_csv("/Users/mikhailandronov/work/reagent_emb_vis/data/uspto_aam_rgs_min_count_100_d_50.csv", sep=',')

## Reagent occurrence counts

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(15, 10))
THRESHOLD_1, THRESHOLD_2 = 100, 400
ax[0].plot(d["count"], linewidth=3)
ax[0].grid(axis='y')
ax[1].plot(d["count"][THRESHOLD_1:], linewidth=3)
ax[1].grid(axis='y')
ax[2].plot(d["count"][THRESHOLD_2:], linewidth=3)
ax[2].grid(axis='y')


# First bounding box
line_color = "red"
line_width=2

box_right = 0.868

box_1_low = 0.64
box_1_high = 0.7
box_1_left = 0.27
graph_2_top_right_x = 0.9005
graph_2_top_y = 0.596
graph_2_top_left_x = 0.127

fig.lines.extend([
    plt.Line2D([box_1_left, box_right], [box_1_high, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, box_right], [box_1_low, box_1_low],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, box_1_left], [box_1_low, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, box_right], [box_1_low, box_1_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, graph_2_top_right_x], [box_1_low, graph_2_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_1_left, graph_2_top_left_x], [box_1_low, graph_2_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width)
])

# Second bounding box
box_2_low = 0.36
box_2_high = 0.42
box_2_left = 0.6
graph_3_top_right_x = 0.9005
graph_3_top_y = 0.3114
graph_3_top_left_x = 0.127

fig.lines.extend([
    plt.Line2D([box_2_left, box_right], [box_2_high, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, box_right], [box_2_low, box_2_low],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, box_2_left], [box_2_low, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, box_right], [box_2_low, box_2_high],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_right, graph_3_top_right_x], [box_2_low, graph_3_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width),
    plt.Line2D([box_2_left, graph_3_top_left_x], [box_2_low, graph_3_top_y],
               transform=fig.transFigure, color=line_color, linewidth=line_width)
])

plt.sca(ax[0])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)

plt.sca(ax[1])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)
plt.ylabel("Number of occurrences", fontdict={"size": FONTSIZE})

plt.sca(ax[2])
plt.xticks(fontsize=FONTSIZE - 2)
plt.yticks(fontsize=FONTSIZE - 2)
plt.xlabel("Unique reagent index", fontdict={"size": FONTSIZE})
plt.subplots_adjust(hspace=0.4)
fig.suptitle("Number of occurrences (truncated to 100) for every unique reagent in the USPTO dataset",
             y=0.92,
             fontsize=16)
plt.savefig("../figures/occurrences.png", dpi=300)
plt.show()

## Distribution of roles

In [None]:
role_distr = d["class"].value_counts(normalize=True).sort_index()

plt.figure(figsize=(10, 10))
plt.pie(role_distr, colors=px.colors.qualitative.Light24, labels=role_distr.index, autopct='%1.1f%%',
        textprops={"size": FONTSIZE - 2})
plt.title("Distribution of roles of the reagents in the USPTO dataset", fontdict={"size": FONTSIZE})
plt.savefig("../figures/pie.png", dpi=300)
plt.show()

## Embeddings projection

In [287]:
role_to_label = {v: i for i, v in enumerate(sorted(d["class"].unique()))}
numerical_role_label = d["class"].map(role_to_label)
color = [px.colors.qualitative.Light24[i] for i in numerical_role_label]

In [None]:
role_to_color = {k: px.colors.qualitative.Light24[v] for k, v in role_to_label.items()}
role_to_color

In [281]:
points = d[["x", "y"]].values

In [None]:
plt.figure(figsize=(10, 10))
plt.scatter(points[:, 0], points[:, 1], c=color, edgecolors="k")
legend_handles = [
            plt.Line2D(
                [],
                [],
                marker="s",
                color="w",
                markerfacecolor=v,
                ms=10,
                alpha=1,
                linewidth=0,
                label=k,
                markeredgecolor="k",
            )
            for k, v in role_to_color.items()
        ]
legend_kwargs_ = dict(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, )
plt.legend(handles=legend_handles, **legend_kwargs_)
plt.title("UMAP projection of reagent embeddings", fontdict={"size": FONTSIZE})
plt.xlabel("UMAP axis 1", fontdict={"size": FONTSIZE - 2})
plt.ylabel("UMAP axis 2", fontdict={"size": FONTSIZE - 2})
plt.xticks([])
plt.yticks([])
plt.tick_params(axis='x', which='both', bottom=False, top=False)  # Remove ticks on x-axis
plt.tick_params(axis='y', which='both', left=False, right=False)  # Remove ticks on y-axis
plt.savefig("../figures/umap_aam_rgs.png", dpi=300)
plt.show()

## Voronoi diagram

In [None]:
# compute Voronoi tesselation
vor = Voronoi(points)

# plot
regions, vertices = voronoi_finite_polygons_2d(vor)

plt.figure(figsize=(10, 10))
# colorize
for i, region in enumerate(regions):
    polygon = vertices[region]
    plt.fill(*zip(*polygon), color=color[i], alpha=1)

plt.scatter(points[:, 0], points[:, 1], c=color, edgecolors="k")
plt.xlim(vor.min_bound[0] - 0.1, vor.max_bound[0] + 0.1)
plt.ylim(vor.min_bound[1] - 0.1, vor.max_bound[1] + 0.1)


legend_handles = [
            plt.Line2D(
                [],
                [],
                marker="s",
                color="w",
                markerfacecolor=v,
                ms=10,
                alpha=1,
                linewidth=0,
                label=k,
                markeredgecolor="k",
            )
            for k, v in role_to_color.items()
        ]
legend_kwargs_ = dict(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, )
plt.legend(handles=legend_handles, **legend_kwargs_)
plt.title("Voronoi diagram of the UMAP projection of reagent embeddings", fontdict={"size": FONTSIZE})
plt.xlabel("UMAP axis 1", fontdict={"size": FONTSIZE - 2})
plt.ylabel("UMAP axis 2", fontdict={"size": FONTSIZE - 2})
plt.xticks([])
plt.yticks([])
plt.tick_params(axis='x', which='both', bottom=False, top=False)  # Remove ticks on x-axis
plt.tick_params(axis='y', which='both', left=False, right=False)  # Remove ticks on y-axis
plt.savefig("../figures/umap_aam_rgs_voronoi.png", dpi=300)
plt.show()

## Largest PMI scores

In [90]:
import json

with open("../top_pmi_scores.json") as f:
    top_pmi_scores = json.load(f)

In [None]:
for (r1, r2), score in top_pmi_scores:
    print(r1, r2)
    display(Draw.MolsToGridImage(
        [Chem.MolFromSmiles(r1), Chem.MolFromSmiles(r2)],
        molsPerRow=2,
        subImgSize=(300, 300)
    ))
    print(score)