Skip to content

Commit

Permalink
add rolling median filter to variable_genes, switch plot to scatterplot
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Oct 4, 2019
1 parent 0861a2f commit d6af617
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 98 deletions.
6 changes: 4 additions & 2 deletions scprep/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def filter_variable_genes(data, *extra_data, span=0.7, interpolate=0.2, kernel_s
cutoff=None, percentile=80):
"""Filter all genes with low variability
Variability is computed as the deviation from a loess fit of the mean-variance curve
Variability is computed as the deviation from a loess fit
to the rolling median of the mean-variance curve
Parameters
----------
Expand Down Expand Up @@ -405,7 +406,8 @@ def filter_variable_genes(data, *extra_data, span=0.7, interpolate=0.2, kernel_s
extra_data : array-like, shape=[any, m_features]
Filtered extra data, if passed.
"""
var_genes = measure.variable_genes(data, span=span, interpolate=interpolate)
var_genes = measure.variable_genes(data, span=span, interpolate=interpolate,
kernel_size=kernel_size)
keep_cells_idx = _get_filter_idx(var_genes,
cutoff, percentile,
keep_cells='above')
Expand Down
13 changes: 9 additions & 4 deletions scprep/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def gene_set_expression(data, genes=None, library_size_normalize=False,


@utils._with_pkg(pkg="statsmodels")
def variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05):
def variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05, return_means=False):
"""Measure the variability of each gene in a dataset
Variability is computed as the deviation from a loess fit
Expand All @@ -88,6 +88,8 @@ def variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05):
kernel_size : float or int, optional (default: 0.05)
Width of rolling median window. If a float, the width is given by
kernel_size * data.shape[1]
return_means : boolean, optional (default: False)
If True, return the gene means
Returns
-------
Expand All @@ -106,10 +108,13 @@ def variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05):
lowess = statsmodels.nonparametric.smoothers_lowess.lowess(
data_std_med, data_mean,
delta=delta, frac=span, return_sorted=False)
variability = data_std - lowess
result = data_std - lowess
if columns is not None:
variability = pd.Series(variability, index=columns, name='variability')
return variability
result = pd.Series(result, index=columns, name='variability')
data_mean = pd.Series(data_mean, index=columns, name='mean')
if return_means:
result = result, data_mean
return result


def _get_percentile_cutoff(data, cutoff=None, percentile=None, required=False):
Expand Down
3 changes: 2 additions & 1 deletion scprep/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .scatter import scatter, scatter2d, scatter3d, rotate_scatter3d
from .histogram import histogram, plot_library_size, plot_gene_set_expression, plot_variable_genes
from .histogram import histogram, plot_library_size, plot_gene_set_expression
from .marker import marker_plot
from .scree import scree_plot
from .jitter import jitter
from .variable_genes import plot_variable_genes
from . import tools, colors
83 changes: 0 additions & 83 deletions scprep/plot/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,86 +288,3 @@ def plot_gene_set_expression(data, genes=None,
bins=bins, log=log, ax=ax, figsize=figsize,
xlabel=xlabel, title=title, fontsize=fontsize,
filename=filename, dpi=dpi, **kwargs)


@utils._with_pkg(pkg="matplotlib", min_version=3)
def plot_variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05,
bins=100, log=False,
cutoff=None, percentile=None,
ax=None, figsize=None,
xlabel='Gene variability',
ylabel='Number of genes',
title=None,
fontsize=None,
filename=None,
dpi=None, **kwargs):
"""Plot the histogram of gene variability
Variability is computed as the deviation from a loess fit of the mean-variance curve
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input data. Multiple datasets may be given as a list of array-likes.
span : float, optional (default: 0.7)
Fraction of genes to use when computing the loess estimate at each point
interpolate : float, optional (default: 0.2)
Multiple of the standard deviation of variances at which to interpolate
linearly in order to reduce computation time.
kernel_size : float or int, optional (default: 0.05)
Width of rolling median window. If a float, the width is given by
kernel_size * data.shape[1]
bins : int, optional (default: 100)
Number of bins to draw in the histogram
log : bool, or {'x', 'y'}, optional (default: False)
If True, plot both axes on a log scale. If 'x' or 'y',
only plot the given axis on a log scale. If False,
plot both axes on a linear scale.
cutoff : float or `None`, optional (default: `None`)
Absolute cutoff at which to draw a vertical line.
Only one of `cutoff` and `percentile` may be given.
percentile : float or `None`, optional (default: `None`)
Percentile between 0 and 100 at which to draw a vertical line.
Only one of `cutoff` and `percentile` may be given.
library_size_normalize : bool, optional (default: False)
Divide gene set expression by library size
ax : `matplotlib.Axes` or None, optional (default: None)
Axis to plot on. If None, a new axis will be created.
figsize : tuple or None, optional (default: None)
If not None, sets the figure size (width, height)
[x,y]label : str, optional
Labels to display on the x and y axis.
title : str or None, optional (default: None)
Axis title.
fontsize : float or None (default: None)
Base font size.
filename : str or None (default: None)
file to which the output is saved
dpi : int or None, optional (default: None)
The resolution in dots per inch. If None it will default to the value
savefig.dpi in the matplotlibrc file. If 'figure' it will set the dpi
to be the value of the figure. Only used if filename is not None.
**kwargs : additional arguments for `matplotlib.pyplot.hist`
Returns
-------
ax : `matplotlib.Axes`
axis on which plot was drawn
"""
if hasattr(data, 'shape') and len(data.shape) == 2:
var_genes = measure.variable_genes(
data, span=span, interpolate=interpolate)
else:
data_array = utils.to_array_or_spmatrix(data)
if len(data_array.shape) == 2 and data_array.dtype.type is not np.object_:
var_genes = measure.variable_genes(
data_array, span=span, interpolate=interpolate)
else:
var_genes = [measure.variable_genes(
d, span=span, interpolate=interpolate)
for d in data]
return histogram(var_genes,
cutoff=cutoff, percentile=percentile,
bins=bins, log=log, ax=ax, figsize=figsize,
xlabel=xlabel, title=title, fontsize=fontsize,
filename=filename, dpi=dpi, **kwargs)
71 changes: 71 additions & 0 deletions scprep/plot/variable_genes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from .scatter import scatter
from .. import utils, measure
from ..filter import _get_filter_idx


@utils._with_pkg(pkg="matplotlib", min_version=3)
def plot_variable_genes(data, span=0.7, interpolate=0.2, kernel_size=0.05,
cutoff=None, percentile=90,
ax=None, figsize=None,
xlabel='Gene mean',
ylabel='Standardized variance',
title=None,
fontsize=None,
filename=None,
dpi=None, **kwargs):
"""Plot the histogram of gene variability
Variability is computed as the deviation from a loess fit
to the rolling median of the mean-variance curve
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input data. Multiple datasets may be given as a list of array-likes.
span : float, optional (default: 0.7)
Fraction of genes to use when computing the loess estimate at each point
interpolate : float, optional (default: 0.2)
Multiple of the standard deviation of variances at which to interpolate
linearly in order to reduce computation time.
kernel_size : float or int, optional (default: 0.05)
Width of rolling median window. If a float, the width is given by
kernel_size * data.shape[1]
cutoff : float or `None`, optional (default: `None`)
Absolute cutoff at which to draw a vertical line.
Only one of `cutoff` and `percentile` may be given.
percentile : float or `None`, optional (default: 90)
Percentile between 0 and 100 at which to draw a vertical line.
Only one of `cutoff` and `percentile` may be given.
ax : `matplotlib.Axes` or None, optional (default: None)
Axis to plot on. If None, a new axis will be created.
figsize : tuple or None, optional (default: None)
If not None, sets the figure size (width, height)
[x,y]label : str, optional
Labels to display on the x and y axis.
title : str or None, optional (default: None)
Axis title.
fontsize : float or None (default: None)
Base font size.
filename : str or None (default: None)
file to which the output is saved
dpi : int or None, optional (default: None)
The resolution in dots per inch. If None it will default to the value
savefig.dpi in the matplotlibrc file. If 'figure' it will set the dpi
to be the value of the figure. Only used if filename is not None.
**kwargs : additional arguments for `matplotlib.pyplot.hist`
Returns
-------
ax : `matplotlib.Axes`
axis on which plot was drawn
"""
variability, means = measure.variable_genes(data, span=span, interpolate=interpolate,
kernel_size=kernel_size, return_means=True)
keep_cells_idx = _get_filter_idx(variability,
cutoff, percentile,
keep_cells='above')
return scatter(means, variability, c=keep_cells_idx,
cmap={True : 'red', False : 'black'},
xlabel=xlabel, ylabel=ylabel, title=title,
fontsize=fontsize, filename=filename, dpi=dpi,
**kwargs)
9 changes: 1 addition & 8 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,7 @@ def test_plot_gene_set_expression_single_gene(self):
def test_plot_variable_genes(self):
scprep.plot.plot_variable_genes(
self.X,
color='r')

def test_plot_variable_genes_multiple(self):
scprep.plot.plot_variable_genes([
self.X, scprep.select.select_rows(
self.X, idx=np.arange(self.X.shape[0] // 2))],
filename="test_variable_genes.png",
color=['r', 'b'])
filename="test_variable_genes.png")
assert os.path.exists("test_variable_genes.png")

def test_variable_genes_list_of_lists(self):
Expand Down

0 comments on commit d6af617

Please sign in to comment.