Skip to content

Commit

Permalink
code reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
shengzeang committed Apr 27, 2022
1 parent 27b1ff7 commit 7c631a6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
5 changes: 3 additions & 2 deletions examples/test_nas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from openbox.optimizer.generic_smbo import SMBO

from sgl.dataset.planetoid import Planetoid
from sgl.search.search_config import ConfigManager
from openbox.optimizer.generic_smbo import SMBO

dataset = Planetoid("cora", "./", "official")
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
Expand All @@ -21,7 +22,7 @@
surrogate_type='prf',
acq_type='ehvi',
acq_optimizer_type='local_random',
initial_runs=2*(dim+1),
initial_runs=2 * (dim + 1),
init_strategy='sobol',
ref_point=[-1, 0.00001],
time_limit_per_trial=5000,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="sgl-dair",
version="0.1.0",
version="0.1.2",
author="DAIR Lab @PKU",
description="Graph Neural Network (GNN) toolkit targeting scalable graph learning",
long_description=long_description,
Expand Down
17 changes: 14 additions & 3 deletions sgl/search/search_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import numpy as np
<<<<<<< Updated upstream
from sgl.search.search_models import SearchModel
from sgl.search.auto_search import SearchManager
=======
>>>>>>> Stashed changes
from openbox.utils.config_space import ConfigurationSpace, UniformIntegerHyperparameter

from sgl.search.auto_search import SearchManager
from sgl.search.search_models import SearchModel


class ConfigManager():
def __init__(self, arch, prop_steps=[1,10], prop_types=[0,1], mesg_types=[0,8], num_layers=[1,10], post_steps=[1,10], post_types=[0,1], pmsg_types=[0,5]):
def __init__(self, arch, prop_steps=[1, 10], prop_types=[0, 1], mesg_types=[0, 8], num_layers=[1, 10],
post_steps=[1, 10], post_types=[0, 1], pmsg_types=[0, 5]):
super(ConfigManager, self).__init__()

self.__initial_arch = arch
Expand All @@ -16,7 +24,9 @@ def __init__(self, arch, prop_steps=[1,10], prop_types=[0,1], mesg_types=[0,8],
self.__post_steps = UniformIntegerHyperparameter("post_steps", post_steps[0], post_steps[1])
self.__post_types = UniformIntegerHyperparameter("post_types", post_types[0], post_types[1])
self.__pmsg_types = UniformIntegerHyperparameter("pmsg_types", pmsg_types[0], pmsg_types[1])
self.__config_space.add_hyperparameters([self.__prop_steps, self.__prop_types, self.__mesg_types, self.__num_layers, self.__post_steps, self.__post_types, self.__pmsg_types])
self.__config_space.add_hyperparameters(
[self.__prop_steps, self.__prop_types, self.__mesg_types, self.__num_layers, self.__post_steps,
self.__post_types, self.__pmsg_types])

def _setParameters(self, dataset, device, hiddim, epochs, lr, wd):
self.__dataset = dataset
Expand All @@ -31,7 +41,8 @@ def _configSpace(self):

def _configTarget(self, arch):
model = SearchModel(arch, self.__dataset.num_features, int(self.__dataset.num_classes), self.__hiddim)
acc_res, time_res = SearchManager(self.__dataset, model, lr=self.__lr, weight_decay=self.__wd, epochs=self.__epochs, device=self.__device)._execute()
acc_res, time_res = SearchManager(self.__dataset, model, lr=self.__lr, weight_decay=self.__wd,
epochs=self.__epochs, device=self.__device)._execute()
result = dict()
result['objs'] = np.stack([-acc_res, time_res], axis=-1)
return result
Expand Down

0 comments on commit 7c631a6

Please sign in to comment.