In [1]:
import h5py
import torch
from selected_system import mols, mol_file
from ani_benchmark import NeighborBenchmark, FreeNeighborBenchmark, NoNeighborBenchmark
import pandas
import os
import tqdm
from IPython.display import display

def pretty_print(orig_dict):
    """Pretty print times in a dictionary"""
    ret = {}
    for i in orig_dict:
        if orig_dict[i] is None:
            ret[i] = None
        else:
            ms = int(orig_dict[i])
            s = ms / 1000
            m = s / 60
            h = m / 60
            if ms < 1000:
                p = '{}ms'.format(ms)
            elif s < 60:
                p = '{:.2f}s'.format(s)
            elif m < 60:
                p = '{:.2f}min'.format(m)
            else:
                p = '{:.2f}min'.format(h)
            ret[i] = p
    return ret

  from ._conv import register_converters as _register_converters


In [2]:
torch.set_num_threads(1)
fm = h5py.File(os.path.join('../',mol_file), "r")

benchmarks = {
    'N,C': NeighborBenchmark(device=torch.device("cpu")),
    # 'N,G': NeighborBenchmark(device=torch.device("cuda")),
    'F,C': FreeNeighborBenchmark(device=torch.device("cpu")),
    # 'F,G': FreeNeighborBenchmark(device=torch.device("cuda")),
    'X,C': NoNeighborBenchmark(device=torch.device("cpu")),
    # 'X,G': NoNeighborBenchmark(device=torch.device("cuda")),
}

In [3]:
for i in mols:
    print('number of atoms:', i)
    smiles = mols[i]
    for s in smiles:
        print('Running benchmark on molecule', s)
        key = s.replace('/', '_')
        coordinates = torch.from_numpy(fm[key][()])
        species = fm[key].attrs['species'].split()
        results = {}
        for b in tqdm.tqdm_notebook(benchmarks):
            bench = benchmarks[b]
            coordinates = coordinates.type(bench.aev_computer.dtype)
            try:
                result = bench.oneByOne(coordinates, species)
            except RuntimeError as e:
                print(e)
                result = {'aev': None, 'energy': None, 'force': None}
            results[b + ',1'] = pretty_print(result)
            try:
                result = bench.inBatch(coordinates, species)
            except RuntimeError as e:
                print(e)
                result = {'aev': None, 'energy': None, 'force': None}
            results[b + ',B'] = pretty_print(result)
        df = pandas.DataFrame(results)
        display(df)
        break
    break

number of atoms: 20
Running benchmark on molecule COC(=O)c1ccc([N+](=O)[O-])cc1


KeyboardInterrupt: 