Skip to content

Commit

Permalink
Imporove speed
Browse files Browse the repository at this point in the history
  • Loading branch information
tomoe committed Jun 12, 2023
1 parent fe09427 commit 90c58c1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion multiml/database/numpy_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def update_data(self, data_id, var_name, idata, phase, index, mode=None):

def get_data(self, data_id, var_name, phase, index):
if isinstance(index, list): # allow fancy index, experimental feature
return self._db[data_id][phase][var_name][index]
return np.take(self._db[data_id][phase][var_name], index, axis=0)
else:
return self._db[data_id][phase][var_name][get_slice(index)]

Expand Down
15 changes: 13 additions & 2 deletions multiml/task/pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
gpu_ids=None,
torchinfo=False,
amp=False,
torch_compile=False,
dataset_args=None,
dataloader_args=None,
batch_sampler=False,
Expand All @@ -62,6 +63,7 @@ def __init__(self,
``gpu_ids`` is given.
torchinfo (bool): show torchinfo summary after model compile.
amp (bool): *(expert option)* enable amp mode.
torch_compile (bool): *(expert option)* enable torch.compile.
dataset_args (dict): args passed to default DataSet creation.
dataloader_args (dict): args passed to default DataLoader creation.
batch_sampler (bool): user batch_sampler or not.
Expand All @@ -81,6 +83,7 @@ def __init__(self,
self._gpu_ids = gpu_ids
self._torchinfo = torchinfo
self._amp = amp
self._torch_compile = torch_compile
self._dataset_args = dataset_args
self._dataloader_args = dataloader_args
self._batch_sampler = batch_sampler
Expand Down Expand Up @@ -132,6 +135,9 @@ def compile_model(self):

self.ml.model = util.compile(self._model, self._model_args, modules)

if self._torch_compile:
self.ml.model = torch.compile(self.ml.model)

if self.pred_var_names is not None:
self._pred_index = self.get_pred_index()

Expand Down Expand Up @@ -459,12 +465,17 @@ def step_epoch(self, epoch, phase, dataloader, label=True):

for ii, data in enumerate(dataloader):
batch_result = self.step_batch(data, phase, label)
results.update(epoch_metric(batch_result))

if phase == 'test':
epoch_metric.pred(batch_result)

if (ii % self._running_step == 0) or (ii == num_batches - 1):
if phase == 'train':
if (ii % self._running_step == 0) or (ii == num_batches - 1):
results.update(epoch_metric(batch_result))
pbar_metrics = metrics.get_pbar_metric(results)
pbar.set_postfix(pbar_metrics)
else:
results.update(epoch_metric(batch_result))
pbar_metrics = metrics.get_pbar_metric(results)
pbar.set_postfix(pbar_metrics)
pbar.update(1)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tensorflow =
tensorflow==2.4.0
numpy==1.19.5
pytorch =
torch==1.13.1
torch==2.0.1
torchinfo==1.7.1
tqdm==4.48.2

Expand Down

0 comments on commit 90c58c1

Please sign in to comment.