Skip to content

Commit

Permalink
Fixed bug in virtual dataset cached test data; added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 8, 2018
1 parent 12e298d commit 9db37c8
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,11 +1787,11 @@ def __getitem__(self, pos):
self.dataset._current_cache = cache
if original_pos is None:
return super().__getitem__(pos - (cache * self.dataset._cache_size))
else:
pos = original_pos
cache = int(np.floor(pos / self.dataset._cache_size))
else: ## test_ data:
size, num_train, num_test = self.dataset._get_split_sizes()
pos = (original_pos + num_train) % self.dataset._cache_size
self.item = self.item[5:] # "test_" ...
retval = super().__getitem__(pos - (cache * self.dataset._cache_size))
retval = super().__getitem__(pos)
self.item = "test_" + self.item
return retval

Expand Down Expand Up @@ -1842,6 +1842,45 @@ class VirtualDataset(Dataset):
>>> import random
>>> from distutils.version import LooseVersion
>>> def f(self, cache):
... pos1 = cache * self._cache_size
... pos2 = (cache + 1) * self._cache_size
... inputs = np.array([[pos, pos] for pos in range(pos1, pos2)])
... targets = np.array([[pos] for pos in range(pos1, pos2)])
... return [inputs], [targets]
>>> ds = cx.VirtualDataset(f, 1000, [(2,)], [(1,)], [(0,1)], [(0,1)],
... load_cache_direct=True, cache_size=20)
>>> ds.split(.25)
>>> ds.inputs[0], ds.inputs[-1]
([0, 0], [999, 999])
>>> ds.train_inputs[0], ds.train_inputs[-1]
([0, 0], [749, 749])
>>> ds.test_inputs[0], ds.test_inputs[-1]
([750, 750], [999, 999])
>>> def f(self, pos):
... return [[pos], [pos]], [[pos]]
>>> ds = cx.VirtualDataset(f, 1000, [(2,)], [(1,)], [(0,1)], [(0,1)],
... load_cache_direct=False, cache_size=20)
>>> ds.split(.25)
>>> ds.inputs[0], ds.inputs[-1]
([0, 0], [999, 999])
>>> ds.train_inputs[0], ds.train_inputs[-1]
([0, 0], [749, 749])
>>> ds.test_inputs[0], ds.test_inputs[-1]
([750, 750], [999, 999])
>>> def test_dataset(net):
... net.dataset.split(.1)
... net.train(1, accuracy=1.0, report_rate=10, plot=False, batch_size=8)
Expand Down

0 comments on commit 9db37c8

Please sign in to comment.