Skip to content

Commit

Permalink
Adapt to ray 1.1.0
Browse files Browse the repository at this point in the history
I got some warnings about deprecated override functions and changed them to the suggested ones. Also, we need the `mode` parameter to guide the direction of the search.
  • Loading branch information
howardlau1999 committed Jan 23, 2021
1 parent 4242acd commit fb89102
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/mnist_pytorch_trainable.py
Expand Up @@ -27,7 +27,7 @@
# yapf: disable
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def _setup(self, config):
def setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = get_data_loaders()
Expand All @@ -37,18 +37,18 @@ def _setup(self, config):
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))

def _train(self):
def step(self):
self.current_ip()
train(self.model, self.optimizer, self.train_loader, device=self.device)
acc = test(self.model, self.test_loader, self.device)
return {"mean_accuracy": acc}

def _save(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path

def _restore(self, checkpoint_path):
def load_checkpoint(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))

# this is currently needed to handle Cori GPU multiple interfaces
Expand All @@ -63,7 +63,7 @@ def current_ip(self):
# ip_head and redis_passwords are set by ray cluster shell scripts
print(os.environ["ip_head"], os.environ["redis_password"])
ray.init(address='auto', _node_ip_address=os.environ["ip_head"].split(":")[0], _redis_password=os.environ["redis_password"])
sched = ASHAScheduler(metric="mean_accuracy")
sched = ASHAScheduler(metric="mean_accuracy", mode="max")
analysis = tune.run(TrainMNIST,
scheduler=sched,
stop={"mean_accuracy": 0.99,
Expand All @@ -74,4 +74,4 @@ def current_ip(self):
config={"lr": tune.uniform(0.001, 1.0),
"momentum": tune.uniform(0.1, 0.9),
"use_gpu": True})
print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))
print("Best config is:", analysis.get_best_config(metric="mean_accuracy", mode="max"))

0 comments on commit fb89102

Please sign in to comment.