In [5]:
import argparse
import os
import sys
sys.path.append('../')

import numpy as np
from rdkit import Chem
from rdkit import RDLogger
import torch
from tqdm.auto import tqdm
from glob import glob
from collections import Counter

from utils.evaluation import eval_atom_type, scoring_func, analyze, eval_bond_length
from utils import misc, reconstruct, transforms
from utils.evaluation.docking_qvina import QVinaDockingTask
from utils.evaluation.docking_vina import VinaDockingTask


def print_dict(d, logger):
    for k, v in d.items():
        if v is not None:
            logger.info(f'{k}:\t{v:.4f}')
        else:
            logger.info(f'{k}:\tNone')


def print_ring_ratio(all_ring_sizes, logger):
    for ring_size in range(3, 10):
        n_mol = 0
        for counter in all_ring_sizes:
            if ring_size in counter:
                n_mol += 1
        logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}')


In [12]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--sample_path', type=str, default='./sampling_results/official/')
# parser.add_argument('--verbose', type=eval, default=False)
# parser.add_argument('--eval_step', type=int, default=-1)
# parser.add_argument('--eval_num_examples', type=int, default=None)
# parser.add_argument('--save', type=eval, default=True)
# # parser.add_argument('--protein_root', type=str, default='./data/crossdocked_v1.1_rmsd1.0')
# parser.add_argument('--protein_root', type=str, default='./data/test_set')
# parser.add_argument('--atom_enc_mode', type=str, default='add_aromatic')
# parser.add_argument('--docking_mode', type=str, default='vina_score', choices=['qvina', 'vina_score', 'vina_dock', 'none'])
# parser.add_argument('--exhaustiveness', type=int, default=16)
# args = parser.parse_args()

class Config:
    sample_path = '../sampling_results/official/'
    verbose = False
    eval_step = -1
    eval_num_examples = None
    save = True
    protein_root = '../data/test_set'
    atom_enc_mode = 'add_aromatic'
    docking_mode = 'vina_score'
    exhaustiveness = 16

args = Config()

In [13]:
result_path = os.path.join(args.sample_path, 'eval_results')
os.makedirs(result_path, exist_ok=True)
logger = misc.get_logger('evaluate', log_dir=result_path)
if not args.verbose:
    RDLogger.DisableLog('rdApp.*')

# Load generated data
results_fn_list = glob(os.path.join(args.sample_path, '*.pt'))
# results_fn_list = sorted(results_fn_list, key=lambda x: int(os.path.basename(x)[:-3].split('_')[-1]))
results_fn_list = sorted(results_fn_list)
if args.eval_num_examples is not None:
    results_fn_list = results_fn_list[:args.eval_num_examples]
num_examples = len(results_fn_list)
logger.info(f'Load generated data done! {num_examples} examples in total.')


[2023-10-18 04:01:56,602::evaluate::INFO] Load generated data done! 1 examples in total.


In [20]:
res = torch.load(results_fn_list[0])
print(len(res), type(res))
print(len(res[0]), type(res[0]))
print(len(res[0][0]), type(res[0][0]))
for k, v in res[0][0].items():
    print(k, type(v))
    if isinstance(v, list):
        print(len(v))

100 <class 'list'>
84 <class 'list'>
7 <class 'dict'>
mol <class 'rdkit.Chem.rdchem.Mol'>
smiles <class 'str'>
ligand_filename <class 'str'>
pred_pos <class 'numpy.ndarray'>
pred_v <class 'numpy.ndarray'>
chem_results <class 'dict'>
vina <class 'dict'>
