Skip to content

Commit

Permalink
Remove AEV cacher (#361)
Browse files Browse the repository at this point in the history
* Remove AEV cacher

* more

* more

* more

* flake8

* further cleanup
  • Loading branch information
zasdfgbnm committed Nov 6, 2019
1 parent 7cdd405 commit 86500df
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 270 deletions.
1 change: 0 additions & 1 deletion .github/workflows/runnable_submodules.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@ jobs:
run: ./download.sh
- name: Run submodules
run: |
python -m torchani.data.cache_aev tmp dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 256
python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.yaml dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
4 changes: 1 addition & 3 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ Datasets
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
.. autoclass:: torchani.data.AEVCacheLoader
.. automodule:: torchani.data.cache_aev



Utilities
Expand Down
33 changes: 0 additions & 33 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torchani
import unittest
from torchani.data.cache_aev import cache_aev, cache_sparse_aev

path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4')
Expand Down Expand Up @@ -87,38 +86,6 @@ def testNoUnnecessaryPadding(self):
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)

def testAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.AEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)

def testSparseAEVCacheLoader(self):
tmpdir = os.path.join(os.getcwd(), 'tmp')
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
cache_sparse_aev(tmpdir, dataset_path2, 64, enable_tqdm=False)
loader = torchani.data.SparseAEVCacheLoader(tmpdir)
ds = loader.dataset
aev_computer_dev = aev_computer.to(loader.dataset.device)
for _ in range(3):
for (species_aevs, _), (species_coordinates, _) in zip(loader, ds):
for (s1, a), (s2, c) in zip(species_aevs, species_coordinates):
self._assertTensorEqual(s1, s2)
s2, a2 = aev_computer_dev((s2, c))
self._assertTensorEqual(s1, s2)
self._assertTensorEqual(a, a2)


if __name__ == '__main__':
unittest.main()
156 changes: 2 additions & 154 deletions torchani/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import os
from ._pyanitools import anidataloader
import torch
from .. import utils, neurochem, aev, models
import pickle
import numpy as np
from scipy.sparse import bsr_matrix
from .. import utils
import warnings
from .new import CachedDataset, ShuffledDataset, find_threshold

Expand Down Expand Up @@ -364,153 +361,4 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
return tuple(ret)


class AEVCacheLoader(Dataset):
"""Build a factory for AEV.
The computation of AEV is the most time consuming part during training.
Since during training, the AEV never changes, it is not hard to see that,
If we have a fast enough storage (this is usually the case for good SSDs,
but not for HDD), we could cache the computed AEVs into disk and load it
rather than compute it from scratch everytime we use it.
Arguments:
disk_cache (str): Directory storing disk caches.
"""

def __init__(self, disk_cache=None):
super(AEVCacheLoader, self).__init__()
self.disk_cache = disk_cache

# load dataset from disk cache
dataset_path = os.path.join(disk_cache, 'dataset')
with open(dataset_path, 'rb') as f:
self.dataset = pickle.load(f)

def __getitem__(self, index):
_, output = self.dataset.batches[index]
aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f)
for i, sa in enumerate(species_aevs):
species, aevs = self.decode_aev(*sa)
species_aevs[i] = (
species.to(self.dataset.device),
aevs.to(self.dataset.device)
)
return species_aevs, output

def __len__(self):
return len(self.dataset)

@staticmethod
def decode_aev(encoded_species, encoded_aev):
return encoded_species, encoded_aev

@staticmethod
def encode_aev(species, aev):
return species, aev


class SparseAEVCacheLoader(AEVCacheLoader):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
AEV never changes during training and contains a large number of zeros.
Therefore, we can store the computed AEVs as sparse representation and
load it during the training rather than compute it from scratch. The
storage requirement for ```'cache_sparse_aev'``` is considerably less
than ```'cache_aev'```.
Arguments:
disk_cache (str): Directory storing disk caches.
"""

@staticmethod
def decode_aev(encoded_species, encoded_aev):
species = torch.from_numpy(encoded_species.todense())
aevs_np = np.stack([np.array(i.todense()) for i in encoded_aev], axis=0)
aevs = torch.from_numpy(aevs_np)
return species, aevs

@staticmethod
def encode_aev(species, aev):
encoded_species = bsr_matrix(species.cpu().numpy())
encoded_aev = [bsr_matrix(i.cpu().numpy()) for i in aev]
return encoded_species, encoded_aev


ani1x = models.ANI1x()


def create_aev_cache(dataset, aev_computer, output, progress_bar=True, encoder=lambda *x: x):
"""Cache AEV for the given dataset.
Arguments:
dataset (:class:`torchani.data.PaddedBatchChunkDataset`): the dataset to be cached
aev_computer (:class:`torchani.AEVComputer`): the AEV computer used to compute aev
output (str): path to the directory where cache will be stored
progress_bar (bool): whether to show progress bar
encoder (:class:`collections.abc.Callable`): The callable
(species, aev) -> (encoded_species, encoded_aev) that encode species and aev
"""
# dump out the dataset
filename = os.path.join(output, 'dataset')
with open(filename, 'wb') as f:
pickle.dump(dataset, f)

if progress_bar:
import tqdm
indices = tqdm.trange(len(dataset))
else:
indices = range(len(dataset))
for i in indices:
input_, _ = dataset[i]
aevs = [encoder(*aev_computer(j)) for j in input_]
filename = os.path.join(output, '{}'.format(i))
with open(filename, 'wb') as f:
pickle.dump(aevs, f)


def _cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, encoder, **kwargs):
# if output directory does not exist, then create it
if not os.path.exists(output):
os.makedirs(output)

device = torch.device(device)
consts = neurochem.Constants(constfile)
aev_computer = aev.AEVComputer(**consts).to(device)

if subtract_sae:
energy_shifter = neurochem.load_sae(sae_file)
transform = (energy_shifter.subtract_from_dataset,)
else:
transform = ()

dataset = load_ani_dataset(
dataset_path, consts.species_to_tensor, batchsize,
device=device, transform=transform, **kwargs
)

create_aev_cache(dataset, aev_computer, output, enable_tqdm, encoder)


def cache_aev(output, dataset_path, batchsize, device=default_device,
constfile=ani1x.const_file, subtract_sae=False,
sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm, AEVCacheLoader.encode_aev,
**kwargs)


def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
constfile=ani1x.const_file, subtract_sae=False,
sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
_cache_aev(output, dataset_path, batchsize, device, constfile,
subtract_sae, sae_file, enable_tqdm,
SparseAEVCacheLoader.encode_aev, **kwargs)


__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader',
'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev',
'CachedDataset', 'ShuffledDataset', 'find_threshold']
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'CachedDataset', 'ShuffledDataset', 'find_threshold']
47 changes: 0 additions & 47 deletions torchani/data/cache_aev.py

This file was deleted.

39 changes: 11 additions & 28 deletions torchani/neurochem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,28 +274,23 @@ class Trainer:
tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboard.
aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints
will be stored in the network directory with this file name.
"""

def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, aev_caching=False,
checkpoint_name='model.pt'):
tensorboard=None, checkpoint_name='model.pt'):

from ..data import load_ani_dataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402

class dummy:
pass

self.imports = dummy()
self.imports.load_ani_dataset = load_ani_dataset
self.imports.AEVCacheLoader = AEVCacheLoader

self.filename = filename
self.device = device
self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name
self.weights = []
self.biases = []
Expand Down Expand Up @@ -540,11 +535,7 @@ def init_params(m):

# initialize weights and biases
self.nn.apply(init_params)

if self.aev_caching:
self.model = self.nn.to(self.device)
else:
self.model = Sequential(self.aev_computer, self.nn).to(self.device)
self.model = Sequential(self.aev_computer, self.nn).to(self.device)

# loss functions
self.mse_se = torch.nn.MSELoss(reduction='none')
Expand All @@ -556,23 +547,15 @@ def init_params(m):
self.best_validation_rmse = math.inf

def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.
If AEV caching is enabled, then the arguments are path to the cache
directory, otherwise it should be path to the dataset.
"""
if self.aev_caching:
self.training_set = self.imports.AEVCacheLoader(training_path)
self.validation_set = self.imports.AEVCacheLoader(validation_path)
else:
self.training_set = self.imports.load_ani_dataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
"""Load training and validation dataset from file."""
self.training_set = self.imports.load_ani_dataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])

def evaluate(self, dataset):
"""Run the evaluation"""
Expand Down
5 changes: 1 addition & 4 deletions torchani/neurochem/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@
parser.add_argument('--tensorboard',
help='Directory to store tensorboard log files',
default=None)
parser.add_argument('--cache-aev', dest='cache_aev', action='store_true',
help='Whether to cache AEV', default=None)
parser.add_argument('--checkpoint_name',
help='Name of checkpoint file',
default='model.pt')
parser = parser.parse_args()

d = torch.device(parser.device)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard,
parser.cache_aev, parser.checkpoint_name)
trainer = Trainer(parser.config_path, d, parser.tqdm, parser.tensorboard, parser.checkpoint_name)
trainer.load_data(parser.training_path, parser.validation_path)
trainer.run()

0 comments on commit 86500df

Please sign in to comment.