In [1]:
import h5py
import torch
from selected_system import mols, mol_file
from ani_benchmark import ANIBenchmark
import pandas
import os
import tqdm
from IPython.display import display
import itertools
tqdm.monitor_interval = 0

def pretty_print(orig_dict):
    """Pretty print times in a dictionary"""
    ret = {}
    for i in orig_dict:
        if orig_dict[i] is None:
            ret[i] = None
        else:
            ms = int(orig_dict[i])
            s = ms / 1000
            m = s / 60
            h = m / 60
            if ms < 1000:
                p = '{}ms'.format(ms)
            elif s < 60:
                p = '{:.2f}s'.format(s)
            elif m < 60:
                p = '{:.2f}min'.format(m)
            else:
                p = '{:.2f}min'.format(h)
            ret[i] = p
    return ret

  from ._conv import register_converters as _register_converters


van Der Waals correction will be unavailable. Please install ased3


In [2]:
torch.set_num_threads(1)
fm = h5py.File(os.path.join('../',mol_file), "r")

benchmarks = {
    'C': ANIBenchmark(device=torch.device("cpu")),
}

if torch.cuda.is_available():
    benchmarks.update({
        'G': ANIBenchmark(device=torch.device("cuda")),
    })

In [3]:
for i in mols:
    print('number of atoms:', i)
    smiles = mols[i]
    for s in smiles:
        print('Running benchmark on molecule', s)
        key = s.replace('/', '_')
        coordinates = torch.from_numpy(fm[key][()])
        coordinates = coordinates[:200]
        species = fm[key].attrs['species'].split()
        results = {}
        for b,m in tqdm.tqdm_notebook(list(itertools.product(benchmarks, ['1','B']))):
            bench = benchmarks[b]
            coordinates = coordinates.type(bench.aev_computer.dtype)
            try:
                if m == '1':
                    result = bench.oneByOne(coordinates, species)
                elif m == 'B':
                    result = bench.inBatch(coordinates, species)
                else:
                    raise ValueError('BUG here')
                result['forward'] = result['aev'] + result['energy'] + result['neighborlist']
                result['total'] = result['forward'] + result['force']
            except RuntimeError as e:
                print(e)
                result = {'aev': None, 'energy': None, 'force': None, 'total': None, 'neighborlist': None }
            results[b + ',' + m] = pretty_print(result)
        df = pandas.DataFrame(results)
        display(df)
        break

number of atoms: 20
Running benchmark on molecule COC(=O)c1ccc([N+](=O)[O-])cc1


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,8.14s,391ms,21.39s,154ms
energy,653ms,28ms,1.57s,7ms
force,9.37s,864ms,30.65s,193ms
forward,9.20s,424ms,25.62s,173ms
neighborlist,408ms,4ms,2.65s,11ms
total,18.57s,1.29s,56.27s,367ms


number of atoms: 50
Running benchmark on molecule O=[N+]([O-])c1ccc(NN=Cc2ccc(C=NNc3ccc([N+](=O)[O-])cc3[N+](=O)[O-])cc2)c([N+](=O)[O-])c1


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,24.39s,1.70s,1.05min,524ms
energy,1.51s,63ms,3.91s,20ms
force,31.49s,3.88s,1.52min,718ms
forward,26.90s,1.79s,1.23min,588ms
neighborlist,1.00s,35ms,6.70s,43ms
total,58.40s,5.68s,2.76min,1.31s


number of atoms: 10
Running benchmark on molecule N#CCC(=O)N


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,4.71s,156ms,12.29s,98ms
energy,339ms,14ms,798ms,4ms
force,5.01s,288ms,17.38s,106ms
forward,5.26s,173ms,14.43s,108ms
neighborlist,205ms,2ms,1.33s,5ms
total,10.27s,461ms,31.80s,214ms


number of atoms: 4,5,6
Running benchmark on molecule C


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,834ms,18ms,2.18s,12ms
energy,175ms,6ms,435ms,2ms
force,958ms,49ms,3.37s,21ms
forward,1.11s,25ms,3.30s,17ms
neighborlist,101ms,0ms,686ms,3ms
total,2.07s,74ms,6.67s,39ms


number of atoms: 100
Running benchmark on molecule CC(C)C[C@@H](C(=O)O)NC(=O)C[C@@H]([C@H](CC1CCCCC1)NC(=O)CC[C@@H]([C@H](Cc2ccccc2)NC(=O)OC(C)(C)C)O)O


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,55.33s,8.82s,2.48min,1.74s
energy,2.85s,126ms,7.59s,40ms
force,1.45min,22.75s,3.59min,2.24s
forward,1.00min,9.03s,2.84min,1.85s
neighborlist,2.02s,78ms,14.11s,69ms
total,2.45min,31.77s,6.43min,4.09s


number of atoms: 305
Running benchmark on molecule [H]/N=C(/N)\NCCC[C@H](C(=O)N[C@H]([C@@H](C)O)C(=O)N[C@H](Cc1ccc(cc1)O)C(=O)NCCCC[C@@H](C(=O)NCCCC[C@@H](C(=O)NCC(=O)O)NC(=O)[C@H](CCCCNC(=O)[C@@H](Cc2ccc(cc2)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\[H])/N)N)NC(=O)[C@@H](Cc3ccc(cc3)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\[H])/N)N)NC(=O)[C@@H](Cc4ccc(cc4)O)NC(=O)[C@@H]([C@@H](C)O)NC(=O)[C@@H](CCCN/C(=N\[H])/N)N)N


A Jupyter Widget




Unnamed: 0,"C,1","C,B","G,1","G,B"
aev,2.87min,25.13s,7.85min,9.33s
energy,8.54s,416ms,23.94s,212ms
force,5.27min,3.33min,12.71min,11.43s
forward,3.14min,26.27s,9.05min,10.07s
neighborlist,7.35s,721ms,48.06s,523ms
total,8.40min,3.76min,21.76min,21.49s
