In [5]:
import argparse
import os
import sys
sys.path.append('../')

import numpy as np
from rdkit import Chem
from rdkit import RDLogger
import torch
from tqdm.auto import tqdm
from glob import glob
from collections import Counter

# from utils.evaluation import eval_atom_type, scoring_func, analyze, eval_bond_length
from utils import misc#, reconstruct, transforms
# from utils.evaluation.docking_qvina import QVinaDockingTask
# from utils.evaluation.docking_vina import VinaDockingTask


def print_dict(d, logger):
    for k, v in d.items():
        if v is not None:
            logger.info(f'{k}:\t{v:.4f}')
        else:
            logger.info(f'{k}:\tNone')


def print_ring_ratio(all_ring_sizes, logger):
    for ring_size in range(3, 10):
        n_mol = 0
        for counter in all_ring_sizes:
            if ring_size in counter:
                n_mol += 1
        logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}')


In [6]:
class Config:
    sample_path = '../sampling_results/official/'
    verbose = False
    eval_step = -1
    eval_num_examples = None
    save = True
    protein_root = '../data/test_set'
    atom_enc_mode = 'add_aromatic'
    docking_mode = 'vina_score'
    exhaustiveness = 16

args = Config()

In [7]:
result_path = os.path.join(args.sample_path, 'eval_results')
os.makedirs(result_path, exist_ok=True)
logger = misc.get_logger('evaluate', log_dir=result_path)
if not args.verbose:
    RDLogger.DisableLog('rdApp.*')

# Load generated data
results_fn_list = glob(os.path.join(args.sample_path, '*.pt'))
# results_fn_list = sorted(results_fn_list, key=lambda x: int(os.path.basename(x)[:-3].split('_')[-1]))
results_fn_list = sorted(results_fn_list)
if args.eval_num_examples is not None:
    results_fn_list = results_fn_list[:args.eval_num_examples]
num_examples = len(results_fn_list)
logger.info(f'Load generated data done! {num_examples} examples in total.')


[2023-10-19 07:00:02,612::evaluate::INFO] Load generated data done! 1 examples in total.


In [9]:
results = torch.load(results_fn_list[0])
print(len(results), type(results))
print(len(results[0]), type(results[0]))
print(len(results[0][0]), type(results[0][0]))
for k, v in results[0][0].items():
    print(k, type(v))
    if isinstance(v, list):
        print(len(v))

100 <class 'list'>
84 <class 'list'>
7 <class 'dict'>
mol <class 'rdkit.Chem.rdchem.Mol'>
smiles <class 'str'>
ligand_filename <class 'str'>
pred_pos <class 'numpy.ndarray'>
pred_v <class 'numpy.ndarray'>
chem_results <class 'dict'>
vina <class 'dict'>


In [10]:
results = results[0]

qed = [r['chem_results']['qed'] for r in results]
sa = [r['chem_results']['sa'] for r in results]
logger.info('QED:   Mean: %.3f Median: %.3f' % (np.mean(qed), np.median(qed)))
logger.info('SA:    Mean: %.3f Median: %.3f' % (np.mean(sa), np.median(sa)))
if args.docking_mode == 'qvina':
    vina = [r['vina'][0]['affinity'] for r in results]
    logger.info('Vina:  Mean: %.3f Median: %.3f' % (np.mean(vina), np.median(vina)))
elif args.docking_mode in ['vina_dock', 'vina_score']:
    vina_score_only = [r['vina']['score_only'][0]['affinity'] for r in results]
    vina_min = [r['vina']['minimize'][0]['affinity'] for r in results]
    logger.info('Vina Score:  Mean: %.3f Median: %.3f' % (np.mean(vina_score_only), np.median(vina_score_only)))
    logger.info('Vina Min  :  Mean: %.3f Median: %.3f' % (np.mean(vina_min), np.median(vina_min)))
    if args.docking_mode == 'vina_dock':
        vina_dock = [r['vina']['dock'][0]['affinity'] for r in results]
        logger.info('Vina Dock :  Mean: %.3f Median: %.3f' % (np.mean(vina_dock), np.median(vina_dock)))


[2023-10-19 07:01:24,463::evaluate::INFO] QED:   Mean: 0.350 Median: 0.327
[2023-10-19 07:01:24,466::evaluate::INFO] SA:    Mean: 0.546 Median: 0.540
[2023-10-19 07:01:24,467::evaluate::INFO] Vina Score:  Mean: -3.437 Median: -5.543
[2023-10-19 07:01:24,467::evaluate::INFO] Vina Min  :  Mean: -5.262 Median: -6.543
