Skip to content

Commit

Permalink
set random seed for swiss roll generation
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Sep 18, 2018
1 parent 264bd2d commit 56e16a4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/load_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_swiss_roll(n_samples=1000, noise=0.5, seed=42):
t = 1.5 * np.pi * (1 + 2 * generator.rand(1, n_samples))
x = t * np.cos(t)
y = t * np.sin(t)
sample_idx = np.random.choice([0, 1], n_samples, replace=True)
sample_idx = generator.choice([0, 1], n_samples, replace=True)
z = sample_idx
t = np.squeeze(t)
X = np.concatenate((x, y))
Expand Down
4 changes: 2 additions & 2 deletions test/test_mnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_mnn_graph_float_theta():
distance=metric, sample_idx=sample_idx, thresh=0,
use_pygsp=True)
assert G.N == G2.N
assert np.all(G.d == G2.d)
np.testing.assert_array_equal(G.dw, G2.dw)
assert (G.W != G2.W).nnz == 0
assert (G2.W != G.W).sum() == 0
assert isinstance(G2, graphtools.graphs.MNNGraph)
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_mnn_graph_matrix_theta():
distance=metric, sample_idx=sample_idx, thresh=0,
use_pygsp=True)
assert G.N == G2.N
assert np.all(G.d == G2.d)
np.testing.assert_array_equal(G.dw, G2.dw)
assert (G.W != G2.W).nnz == 0
assert (G2.W != G.W).sum() == 0
assert isinstance(G2, graphtools.graphs.MNNGraph)
Expand Down

0 comments on commit 56e16a4

Please sign in to comment.