Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante-immunai committed May 11, 2022
1 parent 53bd15a commit 0c34556
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
32 changes: 17 additions & 15 deletions scprep/stats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import scipy.sparse

from . import plot
from . import select
from . import utils
Expand Down Expand Up @@ -84,24 +86,24 @@ def pairwise_correlation(X, Y, ignore_nan=False):
Y_colsums = utils.matrix_sum(Y, axis=0, ignore_nan=ignore_nan)
# Basically there are four parts in the formula. We would compute them
# one-by-one
N_times_sum_xy = utils.toarray(N * Y.T.dot(X))
sum_x_times_sum_y = X_colsums * Y_colsums[:, None]
var_x = (
N
* utils.matrix_sum(
utils.matrix_transform(X, np.power, 2), axis=0, ignore_nan=ignore_nan
)
- X_colsums**2
X_sq_colsums = utils.matrix_sum(
utils.matrix_transform(X, np.power, 2), axis=0, ignore_nan=ignore_nan
)
var_y = (
N
* utils.matrix_sum(
utils.matrix_transform(Y, np.power, 2), axis=0, ignore_nan=ignore_nan
)
- Y_colsums**2
Y_sq_colsums = utils.matrix_sum(
utils.matrix_transform(Y, np.power, 2), axis=0, ignore_nan=ignore_nan
)
var_x = N * X_sq_colsums - X_colsums**2
var_y = N * Y_sq_colsums - Y_colsums**2
if ignore_nan:
# now that we have the variance computed we can fill in the NaNs
X = utils.fillna(X, 0)
Y = utils.fillna(Y, 0)
N_times_sum_xy = utils.toarray(N * Y.T.dot(X))
sum_x_times_sum_y = X_colsums * Y_colsums[:, None]
# Finally compute Pearson Correlation Coefficient as 2D array
cor = (N_times_sum_xy - sum_x_times_sum_y) / np.sqrt(var_x * var_y[:, None])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="invalid value encountered in true_divide", category=RuntimeWarning)
cor = (N_times_sum_xy - sum_x_times_sum_y) / np.sqrt(var_x * var_y[:, None])
return cor.T


Expand Down
27 changes: 26 additions & 1 deletion scprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,31 @@ def matrix_transform(data, fun, *args, **kwargs):
return data


def fillna(data, fill, copy=True):
return_cls = None
if isinstance(data, (sparse.lil_matrix, sparse.dok_matrix)):
return_cls = type(data)
assert copy, f"Cannot fillna in-place for {return_cls.__name__}"
data = data.tocsr()
elif copy:
data = data.copy()
if sparse.issparse(data):
data.data[np.isnan(data.data)] = fill
if return_cls is not None:
data = return_cls(data)
else:
data[np.isnan(data)] = fill
return data



def _nansum(data, axis=None):
if sparse.issparse(data):
return np.sum(fillna(data, 0), axis=axis)
else:
return np.nansum(data, axis=axis)


def matrix_sum(data, axis=None, ignore_nan=False):
"""Get the column-wise, row-wise, or total sum of values in a matrix.
Expand All @@ -396,7 +421,7 @@ def matrix_sum(data, axis=None, ignore_nan=False):
sums : array-like or float
Sums along desired axis.
"""
sum_fn = np.nansum if ignore_nan else np.sum
sum_fn = _nansum if ignore_nan else np.sum
if axis not in [0, 1, None]:
raise ValueError("Expected axis in [0, 1, None]. Got {}".format(axis))
if isinstance(data, pd.DataFrame):
Expand Down
36 changes: 28 additions & 8 deletions test/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from functools import partial

import scipy.sparse
from parameterized import parameterized
from scipy import stats
from tools import data
Expand Down Expand Up @@ -93,7 +95,7 @@ def test_fun(X, *args, **kwargs):
)

D = data.generate_positive_sparse_matrix(shape=(500, 100), seed=42, poisson_mean=5)
Y = test_fun(D)
Y = np.corrcoef(D.T)[:,:10]
assert Y.shape == (D.shape[1], 10)
assert np.allclose(Y[(np.arange(10), np.arange(10))], 1, atol=0)
matrix.test_all_matrix_types(
Expand Down Expand Up @@ -127,13 +129,31 @@ def test_fun(X, *args, **kwargs):


def test_pairwise_correlation_nan():
D = np.array([np.arange(10), np.arange(0, 20, 2)]).astype(float)
D[:, 3] = np.nan
C = scprep.stats.pairwise_correlation(D, D)
assert np.all(np.isnan(C))
C = scprep.stats.pairwise_correlation(D, D, ignore_nan=True)
assert not np.any(np.isnan(C))
np.testing.assert_equal(C, 1)
D = np.array([np.arange(10), np.arange(0, 20, 2), np.zeros(10)]).astype(float).T
D[3,:] = np.nan

def test_with_nan(D):
C = scprep.stats.pairwise_correlation(D, D)
assert np.all(np.isnan(C))

matrix.test_all_matrix_types(
D,
test_with_nan,
)

def test_with_ignore_nan(D):
C = scprep.stats.pairwise_correlation(D, D, ignore_nan=True)
# should still be NaN on samples that have no variance
assert np.all(np.isnan(C[-1]))
assert np.all(np.isnan(C[:,-1]))
# but shouldn't be NaN on samples that have some NaNs
assert not np.any(np.isnan(C[:2][:,:2]))
np.testing.assert_equal(C[:2][:,:2], 1)

matrix.test_all_matrix_types(
D,
test_with_ignore_nan,
)


def shan_entropy(c):
Expand Down

0 comments on commit 0c34556

Please sign in to comment.