In [2]:
import pandas as pd
import os
import faiss
import json
import copy
import numpy as np
import pickle

from rdkit import Chem
from rdkit.Chem import AllChem, rdShapeHelpers, DataStructs

import torch
from torch_geometric.data import Data

In [None]:
folder_path = "data/qm9_data"
series_list = []
for i, file_path in enumerate(os.listdir(folder_path)):
    file = os.path.join(folder_path, file_path, "descriptor.json")
    with open(file, "r") as f:
        data = json.load(f)
    series = pd.Series(data)
    series_list.append(series)
    if i % 1000 == 0: print(i)
df = pd.DataFrame(series_list)

In [None]:
df.to_parquet("data/descriptors.parquet")

In [None]:
df = pd.read_parquet("data/descriptors.parquet")

In [None]:
# SCALE DATA
for col in df.drop(["SMILES", "file_path"], axis=1).columns:
    df[col] = (df[col] - df[col].mean()) / df[col].std()

In [None]:
features = copy.deepcopy(df).drop(["SMILES", "file_path"], axis=1)

features["MolWt"] *= 8
features["LogP"] *= 4
features["HOMO"] *= 2
features["LUMO"] *= 2
features["HDonors"] *= 2
features["HAcceptors"] *= 2

In [None]:
index = faiss.IndexFlatL2(len(features.columns))  
index.add(features)

distances, indices = index.search(features, 100)

max_val = np.percentile(distances, 99)
distances = np.clip((1 - (distances / max_val)), 0, 1)

In [None]:
np.save(os.path.join("data", "indices"), indices)
np.save(os.path.join("data", "distances"), distances)

In [None]:
distances = np.load(os.path.join("data", "distances.npy"))
indices = np.load(os.path.join("data", "indices.npy"))

In [None]:
def calculate_similarity(mol1, mol2):
    
    # TANIMOTO SIMILARITY
    fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048)
    fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048)
    similarity_2D = DataStructs.TanimotoSimilarity(fp1, fp2)

    # 3D SIMILARITY
    similarity_3D = rdShapeHelpers.ShapeTanimotoDist(mol1, mol2)

    return similarity_2D, similarity_3D

In [None]:
def find_most_similar_mol(df, distances, indices):
    mol1 = Chem.MolFromMolFile(df.iloc[indices[0]]["file_path"], removeHs=False)

    scores_2d = []
    scores_3d = []
    for idx in indices[1:]:
        mol2 = Chem.MolFromMolFile(df.iloc[idx]["file_path"], removeHs=False)
        score_2d, score_3d = calculate_similarity(mol1, mol2)
        scores_2d.append(score_2d)
        scores_3d.append(score_3d)

    scores_2d = np.array(scores_2d)
    scores_3d = np.array(scores_3d)
    distances = np.array(distances[1:])

    distances = distances / (distances.max() + 1e-10)
    scores_2d = scores_2d / (scores_2d.max() + 1e-10)
    scores_3d = scores_3d / (scores_3d.max() + 1e-10)

    scores = scores_2d + scores_3d + distances

    sorted_indices = np.argsort(scores)[::-1][:5]
    smiles_list = [df.iloc[indices[i + 1]]["file_path"] for i in sorted_indices]

    return smiles_list

In [None]:
final_dict = {}

for i in range(len(indices)):
    if i % 100 == 0:
        print(i)
    idx = indices[i]
    dist = distances[i]
    res = find_most_similar_mol(df, dist, idx)

    file_path = df.iloc[idx[0]]["file_path"]
    final_dict[file_path] = res

print(final_dict)

In [None]:
with open('data/similar_mol.json', "w") as f:
    json.dump(final_dict, f)

In [None]:
with open("data/similar_mol.json", "r") as f:
    similar_mol_dict = json.load(f)

In [None]:
def read_graph(mol_path):
    node_df = pd.read_parquet(os.path.join(mol_path, "nodes.parquet"))
    edge_matrix = np.load(os.path.join(mol_path, "edges.npy"), allow_pickle=True)

    x = torch.tensor(node_df.to_numpy(), dtype=torch.float).cuda()

    edge_index = []
    edge_attr = []
    for edge in edge_matrix:
        edge_index.append([edge[0], edge[1]])
        edge_index.append([edge[1], edge[0]])

        edge_attr.append(edge[2:])
        edge_attr.append(edge[2:])
        
    edge_index = torch.tensor(edge_index, dtype=torch.long).cuda()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).cuda()

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [None]:
final_df_dict = {}
for i, key in enumerate(similar_mol_dict):
    final_df_dict[key] = read_graph(key.replace("molecule.mol", ""))
    print(i)
    
with open("data\\final_df_dict.pkl", "wb") as f:
    pickle.dump(final_df_dict, f)