Skip to content

Commit

Permalink
Added _get_data, _set_data, Moved apply_transform to Batch
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-kh committed Jun 14, 2017
1 parent dc20fe4 commit 3ecdf8b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 38 deletions.
43 changes: 43 additions & 0 deletions dataset/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def data(self):
self.load(self._preloaded)
return self._data

def _get_data(self, i, value):
""" Return i-th component of a data tuple
Do not call if self.data is not a tuple """
return self.data[i] if self.data is not None else None

def _set_data(self, i, value):
""" Put a new value into i-th component of a data tuple
Do not call if self._data is not a tuple """
data = list(self.data)
data[i] = value
self._data = tuple(data)

def __getitem__(self, item):
if isinstance(self.data, tuple):
res = tuple(data_item[item] if data_item is not None else None for data_item in self.data)
Expand Down Expand Up @@ -100,13 +112,44 @@ def get_errors(self, all_res):
@action
def load(self, src, fmt=None):
""" Load data from a file or another data source """
if fmt is None:
if isinstance(src, tuple):
self._data = tuple(src[i][self.indices] for i in range(len(src)))
else:
self._data = src[self.indices]
else:
raise ValueError("Unsupported format:", fmt)
return self

@action
def dump(self, dst, fmt=None):
""" Save batch data to disk """
return self

@action
@inbatch_parallel(init='indices')
def apply_transform(self, ix, src, dst, func, *args, **kwargs):
""" Apply a function to each item of the batch """
dst_attr = getattr(self, dst)
pos = self.index.get_pos(ix)
if src is None:
all_args = args
else:
src_attr = getattr(self, src)
all_args = tuple(src_attr[pos], *args)
dst_attr[pos] = func(*all_args, **kwargs)

@action
def apply_transform_all(self, src, dst, func, *args, **kwargs):
""" Apply a function all item of the batch """
if src is None:
all_args = args
else:
src_attr = getattr(self, src)
all_args = tuple(src_attr, *args)
setattr(self, dst, func(*all_args, **kwargs))
return self


class ArrayBatch(Batch):
""" Base Batch class for array-like datasets """
Expand Down
4 changes: 2 additions & 2 deletions dataset/batch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class BaseBatch:
""" Basr class for batches
Required to solve circulal module dependencies
""" Base class for batches
Required to solve circular module dependencies
"""
def __init__(self, index):
self.index = index
Expand Down
43 changes: 7 additions & 36 deletions dataset/batch_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,32 @@ def data(self):
@property
def images(self):
""" Images """
return self.data[0] if self.data is not None else None
return self._get_data(0)

@images.setter
def images(self, value):
""" Set images """
data = list(self.data)
data[0] = value
self._data = data
self._set_data(0, value)

@property
def labels(self):
""" Labels for images """
return self.data[1] if self.data is not None else None
return self._get_data(1)

@labels.setter
def labels(self, value):
""" Set labels """
data = list(self.data)
data[1] = value
self._data = data
self._set_data(1, value)

@property
def masks(self):
""" Masks for images """
return self.data[2] if self.data is not None else None
return self._get_data(2)

@masks.setter
def masks(self, value):
""" Set masks """
data = list(self.data)
data[3] = value
self._data = data
self._set_data(2, value)

def assemble(self, all_res, *args, **kwargs):
""" Assemble the batch after a parallel action """
Expand Down Expand Up @@ -131,29 +125,6 @@ def load(self, src, fmt=None):

@action
def dump(self, dst, fmt=None):
""" Saves data to a file or array """
""" Saves data to a file or a memory object """
_ = dst, fmt
return self

@action
@inbatch_parallel(init='indices')
def apply_transform(self, ix, src, dst, func, *args, **kwargs):
""" Apply a function to each item of the batch """
dst_attr = getattr(self, dst)
pos = self.index.get_pos(ix)
if src is None:
all_args = args
else:
src_attr = getattr(self, src)
all_args = tuple(src_attr[pos], *args)
dst_attr[pos] = func(*all_args, **kwargs)

@action
def apply_transform_all(self, src, dst, func, *args, **kwargs):
""" Apply a function all item of the batch """
if src is None:
all_args = args
else:
src_attr = getattr(self, src)
all_args = tuple(src_attr, *args)
setattr(self, dst, func(*all_args, **kwargs))

0 comments on commit 3ecdf8b

Please sign in to comment.