Skip to content

Commit

Permalink
clean up logic
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Feb 24, 2020
1 parent 02664d2 commit f802646
Showing 1 changed file with 70 additions and 82 deletions.
152 changes: 70 additions & 82 deletions graphtools/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,24 +223,34 @@ def __init__(

def set_params(self, **params):
for p in params:
setattr(self, p, params[p])
if not getattr(self, p) == params[p]:
setattr(self, p, params[p])
self._set_graph_params(**params)

def _set_graph_params(self, **params):
if self.graph is not None:
if "n_landmark" in params:
n_landmark = params["n_landmark"]
del params["n_landmark"]
try:
self.graph.set_params(**params)
except ValueError as e:
_logger.debug("Reset graph due to {}".format(str(e)))
self.graph = None
else:
try:
# special way to reset the graph here
self.n_landmark = n_landmark
except NameError:
pass

@abc.abstractmethod
def _reset_graph(self):
"""Trigger a reset of self.graph
Any downstream effects of resetting the graph should override this function
"""
pass
raise NotImplementedError

def _detect_precomputed_matrix_type(self, X):
if isinstance(X, sparse.coo_matrix):
Expand All @@ -259,72 +269,54 @@ def _parse_n_landmark(self, X):
def _parse_input(self, X):
# passing graphs as input
if isinstance(X, base.BaseGraph):
if isinstance(X, graphs.LandmarkGraph) or (
isinstance(X, base.BaseGraph) and self.n_landmark is None
):
# we can keep this graph
self.graph = X
X = X.data
n_pca = self.graph.n_pca
update_graph = False
if isinstance(self.graph, graphs.TraditionalGraph):
precomputed = self.graph.precomputed
else:
precomputed = None
return X, n_pca, self._parse_n_landmark(X), precomputed, update_graph
# we can keep this graph
self.graph = X
X = X.data
# immutable graph properties override operator
n_pca = self.graph.n_pca
self.knn = X.knn
self.decay = X.decay
self.distance = X.distance
self.thresh = X.thresh
update_graph = False
if isinstance(self.graph, graphs.TraditionalGraph):
precomputed = self.graph.precomputed
else:
# n_landmark is set, but this is not a landmark graph
self.graph = None
X = X.kernel
precomputed = "affinity"
n_pca = None
update_graph = False
return X, n_pca, self._parse_n_landmark(X), precomputed, update_graph
else:
try:
if isinstance(X, pygsp.graphs.Graph):
self.graph = None
X = X.W
precomputed = "adjacency"
update_graph = False
n_pca = None
return (
X,
n_pca,
self._parse_n_landmark(X),
precomputed,
update_graph,
)
except NameError:
# pygsp not installed
pass

# checks on regular data
update_graph = True
if utils.is_Anndata(X):
X = X.X
if not callable(self.distance) and self.distance.startswith("precomputed"):
if self.distance == "precomputed":
# automatic detection
precomputed = self._detect_precomputed_matrix_type(X)
elif self.distance in ["precomputed_affinity", "precomputed_distance"]:
precomputed = self.distance.split("_")[1]
else:
raise ValueError(
"distance {} not recognized. Did you mean "
"'precomputed_distance', "
"'precomputed_affinity', or 'precomputed' "
"(automatically detects distance or affinity)?".format(
self.distance
)
)
precomputed = None
elif isinstance(X, pygsp.graphs.Graph):
# convert pygsp to graphtools
self.graph = None
X = X.W
precomputed = "adjacency"
update_graph = False
n_pca = None
else:
precomputed = None
if self.n_pca is None or self.n_pca >= np.min(X.shape):
# data matrix
update_graph = True
if utils.is_Anndata(X):
X = X.X
if not callable(self.distance) and self.distance.startswith("precomputed"):
if self.distance == "precomputed":
# automatic detection
precomputed = self._detect_precomputed_matrix_type(X)
elif self.distance in ["precomputed_affinity", "precomputed_distance"]:
precomputed = self.distance.split("_")[1]
else:
raise ValueError(
"distance {} not recognized. Did you mean "
"'precomputed_distance', "
"'precomputed_affinity', or 'precomputed' "
"(automatically detects distance or affinity)?".format(
self.distance
)
)
n_pca = None
else:
n_pca = self.n_pca
precomputed = None
if self.n_pca is None or self.n_pca >= np.min(X.shape):
n_pca = None
else:
n_pca = self.n_pca
return X, n_pca, self._parse_n_landmark(X), precomputed, update_graph

def _update_graph(self, X, precomputed, n_pca, n_landmark, **kwargs):
Expand All @@ -335,26 +327,22 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark, **kwargs):
"""
self.graph = None
else:
try:
self.graph.set_params(
n_pca=n_pca,
precomputed=precomputed,
n_landmark=n_landmark,
random_state=self.random_state,
knn=self.knn,
decay=self.decay,
distance=self.distance,
n_svd=self.n_svd,
n_jobs=self.n_jobs,
thresh=self.thresh,
verbose=self.verbose,
**(self.kwargs)
)
self._set_graph_params(
n_pca=n_pca,
precomputed=precomputed,
n_landmark=n_landmark,
random_state=self.random_state,
knn=self.knn,
decay=self.decay,
distance=self.distance,
n_svd=self.n_svd,
n_jobs=self.n_jobs,
thresh=self.thresh,
verbose=self.verbose,
**(self.kwargs)
)
if self.graph is not None:
_logger.info("Using precomputed graph and diffusion operator...")
except ValueError as e:
# something changed that should have invalidated the graph
_logger.debug("Reset graph due to {}".format(str(e)))
self.graph = None

def fit(self, X, **kwargs):
"""Computes the graph
Expand Down

0 comments on commit f802646

Please sign in to comment.