Skip to content

Commit

Permalink
Merge fe858d9 into 6b1af72
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Feb 10, 2021
2 parents 6b1af72 + fe858d9 commit a90a8e9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
27 changes: 24 additions & 3 deletions Python/phate/phate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ class PHATE(BaseEstimator):
`data[0,0]`. You can override this detection with
`knn_dist='precomputed_distance'` or `knn_dist='precomputed_affinity'`.
knn_max : int, optional, default: None
Maximum number of neighbors for which alpha decaying kernel
is computed for each point. For very large datasets, setting `knn_max`
to a small multiple of `knn` can speed up computation significantly.
mds_dist : string, optional, default: 'euclidean'
Distance metric for MDS. Recommended values: 'euclidean' and 'cosine'
Any metric from `scipy.spatial.distance` can be used. Custom distance
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(
n_pca=100,
mds_solver="sgd",
knn_dist="euclidean",
knn_max=None,
mds_dist="euclidean",
mds="metric",
n_jobs=1,
Expand All @@ -203,6 +209,7 @@ def __init__(
self.mds = mds
self.n_pca = n_pca
self.knn_dist = knn_dist
self.knn_max = knn_max
self.mds_dist = mds_dist
self.mds_solver = mds_solver
self.random_state = random_state
Expand Down Expand Up @@ -311,16 +318,19 @@ def _check_params(self):
------
ValueError : unacceptable choice of parameters
"""
utils.check_positive(n_components=self.n_components, k=self.knn)
utils.check_int(n_components=self.n_components, k=self.knn, n_jobs=self.n_jobs)
utils.check_positive(n_components=self.n_components, knn=self.knn)
utils.check_int(
n_components=self.n_components, knn=self.knn, n_jobs=self.n_jobs
)
utils.check_between(-1, 1, gamma=self.gamma)
utils.check_if_not(None, utils.check_positive, a=self.decay)
utils.check_if_not(None, utils.check_positive, decay=self.decay)
utils.check_if_not(
None,
utils.check_positive,
utils.check_int,
n_landmark=self.n_landmark,
n_pca=self.n_pca,
knn_max=self.knn_max,
)
utils.check_if_not("auto", utils.check_positive, utils.check_int, t=self.t)
if not callable(self.knn_dist):
Expand Down Expand Up @@ -460,6 +470,11 @@ def set_params(self, **params):
using `data[0,0]`. You can override this detection with
`knn_dist='precomputed_distance'` or `knn_dist='precomputed_affinity'`.
knn_max : int, optional, default: None
Maximum number of neighbors for which alpha decaying kernel
is computed for each point. For very large datasets, setting `knn_max`
to a small multiple of `knn` can speed up computation significantly.
mds_dist : string, optional, default: 'euclidean'
recommended values: 'euclidean' and 'cosine'
Any metric from `scipy.spatial.distance` can be used
Expand Down Expand Up @@ -578,6 +593,10 @@ def set_params(self, **params):
self.knn = params["knn"]
reset_kernel = True
del params["knn"]
if "knn_max" in params and params["knn_max"] != self.knn_max:
self.knn_max = params["knn_max"]
reset_kernel = True
del params["knn_max"]
if "decay" in params and params["decay"] != self.decay:
self.decay = params["decay"]
reset_kernel = True
Expand Down Expand Up @@ -759,6 +778,7 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark):
self.graph.set_params(
decay=self.decay,
knn=self.knn,
knn_max=self.knn_max,
distance=self.knn_dist,
precomputed=precomputed,
n_jobs=self.n_jobs,
Expand Down Expand Up @@ -824,6 +844,7 @@ def fit(self, X):
distance=self.knn_dist,
precomputed=precomputed,
knn=self.knn,
knn_max=self.knn_max,
decay=self.decay,
thresh=1e-4,
n_jobs=self.n_jobs,
Expand Down
2 changes: 1 addition & 1 deletion Python/phate/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.6"
__version__ = "1.0.7"
7 changes: 7 additions & 0 deletions Python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_tree():
n_components=2,
decay=10,
knn=5,
knn_max=15,
t=30,
mds="classic",
knn_dist="euclidean",
Expand All @@ -93,6 +94,12 @@ def test_tree():
n_landmark=None,
verbose=False,
)
phate_operator.fit(M)
assert phate_operator.graph.knn == 5
assert phate_operator.graph.knn_max == 15
assert phate_operator.graph.decay == 10
assert phate_operator.graph.n_jobs == -2
assert phate_operator.graph.verbose == 0

# run phate with classic MDS
print("DLA tree, classic MDS")
Expand Down

0 comments on commit a90a8e9

Please sign in to comment.