In [None]:
import pandas as pd
from hestia.similarity import *
from hestia.partition import ccpart
from hestia import HestiaGenerator
from hestia.clustering import _connected_components_clustering
from tqdm import tqdm


In [None]:
df = pd.read_csv("data/plumber.csv", sep=',')
df.head()

In [None]:
df1_field = 'seq'
df2_field = 'smiles'

sim_fun_ent1 = sequence_similarity_mmseqs
sim_fun_ent2 = molecular_similarity

sim_args_ent1 = {
    "field_name": "seq",
    "threshold": 0.3,
    "verbose": 3
}
sim_args_ent2={
    "field_name": "smiles",
    "fingerprint": "ecfp",
    "radius": 2,
    "threshold": 0.3,
    "verbose": 3,
    "bits": 1024
}

In [None]:
unique_df1 = df.drop_duplicates(df1_field).reset_index(drop=True)
df1_to_df = df.groupby(df1_field).apply(lambda g: g.index.to_numpy())

sim_df_1 = sim_fun_ent1(df_query=unique_df1, **sim_args_ent1)
train, test, clusters = ccpart(
    df=unique_df1, sim_df=sim_df_1, threshold=0.3, verbose=True
)


In [None]:
train_indices, test_indices = [], []
for indx in test:
    indcs = df1_to_df[indx]
    test_indices.extend(indcs)
for indx in train:
    indcs = df1_to_df[indx]
    train_indices.extend(indcs)

test_df, train_df = df.iloc[test_indices].reset_index(), df.iloc[train_indices].reset_index()

u_test = test_df.drop_duplicates(df2_field).reset_index()
u_train = train_df.drop_duplicates(df2_field).reset_index()

u_test_to_df = test_df.groupby(df2_field).apply(lambda g: g.index.to_numpy())
# u_train_to_df = u_train.groupby(df2_field).apply(lambda g: g.index.to_numpy())

unique_test_mols, unique_train_mols = u_test.smiles, u_train.smiles
print(len(unique_test_mols)/1e6, len(unique_train_mols)/1e6)



In [None]:
try:
    from rdkit import Chem
    from rdkit.Chem import rdFingerprintGenerator, rdMolDescriptors
    from rdkit.DataStructs import (
        BulkTanimotoSimilarity, BulkDiceSimilarity,
        BulkSokalSimilarity, BulkRogotGoldbergSimilarity,
        BulkCosineSimilarity)
    from rdkit import RDLogger
    from rdkit import rdBase

    def disable_rdkit_log():
        """Disable all rdkit logs."""
        for log_level in RDLogger._levels:
            rdBase.DisableLog(log_level)

    disable_rdkit_log()

except ModuleNotFoundError:
    raise ImportError("This function requires RDKit to be installed.")

radius, bits = 2, 1024
fpgen = rdFingerprintGenerator.GetMorganGenerator(
    radius=radius, fpSize=bits
)
sim_function = 'tanimoto'

def _get_fp(smile: str):
    mol = Chem.MolFromSmiles(smile, sanitize=True)

    if mol is None:
        print(f"SMILES: `{smile}` could not be processed. Will be substituted by `{smile[1:-1]}`")
        return _get_fp(smile[1:-1])

    fp = fpgen.GetFingerprint(mol)
    return fp

def _parallel_fps(mols: List[str], mssg: str) -> list:
    fps = []
    jobs = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        for mol in mols:
            job = executor.submit(_get_fp, mol)
            jobs.append(job)
        pbar = tqdm(jobs, desc=mssg, unit_scale=True,
                    mininterval=0.5, maxinterval=2)
        for job in pbar:
            if job.exception() is not None:
                raise RuntimeError(job.exception())
            result = job.result()
            fps.append(result)

    pbar.close()
    return fps

test_fps = _parallel_fps(unique_test_mols, "Test mols")

In [None]:
from itertools import islice
from rdkit.DataStructs import ConvertToNumpyArray
def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

def compare(mol, test_fps):
    fp = _get_fp(mol)
    sim = BulkTanimotoSimilarity(fp, test_fps)
    return sim

all_results = set()
threshold = sim_args_ent2['threshold']

test_size = len(unique_test_mols)
pbar = tqdm(unique_train_mols, unit_scale=True)
import copy
tmp_fps = copy.deepcopy(test_fps)
prev_len = len(tmp_fps)
tmp_u_test = copy.deepcopy(unique_test_mols)

for idx, mol in enumerate(pbar):
    fp = _get_fp(mol)
    out = BulkTanimotoSimilarity(fp, tmp_fps)
    out = np.array(out)
    removed = [f for mni_idx, f in enumerate(tmp_u_test) if out[mni_idx] >= threshold]
    tmp_fps = [f for mni_idx, f in enumerate(tmp_fps) if out[mni_idx] < threshold]
    tmp_u_test = [f for mni_idx, f in enumerate(tmp_u_test) if out[mni_idx] < threshold]

    if len(tmp_fps) < prev_len:
        out_u = np.argwhere(out > threshold)
        all_results.update(removed)
    if idx % 100 == 0:
        pbar.set_description(f"Include: {len(tmp_fps):,} / {test_size:,}")
        if len(tmp_fps) == 0:
            print("Saturation")
            break

In [None]:
loose_test = test_df[test_df.smiles.isin(all_results)]
strict_test = test_df[~test_df.smiles.isin(all_results)]
print(len(loose_test) / len(test_df), len(strict_test)/len(test_df))