Skip to content

Commit

Permalink
Updates for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoe committed Jun 13, 2023
1 parent 24b0ba0 commit 13722b2
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion multiml/database/numpy_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def update_data(self, data_id, var_name, idata, phase, index, mode=None):
self._db[data_id][phase][var_name][get_slice(index)] = idata

def get_data(self, data_id, var_name, phase, index):
if isinstance(index, list): # allow fancy index, experimental feature
if isinstance(index, (list, np.ndarray)): # allow fancy index, experimental feature
return np.take(self._db[data_id][phase][var_name], index, axis=0)
else:
return self._db[data_id][phase][var_name][get_slice(index)]
Expand Down
3 changes: 2 additions & 1 deletion multiml/database/zarr_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""ZarrDatabase module."""
import tempfile
import zarr
import numpy as np

from multiml import logger
from multiml.database.database import Database
Expand Down Expand Up @@ -52,7 +53,7 @@ def update_data(self, data_id, var_name, idata, phase, index, mode=None):
self._db[data_id][phase][var_name][get_slice(index)] = idata

def get_data(self, data_id, var_name, phase, index):
if isinstance(index, list): # allow fancy index, experimental feature
if isinstance(index, (list, np.ndarray)): # allow fancy index, experimental feature
return self._db[data_id][phase][var_name].oindex[index]
else:
return self._db[data_id][phase][var_name][get_slice(index)]
Expand Down
16 changes: 14 additions & 2 deletions multiml/task/pytorch/datasets/storegate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import torch
import torch.utils.data as tdata

import numpy as np


class StoreGateDataset(tdata.Dataset):
"""StoreGate dataset class."""
def __init__(self,
storegate,
phase,
preload=False,
device=None,
preload=None,
input_var_names=None,
true_var_names=None,
callbacks=None):

self._storegate = storegate
self._true_var_names = true_var_names
self._phase = phase
self._device = device
self._preload = preload
self._input_var_names = input_var_names
self._true_var_names = true_var_names
Expand All @@ -26,10 +31,17 @@ def __init__(self,
self._data = None
self._target = None

if self._preload:
if self._preload == 'cpu':
self._data = storegate.get_data(input_var_names, phase)
self._target = storegate.get_data(true_var_names, phase)

elif self._preload == 'cuda':
data = storegate.get_data(input_var_names, phase)
target = storegate.get_data(true_var_names, phase)

self._data = torch.from_numpy(data).to(device)
self._target = torch.from_numpy(target).to(device)

def __len__(self):
return self._size

Expand Down
1 change: 1 addition & 0 deletions multiml/task/pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def get_storegate_dataset(self, phase, preload=False, callbacks=None):
"""Returns storegate dataset."""
return StoreGateDataset(self.storegate,
phase,
self._device,
input_var_names=self.input_var_names,
true_var_names=self.true_var_names,
preload=preload,
Expand Down
2 changes: 1 addition & 1 deletion multiml/task/pytorch/samplers/simple_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __next__(self):

self.index += 1

return self.data[index1:index2].tolist()
return self.data[index1:index2]

def __len__(self):
return -(-1 * self.num_samples // self.batch_size)

0 comments on commit 13722b2

Please sign in to comment.