In [2]:
import pickle
from pprint import pprint
from posebusters import PoseBusters
import numpy as np
import pandas as pd
from rdkit import Chem
import random
import time
from concurrent.futures import ProcessPoolExecutor
import itertools

In [3]:
path = '/nfs/ap/home/_menuab_/code/3DMolGen/gen_results/2025-05-24-12:36_b1_4e_min_p_sampling12/generation_resutls.pickle'
with open(path, 'rb') as f:
    data = pickle.load(f)

In [7]:
smi, mols = random.choice(list(data.items()))
print(f'SMILES: {smi}')
print(f'Number of molecules: {len(mols)}')


SMILES: CCC(=O)Nc1ccc(S(=O)(=O)Nc2c(NC)c3ccccc3oc2=O)cc1
Number of molecules: 188


In [8]:

print(f"Example SMILES: {smi}")
print(type(mols[1]), len(mols))

buster = PoseBusters(config="mol")
pb_metrix_dict = buster.bust(mols, None, None, full_report=False).mean().to_dict()


Example SMILES: CCC(=O)Nc1ccc(S(=O)(=O)Nc2c(NC)c3ccccc3oc2=O)cc1
<class 'rdkit.Chem.rdchem.Mol'> 188


In [None]:
pb_metrix_dict

{'mol_pred_loaded': 1.0,
 'sanitization': 1.0,
 'inchi_convertible': 1.0,
 'all_atoms_connected': 1.0,
 'bond_lengths': 1.0,
 'bond_angles': 1.0,
 'internal_steric_clash': 0.9871794871794872,
 'aromatic_ring_flatness': 1.0,
 'non-aromatic_ring_non-flatness': 1.0,
 'double_bond_flatness': 1.0,
 'internal_energy': 0.9871794871794872}

In [None]:
def _bust_smi(smi, mols, config, full_report):
    b = PoseBusters(config=config)
    m = b.bust(mols, None, None, full_report=full_report).mean().to_dict()
    m['smiles'] = smi
    return m

def run_all_posebusters(data, config="mol", full_report=False, max_workers=8):
    """
    Run PoseBusters on all molecules in the data dict in parallel.
    Returns a pandas DataFrame of per-SMILES metrics and an overall summary row.
    """
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(
            _bust_smi,
            data.keys(),
            data.values(),
            itertools.repeat(config),
            itertools.repeat(full_report),
        ))
    df = pd.DataFrame(results)
    summary = df.mean(numeric_only=True).to_frame().T
    summary.insert(0, 'smiles', 'ALL')
    return df, summary


In [None]:
sum = 0
for _, mols in data.items():
    sum += len(mols)
sum

195379

In [None]:
print("Number of molecules in the dataset:", sum)
print("Number of SMILES in the dataset:", len(data))

Number of molecules in the dataset: 195379
Number of SMILES in the dataset: 932


In [None]:
metrix, summary = run_all_posebusters(data, config="mol", full_report=False, max_workers=16)



In [None]:
print(f"Total SMILES in dataset: {len(data)}")
smiles_list = list(data.keys())
subset = {smi: data[smi] for smi in smiles_list}

start = time.time()
df5_metrics, df5_summary = run_all_posebusters(subset, max_workers=20)
elapsed = time.time() - start

n_conf = sum(len(mols) for mols in subset.values())

print(f"Processed {len(subset)} SMILES with {n_conf} conformers")
print(f"Total time: {elapsed:.2f}s, time per conformer: {elapsed/n_conf:.3f}s")

df5_summary

Total SMILES in dataset: 932
Processed 100 SMILES with 5912 conformers
Total time: 59.97s, time per conformer: 0.010s
Processed 100 SMILES with 5912 conformers
Total time: 59.97s, time per conformer: 0.010s


Unnamed: 0,smiles,mol_pred_loaded,sanitization,inchi_convertible,all_atoms_connected,bond_lengths,bond_angles,internal_steric_clash,aromatic_ring_flatness,non-aromatic_ring_non-flatness,double_bond_flatness,internal_energy
0,ALL,1.0,1.0,1.0,1.0,0.906403,0.915126,0.949998,0.917861,0.99,1.0,0.89457


In [6]:
# Benchmark different worker counts with random sampling from the data
sample_size = 10  # adjust sample size as needed
sample_keys = random.sample(list(data.keys()), sample_size)
subset_random = {smi: data[smi] for smi in sample_keys}

workers = [1, 2, 4, 8, 12]
benchmark = []
n_conf = sum(len(mols) for mols in subset_random.values())

for w in workers:
    print(f"Running benchmark with {w} workers...")
    print(f"Total conformers in subset: {n_conf}")
    start = time.time()
    _, _, _, _ = run_all_posebusters(subset_random, max_workers=w)
    elapsed = time.time() - start
    benchmark.append({
        'workers': w,
        'total_time_s': elapsed,
        'time_per_conf_s': elapsed / n_conf
    })

df_bench_random = pd.DataFrame(benchmark)
display(df_bench_random)

Running benchmark with 1 workers...
Total conformers in subset: 2923
Running benchmark with 2 workers...
Total conformers in subset: 2923
Running benchmark with 2 workers...
Total conformers in subset: 2923
Running benchmark with 4 workers...
Total conformers in subset: 2923
Running benchmark with 4 workers...
Total conformers in subset: 2923
Running benchmark with 8 workers...
Total conformers in subset: 2923
Running benchmark with 8 workers...
Total conformers in subset: 2923
Running benchmark with 12 workers...
Total conformers in subset: 2923
Running benchmark with 12 workers...
Total conformers in subset: 2923


Running benchmark with 1 workers...
Total conformers in subset: 2923
Running benchmark with 2 workers...
Total conformers in subset: 2923
Running benchmark with 2 workers...
Total conformers in subset: 2923
Running benchmark with 4 workers...
Total conformers in subset: 2923
Running benchmark with 4 workers...
Total conformers in subset: 2923
Running benchmark with 8 workers...
Total conformers in subset: 2923
Running benchmark with 8 workers...
Total conformers in subset: 2923
Running benchmark with 12 workers...
Total conformers in subset: 2923
Running benchmark with 12 workers...
Total conformers in subset: 2923


Unnamed: 0,workers,total_time_s,time_per_conf_s
0,1,248.924565,0.085161
1,2,133.373335,0.045629
2,4,69.544352,0.023792
3,8,68.606421,0.023471
4,12,69.655117,0.02383


In [4]:
def _bust_smi(smi, mols, config, full_report):
    try:
        b = PoseBusters(config=config)
        dfb = b.bust(mols, None, None, full_report=full_report)
        m = dfb.mean().to_dict()
        m['pass_percentage'] = dfb.all(axis=1).mean() * 100
        m['smiles'] = smi
        m['error'] = ''
        return m
    except Exception as e:
        return {'smiles': smi, 'error': str(e)}

def run_all_posebusters(data, config="mol", full_report=False,
                        max_workers=16, fail_threshold=0.0):
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(
            _bust_smi,
            data.keys(),
            data.values(),
            itertools.repeat(config),
            itertools.repeat(full_report),
        ))
    df = pd.DataFrame(results)
    error_smiles = df.loc[df['error'] != '', 'smiles'].tolist()
    if 'failure_rate' in df.columns:
        bad = df['failure_rate'] > fail_threshold
        fail_smiles = df.loc[bad, 'smiles'].tolist()
    else:
        fail_smiles = []
    summary = df[df['error']==''].mean(numeric_only=True).to_frame().T
    summary.insert(0, 'smiles', 'ALL')
    return df, summary, fail_smiles, error_smiles

def load_data(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data