From 455a88230c5c15a904516850f392712110ad76ff Mon Sep 17 00:00:00 2001 From: Scott Gigante Date: Mon, 15 Oct 2018 13:15:26 -0400 Subject: [PATCH] add option of returning drevi matrix --- python/scprep/stats.py | 31 +++++++++++++++++++------------ python/test/test_stats.py | 5 ++++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/python/scprep/stats.py b/python/scprep/stats.py index afdb5a3c..b2e078d4 100644 --- a/python/scprep/stats.py +++ b/python/scprep/stats.py @@ -80,7 +80,7 @@ def mutual_information(x, y, bins=8): def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1, - plot=False, **kwargs): + plot=False, return_drevi=False, **kwargs): """kNN conditional Density Resampled Estimate of Mutual Information Calculates k-Nearest Neighbor conditional Density Resampled Estimate of @@ -114,12 +114,18 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1, plot : bool, optional (default: False) If True, DREMI create plots of the data like those seen in Fig 5C/D of van Dijk et al. 2018. (doi:10.1016/j.cell.2018.05.061). + return_drevi : bool, optional (default: False) + If True, return the DREVI normalized density matrix in addition + to the DREMI score. **kwargs : additional arguments for `scprep.stats.plot_knnDREMI` Returns ------- dremi : float kNN condtional Density resampled estimate of mutual information + drevi : np.ndarray + DREVI normalized density matrix. Only returned if `return_drevi` + is True. Examples -------- @@ -193,10 +199,10 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1, # Calculate conditional entropy # NB: not using thresholding here; entr(M) calcs -x*log(x) elementwise - bin_density_norm = bin_density_norm = bin_density / \ + drevi = bin_density / \ np.sum(bin_density, axis=0) # columns sum to 1 # calc entropy of each column - cond_entropies = stats.entropy(bin_density_norm, base=2) + cond_entropies = stats.entropy(drevi, base=2) # Mutual information (not normalized) marginal_entropy = stats.entropy( @@ -209,9 +215,9 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1, mutual_info = marginal_entropy - conditional_entropy # DREMI - marginal_entropy_norm = stats.entropy(np.sum(bin_density_norm, axis=1), + marginal_entropy_norm = stats.entropy(np.sum(drevi, axis=1), base=2) - cond_sums_norm = np.mean(bin_density_norm) + cond_sums_norm = np.mean(drevi) conditional_entropy_norm = np.sum(cond_entropies * cond_sums_norm) dremi = marginal_entropy_norm - conditional_entropy_norm @@ -219,13 +225,16 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1, if plot: plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh, - density, bin_density, bin_density_norm, **kwargs) - return dremi + density, bin_density, drevi, **kwargs) + if return_drevi: + return dremi, drevi + else: + return dremi @plot._with_matplotlib def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh, - density, bin_density, bin_density_norm, + density, bin_density, drevi, figsize=(12, 3.5), filename=None, xlabel="Feature 1", ylabel="Feature 2", title_fontsize=18, label_fontsize=16, @@ -279,8 +288,7 @@ def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh, axes[1].set_xlabel(xlabel, fontsize=label_fontsize) # Plot joint probability - raw_density_data = bin_density - axes[2].imshow(raw_density_data, + axes[2].imshow(bin_density, cmap="inferno", origin="lower", aspect="auto") axes[2].set_xticks([]) axes[2].set_yticks([]) @@ -289,8 +297,7 @@ def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh, axes[2].set_xlabel(xlabel, fontsize=label_fontsize) # Plot conditional probability - raw_density_data = bin_density_norm - axes[3].imshow(raw_density_data, + axes[3].imshow(drevi, cmap="inferno", origin="lower", aspect="auto") axes[3].set_xticks([]) axes[3].set_yticks([]) diff --git a/python/test/test_stats.py b/python/test/test_stats.py index 00abe160..ba8ad784 100644 --- a/python/test/test_stats.py +++ b/python/test/test_stats.py @@ -57,7 +57,10 @@ def test_knnDREMI(): Y = scprep.stats.knnDREMI(X[:, 0], X[:, 1]) assert isinstance(Y, float) np.testing.assert_allclose(Y, 0.16238906) - scprep.stats.knnDREMI(X[:, 0], X[:, 1], plot=True) + Y2, drevi = scprep.stats.knnDREMI(X[:, 0], X[:, 1], + plot=True, return_drevi=True) + assert Y2 == Y + assert drevi.shape == (20, 20) matrix.test_all_matrix_types( X, utils.assert_transform_equals, Y=Y, transform=partial(_test_fun_2d, fun=scprep.stats.knnDREMI),