Skip to content

Commit

Permalink
add option of returning drevi matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Oct 15, 2018
1 parent 004745f commit 455a882
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
31 changes: 19 additions & 12 deletions python/scprep/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mutual_information(x, y, bins=8):


def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1,
plot=False, **kwargs):
plot=False, return_drevi=False, **kwargs):
"""kNN conditional Density Resampled Estimate of Mutual Information
Calculates k-Nearest Neighbor conditional Density Resampled Estimate of
Expand Down Expand Up @@ -114,12 +114,18 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1,
plot : bool, optional (default: False)
If True, DREMI create plots of the data like those seen in
Fig 5C/D of van Dijk et al. 2018. (doi:10.1016/j.cell.2018.05.061).
return_drevi : bool, optional (default: False)
If True, return the DREVI normalized density matrix in addition
to the DREMI score.
**kwargs : additional arguments for `scprep.stats.plot_knnDREMI`
Returns
-------
dremi : float
kNN condtional Density resampled estimate of mutual information
drevi : np.ndarray
DREVI normalized density matrix. Only returned if `return_drevi`
is True.
Examples
--------
Expand Down Expand Up @@ -193,10 +199,10 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1,

# Calculate conditional entropy
# NB: not using thresholding here; entr(M) calcs -x*log(x) elementwise
bin_density_norm = bin_density_norm = bin_density / \
drevi = bin_density / \
np.sum(bin_density, axis=0) # columns sum to 1
# calc entropy of each column
cond_entropies = stats.entropy(bin_density_norm, base=2)
cond_entropies = stats.entropy(drevi, base=2)

# Mutual information (not normalized)
marginal_entropy = stats.entropy(
Expand All @@ -209,23 +215,26 @@ def knnDREMI(x, y, k=10, n_bins=20, n_mesh=3, n_jobs=1,
mutual_info = marginal_entropy - conditional_entropy

# DREMI
marginal_entropy_norm = stats.entropy(np.sum(bin_density_norm, axis=1),
marginal_entropy_norm = stats.entropy(np.sum(drevi, axis=1),
base=2)
cond_sums_norm = np.mean(bin_density_norm)
cond_sums_norm = np.mean(drevi)
conditional_entropy_norm = np.sum(cond_entropies * cond_sums_norm)

dremi = marginal_entropy_norm - conditional_entropy_norm

if plot:
plot_knnDREMI(dremi, mutual_info,
x, y, n_bins, n_mesh,
density, bin_density, bin_density_norm, **kwargs)
return dremi
density, bin_density, drevi, **kwargs)
if return_drevi:
return dremi, drevi
else:
return dremi


@plot._with_matplotlib
def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh,
density, bin_density, bin_density_norm,
density, bin_density, drevi,
figsize=(12, 3.5), filename=None,
xlabel="Feature 1", ylabel="Feature 2",
title_fontsize=18, label_fontsize=16,
Expand Down Expand Up @@ -279,8 +288,7 @@ def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh,
axes[1].set_xlabel(xlabel, fontsize=label_fontsize)

# Plot joint probability
raw_density_data = bin_density
axes[2].imshow(raw_density_data,
axes[2].imshow(bin_density,
cmap="inferno", origin="lower", aspect="auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
Expand All @@ -289,8 +297,7 @@ def plot_knnDREMI(dremi, mutual_info, x, y, n_bins, n_mesh,
axes[2].set_xlabel(xlabel, fontsize=label_fontsize)

# Plot conditional probability
raw_density_data = bin_density_norm
axes[3].imshow(raw_density_data,
axes[3].imshow(drevi,
cmap="inferno", origin="lower", aspect="auto")
axes[3].set_xticks([])
axes[3].set_yticks([])
Expand Down
5 changes: 4 additions & 1 deletion python/test/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def test_knnDREMI():
Y = scprep.stats.knnDREMI(X[:, 0], X[:, 1])
assert isinstance(Y, float)
np.testing.assert_allclose(Y, 0.16238906)
scprep.stats.knnDREMI(X[:, 0], X[:, 1], plot=True)
Y2, drevi = scprep.stats.knnDREMI(X[:, 0], X[:, 1],
plot=True, return_drevi=True)
assert Y2 == Y
assert drevi.shape == (20, 20)
matrix.test_all_matrix_types(
X, utils.assert_transform_equals, Y=Y,
transform=partial(_test_fun_2d, fun=scprep.stats.knnDREMI),
Expand Down

0 comments on commit 455a882

Please sign in to comment.