Skip to content

Commit

Permalink
fix solver
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Nov 11, 2019
1 parent 7031d36 commit 70a3b44
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 29 deletions.
54 changes: 39 additions & 15 deletions python/magic/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _parse_genes(self, X, genes):
genes = None
elif isinstance(genes, str) and genes == "pca_only":
if not hasattr(self.graph, "data_pca"):
raise RuntimeError("Cannot return PCA as PCA is not" " performed.")
raise RuntimeError("Cannot return PCA as PCA is not performed.")
elif genes is not None:
genes = np.array([genes]).flatten()
if not issubclass(genes.dtype.type, numbers.Integral):
Expand Down Expand Up @@ -568,8 +568,8 @@ def transform(self, X=None, genes=None, t_max=20, plot_optimal_t=False, ax=None)
)

if X is not None and not utils.matrix_is_equivalent(X, self.graph.data):
extrapolation = True
store_result = False
graph = graphtools.base.Data(X, n_pca=self.n_pca)
warnings.warn(
"Running MAGIC.transform on different "
"data to that which was used for MAGIC.fit may not "
Expand All @@ -578,26 +578,56 @@ def transform(self, X=None, genes=None, t_max=20, plot_optimal_t=False, ax=None)
UserWarning,
)
else:
extrapolation = False
X = self.X
graph = self.graph
data = self.graph
store_result = True

genes = self._parse_genes(X, genes)

if isinstance(genes, str) and genes == "pca_only":
# have to use PCA to return it
solver = "approximate"
else:
if genes is not None and self.X_magic is None:
if len(genes) < self.graph.data_nu.shape[1]:
# faster to skip PCA
solver = "exact"
store_result = False
else:
solver = self.solver

if store_result and self.X_magic is not None:
X_magic = self.X_magic
else:
X_magic = self._impute(graph, t_max=t_max, plot=plot_optimal_t, ax=ax)
if extrapolation:
n_pca = self.n_pca if solver == "approximate" else None
data = graphtools.base.Data(X, n_pca=n_pca)
if solver == "approximate":
# select PCs
X_input = data.data_nu
else:
X_input = scprep.utils.to_array_or_spmatrix(data.data)
if genes is not None and not (
isinstance(genes, str) and genes != "pca_only"
):
X_input = scprep.select.select_cols(X_input, idx=genes)
X_magic = self._impute(X_input, t_max=t_max, plot=plot_optimal_t, ax=ax)
if store_result:
self.X_magic = X_magic

print(X_magic.shape)
# return selected genes
if isinstance(genes, str) and genes == "pca_only":
X_magic = PCA().fit_transform(X_magic)
genes = ["PC{}".format(i + 1) for i in range(X_magic.shape[1])]
else:
X_magic = graph.inverse_transform(X_magic, columns=genes)
# convert back to pandas dataframe, if necessary
elif solver == "approximate":
X_magic = data.inverse_transform(X_magic, columns=genes)
elif genes is not None and len(genes) != X_magic.shape[1]:
# select genes
X_magic = scprep.select.select_cols(X_magic, idx=genes)

# convert back to pandas dataframe, if necessary
X_magic = utils.convert_to_same_format(
X_magic, X, columns=genes, prevent_sparse=True
)
Expand Down Expand Up @@ -696,7 +726,7 @@ def _impute(
Parameters
----------
data : graphtools.Graph, graphtools.Data or array-like
data : array-like
Input data
t_max : int, optional (default: 20)
Maximum value of t to consider for optimal t selection
Expand All @@ -716,13 +746,7 @@ def _impute(
X_magic : array-like, shape=[n_samples, n_pca]
Imputed data
"""

if not isinstance(data, graphtools.base.Data):
if self.solver == "approximate":
data = graphtools.base.Data(data, n_pca=self.n_pca)
elif self.solver == "exact":
data = graphtools.base.Data(data, n_pca=None)
data_imputed = scprep.utils.toarray(data.data_nu)
data_imputed = scprep.utils.toarray(data)

if data_imputed.shape[1] > max_genes_compute_t:
subsample_genes = np.random.choice(
Expand Down
70 changes: 56 additions & 14 deletions python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,86 @@
# anndata not installed
pass

import os

def test_scdata():
scdata = scprep.io.load_csv("../data/test_data.csv", cell_names=False)
scdata = scprep.filter.filter_empty_cells(scdata)
scdata = scprep.filter.filter_empty_genes(scdata)
scdata_norm = scprep.normalize.library_size_normalize(scdata)
scdata_norm = scprep.transform.sqrt(scdata_norm)
assert scdata.shape == scdata_norm.shape
data_path = os.path.join("..", "data", "test_data.csv")
if not os.path.isfile(data_path):
data_path = os.path.join("..", data_path)
scdata = scprep.io.load_csv(data_path, cell_names=False)
scdata = scprep.filter.filter_empty_cells(scdata)
scdata = scprep.filter.filter_empty_genes(scdata)
scdata = scprep.filter.filter_duplicates(scdata)
scdata_norm = scprep.normalize.library_size_normalize(scdata)
scdata_norm = scprep.transform.sqrt(scdata_norm)


def test_genes_str_int():
magic_op = magic.MAGIC(t="auto", decay=20, knn=10, verbose=False)
str_gene_magic = magic_op.fit_transform(scdata_norm, genes=["VIM", "ZEB1"])
int_gene_magic = magic_op.fit_transform(
scdata_norm, graph=magic_op.graph, genes=[-2, -1]
)
assert str_gene_magic.shape[0] == scdata_norm.shape[0]
np.testing.assert_array_equal(str_gene_magic, int_gene_magic)


def test_pca_only():
magic_op = magic.MAGIC(t="auto", decay=20, knn=10, verbose=False)
pca_magic = magic_op.fit_transform(scdata_norm, genes="pca_only")
assert pca_magic.shape[0] == scdata_norm.shape[0]
assert pca_magic.shape[1] == magic_op.n_pca

# test DREMI: need numerical precision here
magic_op.set_params(random_state=42)

def test_all_genes():
magic_op = magic.MAGIC(t="auto", decay=20, knn=10, verbose=False)
int_gene_magic = magic_op.fit_transform(scdata_norm, genes=[-2, -1])
magic_all_genes = magic_op.fit_transform(scdata_norm, genes="all_genes")
assert scdata_norm.shape == magic_all_genes.shape
int_gene_magic2 = magic_op.transform(scdata_norm, genes=[-2, -1])
np.testing.assert_allclose(int_gene_magic, int_gene_magic2, rtol=0.007)


def test_all_genes_approx():
magic_op = magic.MAGIC(
t="auto", decay=20, knn=10, verbose=False, solver="approximate"
)
int_gene_magic = magic_op.fit_transform(scdata_norm, genes=[-2, -1])
magic_all_genes = magic_op.fit_transform(scdata_norm, genes="all_genes")
assert scdata_norm.shape == magic_all_genes.shape
int_gene_magic2 = magic_op.transform(scdata_norm, genes=[-2, -1])
np.testing.assert_allclose(int_gene_magic, int_gene_magic2, rtol=0.007)


def test_dremi():
magic_op = magic.MAGIC(t="auto", decay=20, knn=10, verbose=False)
# test DREMI: need numerical precision here
magic_op.set_params(random_state=42)
magic_op.fit(scdata_norm)
dremi = magic_op.knnDREMI("VIM", "ZEB1", plot=True)
np.testing.assert_allclose(dremi, 1.573619, atol=0.0000005)
np.testing.assert_allclose(dremi, 1.591713, atol=0.0000005)


def test_solver():
# Testing exact vs approximate solver
magic_op = magic.MAGIC(t="auto", decay=20, knn=10, solver="exact", verbose=False)
magic_op = magic.MAGIC(
t="auto", decay=20, knn=10, solver="exact", verbose=False, random_state=42
)
data_imputed_exact = magic_op.fit_transform(scdata_norm)
assert np.all(data_imputed_exact >= 0)

magic_op = magic.MAGIC(
t="auto", decay=20, knn=10, solver="approximate", verbose=False
t="auto",
decay=20,
knn=10,
n_pca=150,
solver="approximate",
verbose=False,
random_state=42,
)
# magic_op.set_params(solver='approximate')
data_imputed_apprx = magic_op.fit_transform(scdata_norm)
# make sure they're close-ish
assert np.allclose(data_imputed_apprx, data_imputed_exact, atol=0.05)
np.testing.assert_allclose(data_imputed_apprx, data_imputed_exact, atol=0.15)
# make sure they're not identical
assert np.any(data_imputed_apprx != data_imputed_exact)

Expand All @@ -63,7 +105,7 @@ def test_anndata():
except NameError:
# anndata not installed
return
scdata = anndata.read_csv("../data/test_data.csv")
scdata = anndata.read_csv(data_path)
fast_magic_operator = magic.MAGIC(
t="auto", solver="approximate", decay=None, knn=10, verbose=False
)
Expand Down

0 comments on commit 70a3b44

Please sign in to comment.