Skip to content

Commit

Permalink
change pickle only flag to use torch in tune adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
nmatthews-asapp committed Jan 30, 2020
1 parent 49b83cd commit 75e9880
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions flambe/experiment/tune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import shutil

import ray
import torch
import dill

from flambe.compile import load_state_from_file, Schema, Component
from flambe.compile.extensions import setup_default_modules, import_modules
Expand Down Expand Up @@ -153,13 +155,19 @@ def _train(self) -> Dict:
def _save(self, checkpoint_dir: str) -> str:
"""Subclasses should override this to implement save()."""
path = os.path.join(checkpoint_dir, "checkpoint.flambe")
self.block.save(path, pickle_only=self.pickle_checkpoints, overwrite=True)
if self.pickle_checkpoints:
torch.save(self.block, path, pickle_module=dill)
else:
self.block.save(path, overwrite=True)
return path

def _restore(self, checkpoint: str) -> None:
"""Subclasses should override this to implement restore()."""
state = load_state_from_file(checkpoint)
self.block.load_state(state)
if self.pickle_checkpoints:
self.block = torch.load(checkpoint, pickle_protocol=dill)
else:
state = load_state_from_file(checkpoint)
self.block.load_state(state)

def _stop(self):
"""Subclasses should override this for any cleanup on stop."""
Expand Down

0 comments on commit 75e9880

Please sign in to comment.