Skip to content

Commit

Permalink
Update create_empty in StoreGate
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoe committed May 30, 2023
1 parent 10497e5 commit 44fe5b3
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions multiml/storegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def create_empty(self, var_names, shape, phase='train', dtype='f4'):
Args:
var_names (str or list): see ``add_data()`` method.
shape (int or tuple): shape of empty data.
shape (tuple): shape of empty data.
phase (str): see ``update_data()`` method.
dtype (str): dtype of empty data. Default float32.
"""
Expand All @@ -507,17 +507,21 @@ def create_empty(self, var_names, shape, phase='train', dtype='f4'):
if isinstance(var_names, str):
var_names = [var_names]

if phase == 'all':
phases = const.PHASES
else:
phases = [phase]
self._check_valid_phase(phase)

ndata = shape[0]
for var_name in var_names:
indices = self._get_phase_indices(phase, ndata)

dummy_data = np.zeros(ndata)
for iphase, phase_data in zip(const.PHASES, np.split(dummy_data, indices)):
if len(phase_data) == 0:
continue

for iphase in phases:
for var_name in var_names:
if var_name in self.get_var_names(iphase):
raise ValueError(f'create_empty: {var_name} arleady exists in {phase}.')

self._db.create_empty(self._data_id, var_name, iphase, shape, dtype)
phase_shape = (len(phase_data), ) + shape[1:]
self._db.create_empty(self._data_id, var_name, iphase, phase_shape, dtype)

def get_data_ids(self):
"""Returns registered data_ids in the backend.
Expand Down

0 comments on commit 44fe5b3

Please sign in to comment.