In [1]:
import os
import sys
import pandas as pd
import numpy as np
import random
from tqdm import tqdm

In [2]:
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem

In [3]:
PROPERTY_NAME = "abcg2"

df_train_pairs = pd.read_csv(os.path.join(PROPERTY_NAME, "rdkit_train_pairs.txt"), sep=" ", header=None); print(df_train_pairs.shape)
list_smi_src = df_train_pairs.iloc[:,0].values.tolist()
list_smi_tar = df_train_pairs.iloc[:,1].values.tolist()

df_valid = pd.read_csv(os.path.join(PROPERTY_NAME, "rdkit_valid.txt"), header=None); print(df_valid.shape)
list_smi_valid = df_valid.iloc[:,0].values.tolist()

df_test = pd.read_csv(os.path.join(PROPERTY_NAME, "rdkit_test.txt"), header=None); print(df_test.shape)
list_smi_test = df_test.iloc[:,0].values.tolist()

(230619, 2)
(500, 1)
(1000, 1)


In [4]:
list_smi_unique = list_smi_src + list_smi_tar
print(len(list_smi_unique))

461238


In [5]:
list_triplet = []
K = 20

for i, (smi_src, smi_tar) in tqdm(enumerate(zip(list_smi_src, list_smi_tar)), total=len(list_smi_src)):
    batch_list_triplet = []
    ## mol
    mol_src = Chem.MolFromSmiles(smi_src)
    mol_tar = Chem.MolFromSmiles(smi_tar)
    ## fingerprint
    fpt_src = AllChem.GetMorganFingerprintAsBitVect(mol_src, radius=2, nBits=2048, useChirality=False)
    fpt_tar = AllChem.GetMorganFingerprintAsBitVect(mol_tar, radius=2, nBits=2048, useChirality=False)
    
    ## Shuffle
    random.shuffle(list_smi_unique)
    
    for smi_neg in list_smi_unique:
        mol_neg = Chem.MolFromSmiles(smi_neg)
        fpt_neg = AllChem.GetMorganFingerprintAsBitVect(mol_neg, radius=2, nBits=2048, useChirality=False)
        ## Tanimoto
        sim_src = DataStructs.TanimotoSimilarity(fpt_src, fpt_neg)
        sim_tar = DataStructs.TanimotoSimilarity(fpt_tar, fpt_neg)
        ## check
        if sim_src < 0.3 and sim_tar < 0.3:
            batch_list_triplet.append((smi_src, smi_tar, smi_neg))
        ## stop
        if len(batch_list_triplet) == K:
            break
            
    if len(batch_list_triplet) < K:
        print(f"[WARNING] {i} has insufficient data ({len(batch_list_triplet)} < {K})")
        
    list_triplet.extend(batch_list_triplet)

  2%|▏         | 5759/230619 [12:01<7:49:41,  7.98it/s]


KeyboardInterrupt: 

In [None]:
df_triplet = pd.DataFrame(list_triplet)

df_triplet.to_csv(os.path.join(PROPERTY_NAME, "rdkit_train_triplet.txt"), sep=" ", header=None, index=False)