Skip to content

Commit

Permalink
test set_params
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Jul 31, 2018
1 parent f475877 commit d9c036a
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 8 deletions.
24 changes: 16 additions & 8 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def set_params(self, **params):
raise ValueError("Cannot update thresh. Please create a new graph")
if 'n_jobs' in params:
self.n_jobs = params['n_jobs']
if hasattr(self, "_knn_tree"):
self.knn_tree.set_params(n_jobs=self.n_jobs)
if 'random_state' in params:
self.random_state = params['random_state']
if 'verbose' in params:
Expand Down Expand Up @@ -857,6 +859,7 @@ def __init__(self, data, sample_idx,
decay=None,
distance='euclidean',
thresh=1e-4,
n_jobs=1,
**kwargs):
self.beta = beta
self.sample_idx = sample_idx
Expand All @@ -867,6 +870,7 @@ def __init__(self, data, sample_idx,
self.decay = decay
self.distance = distance
self.thresh = thresh
self.n_jobs = n_jobs
self.weighted_knn = self._weight_knn()

if sample_idx is None:
Expand Down Expand Up @@ -950,8 +954,12 @@ def get_params(self):
"""
params = super().get_params()
params.update({'beta': self.beta,
'adaptive_k': self.adaptive_k})
params.update(self.knn_args)
'adaptive_k': self.adaptive_k,
'knn': self.knn,
'decay': self.decay,
'distance': self.distance,
'thresh': self.thresh,
'n_jobs': self.n_jobs})
return params

def set_params(self, **params):
Expand Down Expand Up @@ -990,15 +998,14 @@ def set_params(self, **params):
knn_kernel_args = ['knn', 'decay', 'distance', 'thresh']
knn_other_args = ['n_jobs', 'random_state', 'verbose']
for arg in knn_kernel_args:
if arg in params and (arg not in self.knn_args or
params[arg] != self.knn_args[arg]):
if arg in params and params[arg] != getattr(self, arg):
raise ValueError("Cannot update {}. "
"Please create a new graph".format(arg))
for arg in knn_other_args:
self.__setattr__(arg, params[arg])

# update subgraph parameters
[g.set_params(**knn_other_args) for g in self.subgraphs]
if arg in params:
self.__setattr__(arg, params[arg])
for g in self.subgraphs:
g.set_params(**{arg: params[arg]})

# update superclass parameters
super().set_params(**params)
Expand Down Expand Up @@ -1034,6 +1041,7 @@ def build_kernel(self):
thresh=self.thresh,
verbose=self.verbose,
random_state=self.random_state,
n_jobs=self.n_jobs,
initialize=False)
self.subgraphs.append(graph) # append to list of subgraphs
tasklogger.log_complete("subgraphs")
Expand Down
14 changes: 14 additions & 0 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,17 @@ def test_inverse_transform_sparse_no_pca():
assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, 0])
assert_raises(ValueError, G.inverse_transform,
sp.csr_matrix(G.data)[:, :15])


#############
# Test API
#############


def test_set_params():
G = graphtools.base.Data(data, n_pca=20)
assert G.get_params() == {'n_pca': 20, 'random_state': None}
G.set_params(random_state=13)
assert G.random_state == 13
assert_raises(ValueError, G.set_params, n_pca=10)
G.set_params(n_pca=G.n_pca)
25 changes: 25 additions & 0 deletions test/test_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,32 @@ def test_precomputed_interpolate():
G.build_kernel_to_data(data)


####################
# Test API
####################


def test_verbose():
print()
print("Verbose test: Exact")
build_graph(data, decay=10, thresh=0, verbose=True)


def test_set_params():
G = build_graph(data, decay=10, thresh=0)
assert G.get_params() == {'n_pca': 20,
'random_state': 42,
'kernel_symm': '+',
'gamma': None,
'knn': 3,
'decay': 10,
'distance': 'euclidean',
'precomputed': None}
assert_raises(ValueError, G.set_params, knn=15)
assert_raises(ValueError, G.set_params, decay=15)
assert_raises(ValueError, G.set_params, distance='manhattan')
assert_raises(ValueError, G.set_params, precomputed='distance')
G.set_params(knn=G.knn,
decay=G.decay,
distance=G.distance,
precomputed=G.precomputed)
41 changes: 41 additions & 0 deletions test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,48 @@ def test_knn_interpolate():
G.interpolate(pca_data, transitions=transitions)))


####################
# Test API
####################


def test_verbose():
print()
print("Verbose test: kNN")
build_graph(data, decay=None, verbose=True)


def test_set_params():
G = build_graph(data, decay=None)
assert G.get_params() == {
'n_pca': 20,
'random_state': 42,
'kernel_symm': '+',
'gamma': None,
'knn': 3,
'decay': None,
'distance': 'euclidean',
'thresh': 0,
'n_jobs': -1,
'verbose': 0
}
G.set_params(n_jobs=4)
assert G.n_jobs == 4
assert G.knn_tree.n_jobs == 4
G.set_params(random_state=13)
assert G.random_state == 13
G.set_params(verbose=2)
assert G.verbose == 2
G.set_params(verbose=0)
assert_raises(ValueError, G.set_params, knn=15)
assert_raises(ValueError, G.set_params, decay=10)
assert_raises(ValueError, G.set_params, distance='manhattan')
assert_raises(ValueError, G.set_params, thresh=1e-3)
assert_raises(ValueError, G.set_params, gamma=0.99)
assert_raises(ValueError, G.set_params, kernel_symm='*')
G.set_params(knn=G.knn,
decay=G.decay,
thresh=G.thresh,
distance=G.distance,
gamma=G.gamma,
kernel_symm=G.kernel_symm)
28 changes: 28 additions & 0 deletions test/test_landmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
data,
digits,
build_graph,
assert_raises,
raises,
warns,
generate_swiss_roll
Expand Down Expand Up @@ -118,7 +119,34 @@ def test_landmark_mnn_pygsp_graph():
# TODO: add interpolation tests


#############
# Test API
#############

def test_verbose():
print()
print("Verbose test: Landmark")
build_graph(data, decay=None, n_landmark=500, verbose=True).landmark_op


def test_set_params():
G = build_graph(data, n_landmark=500, decay=None)
G.landmark_op
assert G.get_params == {'n_pca': 20,
'random_state': 42,
'kernel_symm': '+',
'gamma': None,
'n_landmark': 500,
'knn': 3,
'decay': None,
'distance':
'euclidean',
'thresh': 0,
'n_jobs': -1,
'verbose': 0}
G.set_params(n_landmark=300)
assert G.landmark_op.shape == (300, 300)
G.set_params(n_landmark=G.n_landmark, n_svd=G.n_svd)
assert hasattr(G, "_landmark_op")
G.set_params(n_svd=50)
assert not hasattr(G, "_landmark_op")
47 changes: 47 additions & 0 deletions test/test_mnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,50 @@ def test_verbose():
build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, verbose=True)


def test_set_params():
X, sample_idx = generate_swiss_roll()
G = build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None,
thresh=1e-4)
assert G.get_params() == {
'n_pca': None,
'random_state': 42,
'kernel_symm': 'gamma',
'gamma': 0.5,
'beta': 1,
'adaptive_k': 'sqrt',
'knn': 3,
'decay': 10,
'distance': 'euclidean',
'thresh': 1e-4,
'n_jobs': 1
}
G.set_params(n_jobs=4)
assert G.n_jobs == 4
for graph in G.subgraphs:
assert graph.n_jobs == 4
assert graph.knn_tree.n_jobs == 4
G.set_params(random_state=13)
assert G.random_state == 13
for graph in G.subgraphs:
assert graph.random_state == 13
G.set_params(verbose=2)
assert G.verbose == 2
for graph in G.subgraphs:
assert graph.verbose == 2
G.set_params(verbose=0)
assert_raises(ValueError, G.set_params, knn=15)
assert_raises(ValueError, G.set_params, decay=15)
assert_raises(ValueError, G.set_params, distance='manhattan')
assert_raises(ValueError, G.set_params, thresh=1e-3)
assert_raises(ValueError, G.set_params, beta=0.2)
assert_raises(ValueError, G.set_params, adaptive_k='min')
G.set_params(knn=G.knn,
decay=G.decay,
thresh=G.thresh,
distance=G.distance,
beta=G.beta,
adaptive_k=G.adaptive_k)

0 comments on commit d9c036a

Please sign in to comment.