Skip to content

Commit

Permalink
Support for h5 served data; Dataset function/generator can assign bat…
Browse files Browse the repository at this point in the history
…ch data directly
  • Loading branch information
dsblank committed Sep 7, 2018
1 parent 38d40b0 commit caca8a5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
41 changes: 30 additions & 11 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,16 @@ def __getitem__(self, pos):
## now fill the cached values with these:
if self.dataset._load_batch_direct:
if self.dataset._generator:
all_inputs, all_targets = next(self.dataset._generator)
all_data = next(self.dataset._generator)
else:
if self.dataset._pass_self:
all_inputs, all_targets = self.dataset._generate(self.dataset, batch)
all_data = self.dataset._generate(self.dataset, batch)
else:
all_inputs, all_targets = self.dataset._generate(batch)
self.dataset._inputs = all_inputs
self.dataset._targets = all_targets
all_data = self.dataset._generate(batch)
if all_data is not None:
self.dataset._inputs = all_data[0]
self.dataset._targets = all_data[1]
# else, assumes that you did it through self._inputs = ... in function/generator
else:
all_inputs = []
all_targets = []
Expand Down Expand Up @@ -2109,24 +2111,41 @@ class H5Dataset(VirtualDataset):
"""

def __init__(self, f, filename, key, batch_size, input_banks, target_banks,
name=None, description=None, network=None):
name=None, description=None, network=None, load_batch_direct=False,
length=None, endpoint=None, username=None, password=None, api_key=None,
use_session=True, use_cache=False):
"""
>>> def f(self, pos):
... return self.h5[self.key][0][pos], self.h5[self.key][0][pos]
>>> def f_batch(self, batch):
... pos = batch * self._batch_size
... b = self.h5[self.key][0][pos:pos + self._batch_size]
... return [b], [b]
>>> if os.path.exists("fonts.hdf5"):
... ds = cx.H5Dataset(f, "fonts.hdf5", "fonts", 32, 1, 1, name="Fonts",
... description='''
... ds1 = cx.H5Dataset(f, "fonts.hdf5", "fonts", 32, 1, 1, name="Fonts",
... description='''
... Based on: https://erikbern.com/2016/01/21/analyzing-50k-fonts-using-deep-neural-networks.html
... ''')
... ds2 = cx.H5Dataset(f_batch, "fonts.hdf5", "fonts", 32, 1, 1, name="Fonts", load_batch_direct=True,
... description='''
... Based on: https://erikbern.com/2016/01/21/analyzing-50k-fonts-using-deep-neural-networks.html
... ''')
"""
import h5py
self.key = key
self.h5 = h5py.File(filename)
super().__init__(f, len(self.h5[self.key]),
if endpoint is not None:
import h5pyd
self.h5 = h5pyd.File(filename, 'r', endpoint=endpoint, username=username,
password=password, api_key=api_key, use_session=use_session,
use_cache=use_cache)
else:
self.h5 = h5py.File(filename)
super().__init__(f, length if length is not None else len(self.h5[self.key]),
generator_ordered=False,
load_batch_direct=False,
load_batch_direct=load_batch_direct,
batch_size=batch_size,
name=name, description=description, network=network,
## dummy values, for now:
Expand Down
1 change: 1 addition & 0 deletions conx/datasets/_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_batch(self, batch):
## targets:
y_train = fp[key_target][pos:pos + self._batch_size]
labels = y_train
## FIXME: at least one is mis-labeled:
## labels[10994] = 9
targets = to_categorical(labels)
labels = np.array([str(label) for label in labels], dtype=str)
Expand Down

0 comments on commit caca8a5

Please sign in to comment.