In [1]:
# Helper libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import matplotlib as mpl
# import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns

import time
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances

from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG = True

In [2]:
def cal_distance(x, y, metric='euclidean'):
    if type(x) == pd.core.series.Series:
        x = x.values.reshape(1, -1)
    if type(y) == pd.core.series.Series:
        y = y.values.reshape(1, -1)
    return pairwise_distances(x, y, metric=metric)

In [3]:
def print_closest_words(x_embedding, x_query, n=5, add_vec=None):
    x = x_embedding.loc[x_query].values.reshape(1, -1).copy()
    # print('x is: {}'.format(x))
    if add_vec is not None:
        x += add_vec
        # print('x + add_vec is: {}'.format(x))
    dists = cal_distance(x=x_embedding.values, y=x)     # compute distances to all words
    lst = sorted(enumerate(dists), key=lambda x: x[1]) # sort by distance
    # print(lst[:100])
    all_smiles = []
    all_dis = [] 
    if add_vec is not None:
        for idx, difference in lst[0:n]:
            _smiles = x_embedding.iloc[idx,:].name
            all_smiles.append(_smiles)
            all_dis.append(difference[0])
            # print(_smiles, difference)
    else:
        for idx, difference in lst[1:n+1]:   # take the top n
            _smiles = x_embedding.iloc[idx,:].name
            all_smiles.append(_smiles)
            all_dis.append(difference[0])
            # print(_smiles, difference)
    return {'smiles': all_smiles, 'dis': all_dis}

In [4]:
def get_minus_result(x_embedding, x, y):
    x = x_embedding.loc[x].values.reshape(1, -1)
    y = x_embedding.loc[y].values.reshape(1, -1)
    return x-y

In [5]:
def draw_mol_by_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    size = (200, 200)
    return Draw.MolToImage(mol, size=size)

In [6]:
def draw_multiple_mol(smiles_list, mols_per_row=4, file_path=None, legends=None):
    mols = []
    for i in smiles_list:
        mols.append(Chem.MolFromSmiles(i))
    mols_per_row = min(len(smiles_list), mols_per_row)
    if legends is None:
        img=Draw.MolsToGridImage(mols, molsPerRow=mols_per_row, subImgSize=(220, 120), useSVG=True)
    else:
        img=Draw.MolsToGridImage(mols, molsPerRow=mols_per_row, subImgSize=(220, 120), useSVG=True, legends=legends)
    if file_path:
        with open(file_path, 'w') as f_handle:
            f_handle.write(img.data)
    return img

In [7]:
def show_each_md(x_reduced, frag_info, file_path=''):
    """
    reduced_x: 2 dimensions x with fragment as index, a dataframe
    frag_info: the number of each MD with fragemnt as index, a dataframe
    """
    # model = model_name
    fig, ax = plt.subplots(2, 4, figsize=(24, 12))
    ax = ax.flatten()
    # print(x_reduced.head(2))
    # print(frag_info.head(2))
    intersect_index = set(x_reduced.index.to_list()) & set(frag_info.index.to_list())
    x_reduced = x_reduced.loc[intersect_index, :].copy()  # alignment
    frag_info = frag_info.loc[intersect_index, :].copy()
    # reduced_x = reduced_x.loc[frag_info.index, :].copy()
    # parallel_frag_info = parallel_frag_info.loc[:, selected_md].copy()
    for i,md in enumerate(frag_info.columns.to_list()):
        # current_labels = parallel_frag_info.iloc[:, i]
        current_labels = frag_info.iloc[:, i]
        unique_labels = sorted(current_labels.unique())
        n_labels = len(unique_labels)
        # print(n_labels)
        cc = sns.color_palette('Blues', n_labels)
        for j,label in enumerate(unique_labels):
            current_nodes = (current_labels == label)
            ax[i].scatter(x_reduced.loc[current_nodes, 0], x_reduced.loc[current_nodes, 1],
                          c=colors.rgb2hex(cc[j]), vmin=0, vmax=10, s=10, label=str(label))
        ax[i].set_title(md, fontsize=12)
        ax[i].legend()
    plt.tight_layout()
    plt.savefig(file_path, bbox_inches='tight', transparent=True)
    plt.close()

In [8]:
def reduce_by_tsne(x):
    t0 = time.time()
    tsne = TSNE(n_components=2, n_jobs=4, learning_rate=200, 
                early_exaggeration=20, n_iter=2000, 
                random_state=42, init='pca', verbose=1)
    X_reduced_tsne = tsne.fit_transform(x)
    # X_reduced_tsne = tsne.fit(x)
    print(X_reduced_tsne.shape)
    # np.save('X_reduced_tsne_pca_first', X_reduced_tsne2)
    t1 = time.time()
    print("t-SNE took {:.1f}s.".format(t1 - t0))
    return X_reduced_tsne

In [9]:
mol2vec = pd.read_csv('./parallel/all_x_after_trained_parallel_model.csv', index_col=0)
mol2vec.head(2)

Unnamed: 0_level_0,0.1,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
id6,-1.499594,4.795064,-1.682328,3.634081,2.648775,-1.290213,1.098553,-1.721439,0.142399,-1.757469,...,1.611947,-0.746514,-0.910748,0.334813,-1.757399,1.627908,0.670884,-1.455732,5.117005,-1.61404
id8,-1.740047,8.122423,-1.756263,1.207755,0.971331,-1.728168,4.42106,-1.726883,-1.115949,-1.740837,...,-1.353415,-1.085281,3.865592,-1.224507,-1.75805,-1.690312,-0.845013,-1.326752,11.165923,-1.721184


In [10]:
cid2smiles = pd.read_csv('./parallel/cid2smiles_all_in_train_test.csv', index_col=0)
cid2smiles.head(2)

Unnamed: 0_level_0,smiles
0,Unnamed: 1_level_1
id6,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O
id8,COc1ccccc1OC(=O)Oc1ccccc1OC


In [22]:
# selected_cid = np.random.choice(cid2smiles.index.to_list(), 10, replace=False)
selected_cid = ['id1482417', 'id724907', 'id349132', 'id197300', 'id375924',
       'id50751', 'id57340', 'id68181', 'id96142', 'id587880']

In [23]:
selected_cid2smiles = cid2smiles.loc[selected_cid, 'smiles'].to_dict()
selected_cid2smiles

{'id1482417': 'Cc1cc(NC(=O)c2cnn(CC(C)C)c2C)ccc1-n1cnnn1',
 'id724907': 'Cc1ccc(F)cc1S(=O)(=O)N1CCCC(C(=O)NC2CC2)C1',
 'id349132': 'CC(C)C(C)NC(=O)CCNS(=O)(=O)c1ccc(F)c(F)c1',
 'id197300': 'C#CCn1c(=NC(=O)c2ccno2)sc2c(OC)ccc(OC)c21',
 'id375924': 'N#Cc1cccc(C(=O)Nc2ccn(-c3ccccc3)n2)c1',
 'id50751': 'CNC(=O)NC(C)(C)CCCCC(C)(C)NC(=O)NC',
 'id57340': 'FC(F)(Br)CC(F)(F)Br',
 'id68181': 'CC(=O)OC(CC(=O)C(C)(C)C)C(Cl)(Cl)Cl',
 'id96142': 'CC(=O)NC(NC(C)=O)C(O)C(O)C(O)C(O)CO',
 'id587880': 'CNC(=O)CNC(=O)C(C)NC(=O)OC(C)(C)C'}

In [11]:
mol2vec = mol2vec.merge(cid2smiles, right_index=True, left_index=True)
mol2vec.head(2)

Unnamed: 0_level_0,0.1,1,2,3,4,5,6,7,8,9,...,21,22,23,24,25,26,27,28,29,smiles
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
id6,-1.499594,4.795064,-1.682328,3.634081,2.648775,-1.290213,1.098553,-1.721439,0.142399,-1.757469,...,-0.746514,-0.910748,0.334813,-1.757399,1.627908,0.670884,-1.455732,5.117005,-1.61404,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O
id8,-1.740047,8.122423,-1.756263,1.207755,0.971331,-1.728168,4.42106,-1.726883,-1.115949,-1.740837,...,-1.085281,3.865592,-1.224507,-1.75805,-1.690312,-0.845013,-1.326752,11.165923,-1.721184,COc1ccccc1OC(=O)Oc1ccccc1OC


In [12]:
mol2vec.set_index('smiles', inplace=True)
mol2vec.head(2)

Unnamed: 0_level_0,0.1,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
smiles,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O,-1.499594,4.795064,-1.682328,3.634081,2.648775,-1.290213,1.098553,-1.721439,0.142399,-1.757469,...,1.611947,-0.746514,-0.910748,0.334813,-1.757399,1.627908,0.670884,-1.455732,5.117005,-1.61404
COc1ccccc1OC(=O)Oc1ccccc1OC,-1.740047,8.122423,-1.756263,1.207755,0.971331,-1.728168,4.42106,-1.726883,-1.115949,-1.740837,...,-1.353415,-1.085281,3.865592,-1.224507,-1.75805,-1.690312,-0.845013,-1.326752,11.165923,-1.721184


In [13]:
frag2vec_new = pd.read_csv('./parallel/frag_embedding_reg.csv', index_col=0)
frag2vec_new.head(2)

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
fragment,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CC,-0.558245,-1.438177,0.099167,-1.690964,0.609527,0.275151,-0.131262,-1.386071,-1.124998,-0.88389,...,0.888008,-1.322027,1.935024,2.144108,-0.883948,0.212668,0.572107,0.954083,-0.496624,-1.354179
CN,-0.458151,-1.483617,0.175892,-1.543361,0.74027,-0.030037,0.70227,-1.455249,-1.184582,-1.250804,...,0.650653,-1.315698,1.569044,3.021253,0.22389,0.342611,0.79345,0.123334,-1.163844,-1.464272


#### query with bouble bond

In [14]:
double_bond = get_minus_result(x_embedding=frag2vec_new, x='C=S', y='CS')

In [24]:
selected_cid2nn = {}
for i,j in selected_cid2smiles.items():
    selected_cid2nn[i] = print_closest_words(x_embedding=mol2vec, x_query=j)

In [25]:
for i,j in selected_cid2smiles.items():
    smiles_list = [j] + selected_cid2nn[i]['smiles']
    dis = selected_cid2nn[i]['dis']
    legends = [i] + ['{:.2f}'.format(d) for d in dis]
    draw_multiple_mol(smiles_list=smiles_list, mols_per_row=6, file_path='./figures/nn_{}.svg'.format(i), legends=legends)

In [26]:
selected_cid2nn_with_double_bond = {}
for i,j in selected_cid2smiles.items():
    selected_cid2nn_with_double_bond[i] = print_closest_words(x_embedding=mol2vec, x_query=j, add_vec=double_bond)

In [27]:
for i,j in selected_cid2smiles.items():
    smiles_list = [j] + selected_cid2nn_with_double_bond[i]['smiles']
    dis = selected_cid2nn_with_double_bond[i]['dis']
    legends = [i] + ['{:.2f}'.format(d) for d in dis]
    draw_multiple_mol(smiles_list=smiles_list, mols_per_row=6, file_path='./figures/nn_{}_with_double_bond.svg'.format(i), legends=legends)

#### triple bond

In [29]:
tri_bond = get_minus_result(x_embedding=frag2vec_new, x='C#C', y='CC')

In [30]:
selected_cid2nn_with_tri_bond = {}
for i,j in selected_cid2smiles.items():
    selected_cid2nn_with_tri_bond[i] = print_closest_words(x_embedding=mol2vec, x_query=j, add_vec=tri_bond)

In [31]:
for i,j in selected_cid2smiles.items():
    smiles_list = [j] + selected_cid2nn_with_tri_bond[i]['smiles']
    dis = selected_cid2nn_with_tri_bond[i]['dis']
    legends = [i] + ['{:.2f}'.format(d) for d in dis]
    draw_multiple_mol(smiles_list=smiles_list, mols_per_row=6, file_path='./figures/nn_{}_with_tri_bond.svg'.format(i), legends=legends)

#### N atom

In [32]:
n = get_minus_result(x_embedding=frag2vec_new, x='C1CNC1', y='C1CC1')

In [33]:
selected_cid2nn_with_n = {}
for i,j in selected_cid2smiles.items():
    selected_cid2nn_with_n[i] = print_closest_words(x_embedding=mol2vec, x_query=j, add_vec=n)

In [34]:
for i,j in selected_cid2smiles.items():
    smiles_list = [j] + selected_cid2nn_with_n[i]['smiles']
    dis = selected_cid2nn_with_n[i]['dis']
    legends = [i] + ['{:.2f}'.format(d) for d in dis]
    draw_multiple_mol(smiles_list=smiles_list, mols_per_row=6, file_path='./figures/nn_{}_with_n.svg'.format(i), legends=legends)