Skip to content

Commit

Permalink
Add GPU support for StoreGate dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoe committed Jun 14, 2023
1 parent 13722b2 commit 8f74876
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
4 changes: 3 additions & 1 deletion multiml/database/numpy_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""NumpyDatabase module."""

import torch
import numpy as np

from multiml.database.database import Database
Expand Down Expand Up @@ -44,7 +45,8 @@ def update_data(self, data_id, var_name, idata, phase, index, mode=None):
self._db[data_id][phase][var_name][get_slice(index)] = idata

def get_data(self, data_id, var_name, phase, index):
if isinstance(index, (list, np.ndarray)): # allow fancy index, experimental feature
if isinstance(index,
(list, np.ndarray, torch.Tensor)): # allow fancy index, experimental feature
return np.take(self._db[data_id][phase][var_name], index, axis=0)
else:
return self._db[data_id][phase][var_name][get_slice(index)]
Expand Down
7 changes: 6 additions & 1 deletion multiml/database/zarr_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""ZarrDatabase module."""
import tempfile
import zarr
import torch
import numpy as np

from multiml import logger
Expand Down Expand Up @@ -53,7 +54,11 @@ def update_data(self, data_id, var_name, idata, phase, index, mode=None):
self._db[data_id][phase][var_name][get_slice(index)] = idata

def get_data(self, data_id, var_name, phase, index):
if isinstance(index, (list, np.ndarray)): # allow fancy index, experimental feature
if isinstance(index,
(list, np.ndarray, torch.Tensor)): # allow fancy index, experimental feature
if isinstance(index, torch.Tensor):
index = index.numpy()

return self._db[data_id][phase][var_name].oindex[index]
else:
return self._db[data_id][phase][var_name][get_slice(index)]
Expand Down
14 changes: 12 additions & 2 deletions multiml/task/pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""PytorchBaseTask module."""
import copy
import math
import multiprocessing as mp

import numpy as np
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self,
dataset_args=None,
dataloader_args=None,
batch_sampler=False,
metric_sample=1,
**kwargs):
"""Initialize the pytorch base task.
Expand All @@ -67,6 +69,7 @@ def __init__(self,
dataset_args (dict): args passed to default DataSet creation.
dataloader_args (dict): args passed to default DataLoader creation.
batch_sampler (bool): user batch_sampler or not.
metric_sample (float or int): sampling ratio for running metrics.
"""
super().__init__(**kwargs)

Expand All @@ -89,7 +92,7 @@ def __init__(self,
self._batch_sampler = batch_sampler

self._pbar_args = const.PBAR_ARGS
self._running_step = 1
self._metric_sample = metric_sample
self._pred_index = None
self._early_stopping = False
self._sampler = None
Expand Down Expand Up @@ -459,6 +462,13 @@ def step_epoch(self, epoch, phase, dataloader, label=True):
pbar_args.update(self._pbar_args)
pbar_desc = self._get_pbar_description(epoch, phase)

if self._metric_sample <= 1.0:
metric_step = 1.0 // self._metric_sample
metric_step = math.floor(metric_step)
else:
metric_step = num_batches / self._metric_sample
metric_step = math.ceil(metric_step)

results = {}
with tqdm(**pbar_args) as pbar:
pbar.set_description(pbar_desc)
Expand All @@ -470,7 +480,7 @@ def step_epoch(self, epoch, phase, dataloader, label=True):
epoch_metric.pred(batch_result)

if phase == 'train':
if (ii % self._running_step == 0) or (ii == num_batches - 1):
if (ii % metric_step == 0) or (ii == num_batches - 1):
results.update(epoch_metric(batch_result))
pbar_metrics = metrics.get_pbar_metric(results)
pbar.set_postfix(pbar_metrics)
Expand Down
8 changes: 5 additions & 3 deletions multiml/task/pytorch/samplers/simple_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@


class SimpleBatchSampler(Sampler):
def __init__(self, num_samples, batch_size, shuffle):
self.data = np.arange(num_samples)
def __init__(self, num_samples, batch_size, shuffle, device='cpu'):
self.data = torch.arange(num_samples, device=device)
self.index = 0
self.num_samples = num_samples
self.batch_size = batch_size
self.shuffle = shuffle
self.device = device

def __iter__(self):
if self.shuffle:
np.random.shuffle(self.data)
self.data = torch.randperm(len(self.data), device=self.device)

self.index = 0
return self

Expand Down

0 comments on commit 8f74876

Please sign in to comment.