Skip to content

Commit

Permalink
make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Nov 22, 2018
1 parent 5104b74 commit 76ba4f0
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 13 deletions.
28 changes: 17 additions & 11 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,25 @@ class kNNGraph(DataGraph):
def __init__(self, data, knn=5, decay=None,
bandwidth=None, distance='euclidean',
thresh=1e-4, n_pca=None, **kwargs):
self.knn = knn
self.decay = decay
self.bandwidth = bandwidth
self.distance = distance
self.thresh = thresh

if decay is not None and thresh <= 0:
raise ValueError("Cannot instantiate a kNNGraph with `decay=None` "
"and `thresh=0`. Use a TraditionalGraph instead.")
if knn > data.shape[0]:
warnings.warn("Cannot set knn ({k}) to be greater than "
"data.shape[0] ({n}). Setting knn={n}".format(
"n_samples ({n}). Setting knn={n}".format(
k=knn, n=data.shape[0]))
knn = data.shape[0]
if n_pca is None and data.shape[1] > 500:
warnings.warn("Building a kNNGraph on data of shape {} is "
"expensive. Consider setting n_pca.".format(
data.shape), UserWarning)

self.knn = knn
self.decay = decay
self.bandwidth = bandwidth
self.distance = distance
self.thresh = thresh
super().__init__(data, n_pca=n_pca, **kwargs)

def get_params(self):
Expand Down Expand Up @@ -232,7 +233,7 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None):
bandwidth = self.bandwidth
if knn > self.data.shape[0]:
warnings.warn("Cannot set knn ({k}) to be greater than "
"data.shape[0] ({n}). Setting knn={n}".format(
"n_samples ({n}). Setting knn={n}".format(
k=knn, n=self.data.shape[0]))

Y = self._check_extension_shape(Y)
Expand Down Expand Up @@ -675,15 +676,20 @@ def __init__(self, data,
n_pca=None,
thresh=1e-4,
precomputed=None, **kwargs):
if decay is None and precomputed not in ['affinity', 'adjacency']:
# decay high enough is basically a binary kernel
raise ValueError("`decay` must be provided for a TraditionalGraph"
". For kNN kernel, use kNNGraph.")
if precomputed is not None and n_pca is not None:
# the data itself is a matrix of distances / affinities
n_pca = None
warnings.warn("n_pca cannot be given on a precomputed graph."
" Setting n_pca=None", RuntimeWarning)
if decay is None and precomputed not in ['affinity', 'adjacency']:
# decay high enough is basically a binary kernel
raise ValueError("`decay` must be provided for a TraditionalGraph"
". For kNN kernel, use kNNGraph.")
if knn > data.shape[0]:
warnings.warn("Cannot set knn ({k}) to be greater than or equal to"
" n_samples ({n}). Setting knn={n}".format(
k=knn, n=data.shape[0] - 1))
knn = data.shape[0] - 1
if precomputed is not None:
if precomputed not in ["distance", "affinity", "adjacency"]:
raise ValueError("Precomputed value {} not recognized. "
Expand Down
1 change: 1 addition & 0 deletions test/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
nose2,
data,
Expand Down
1 change: 1 addition & 0 deletions test/test_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
np,
sp,
Expand Down
10 changes: 10 additions & 0 deletions test/test_exact.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
graphtools,
np,
Expand Down Expand Up @@ -83,6 +84,15 @@ def test_duplicate_data():
thresh=0)


@warns(UserWarning)
def test_k_too_large():
build_graph(data,
n_pca=20,
decay=10,
knn=len(data) + 1,
thresh=0)


#####################################################
# Check kernel
#####################################################
Expand Down
1 change: 1 addition & 0 deletions test/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
graphtools,
np,
Expand Down
5 changes: 3 additions & 2 deletions test/test_landmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
graphtools,
np,
Expand Down Expand Up @@ -43,7 +44,7 @@ def test_landmark_exact_graph():
assert(isinstance(G, graphtools.graphs.TraditionalGraph))
assert(isinstance(G, graphtools.graphs.LandmarkGraph))
assert(G.transitions.shape == (data.shape[0], n_landmark))
assert(G.clusters.shape == data.shape[0])
assert(G.clusters.shape == (data.shape[0],))
assert(len(np.unique(G.clusters)) <= n_landmark)
signal = np.random.normal(0, 1, [n_landmark, 10])
interpolated_signal = G.interpolate(signal)
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_landmark_mnn_graph():
thresh=1e-5, n_pca=None,
decay=10, knn=5, random_state=42,
sample_idx=sample_idx)
assert(G.clusters.shape == data.shape[0])
assert(G.clusters.shape == (X.shape[0],))
assert(G.landmark_op.shape == (n_landmark, n_landmark))
assert(isinstance(G, graphtools.graphs.MNNGraph))
assert(isinstance(G, graphtools.graphs.LandmarkGraph))
Expand Down
1 change: 1 addition & 0 deletions test/test_mnn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
from load_tests import (
graphtools,
np,
Expand Down

0 comments on commit 76ba4f0

Please sign in to comment.