Skip to content

Commit

Permalink
Add split to new dataset API (#299)
Browse files Browse the repository at this point in the history
* split

* clean

* docs

* docs

* Update new.py
  • Loading branch information
yueyericardo authored and zasdfgbnm committed Aug 22, 2019
1 parent b9e2c25 commit 9639d71
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Datasets
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
Expand Down
18 changes: 13 additions & 5 deletions tests/test_data_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def testLoadDataset(self):
for i, _ in enumerate(self.ds):
pbar.update(i)

def testSplitDataset(self):
print('=> test splitting dataset')
train_ds, val_ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2, validation_split=0.1)
frac = len(val_ds) / (len(val_ds) + len(train_ds))
self.assertLess(abs(frac - 0.1), 0.05)

def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
Expand Down Expand Up @@ -91,11 +97,13 @@ def testTensorShape(self):

def testLoadDataset(self):
print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds))
for i, _ in enumerate(self.ds):
pbar.update(i)
self.ds.load()

def testSplitDataset(self):
print('=> test splitting dataset')
train_dataset, val_dataset = self.ds.split(0.1)
frac = len(val_dataset) / len(self.ds)
self.assertLess(abs(frac - 0.1), 0.05)

def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
Expand Down
75 changes: 64 additions & 11 deletions torchani/data/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def __init__(self, file_path,

anidata = anidataloader(file_path)
anidata_size = anidata.group_size()
enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if enable_pkbar:
self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)

for i, molecule in enumerate(anidata):
Expand All @@ -92,7 +92,7 @@ def __init__(self, file_path,
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
self.data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))

if enable_pkbar:
if self.enable_pkbar:
pbar.update(i)

if subtract_self_energies:
Expand Down Expand Up @@ -172,6 +172,43 @@ def __getitem__(self, index):
def __len__(self):
return self.length

def split(self, validation_split):
"""Split dataset into traning and validaiton.
Arguments:
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
"""
val_size = int(validation_split * len(self))
train_size = len(self) - val_size

ds = []
if self.enable_pkbar:
message = ('=> processing, splitting and caching dataset into cpu memory: \n'
+ 'total batches: {}, train batches: {}, val batches: {}, batch_size: {}')
pbar = pkbar.Pbar(message.format(len(self), train_size, val_size, self.batch_size),
len(self))
for i, _ in enumerate(self):
ds.append(self[i])
if self.enable_pkbar:
pbar.update(i)

train_dataset = ds[:train_size]
val_dataset = ds[train_size:]

return train_dataset, val_dataset

def load(self):
"""Cache dataset into CPU memory. If not called, dataset will be cached during the first epoch.
"""
if self.enable_pkbar:
pbar = pkbar.Pbar('=> processing and caching dataset into cpu memory: \ntotal '
+ 'batches: {}, batch_size: {}'.format(len(self), self.batch_size),
len(self))
for i, _ in enumerate(self):
if self.enable_pkbar:
pbar.update(i)

@staticmethod
def sort_list_with_index(inputs, index):
return [inputs[i] for i in index]
Expand Down Expand Up @@ -229,6 +266,7 @@ def release_h5(self):

def ShuffledDataset(file_path,
batch_size=1000, num_workers=0, shuffle=True, chunk_threshold=20,
validation_split=0.0,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False,
self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]):
Expand All @@ -242,6 +280,8 @@ def ShuffledDataset(file_path,
shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
will not split chunks.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
Expand Down Expand Up @@ -273,14 +313,27 @@ def ShuffledDataset(file_path,
def my_collate_fn(data, chunk_threshold=chunk_threshold):
return collate_fn(data, chunk_threshold)

data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)

return data_loader
val_size = int(validation_split * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)
if val_size == 0:
return train_data_loader

val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)

return train_data_loader, val_data_loader


class TorchData(torch.utils.data.Dataset):
Expand Down

0 comments on commit 9639d71

Please sign in to comment.