Skip to content

Commit

Permalink
Merge pull request #18 from MICS-Lab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
quentinblampey committed Jun 7, 2023
2 parents 20abec2 + b790f20 commit 2046e50
Show file tree
Hide file tree
Showing 13 changed files with 369 additions and 159 deletions.
2 changes: 1 addition & 1 deletion docs/advanced/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
We provide some help to choose [scyan.Scyan][] initialization parameters. We listed below the most important ones.

- `prior_std` is probably one of the most important parameters. Its default value should work for most of the usage, but it can be changed if needed. A low `prior_std` (e.g., `0.2`) will help better separate the populations, but it may be too stringent, and some small populations may disappear. In contrast, a high `prior_std` (e.g., `0.35`) increases the chances of having a large diversity of populations, but their separation may be less clear. We recommend to start with a medium value such as `0.25` or `0.3` and reducing it afterwards if it already captures all the populations.
- Reducing the `temperature` can help better capture small populations (e.g., `0.25`). If it's not enough, it's also possible to use `modulo_temp = 3`.
- Reducing the `temperature` can help better capture small populations (e.g., `0.25`).
- `batch_ref` is the reference batch we use to align distributions. By default, we use the batch where we have the most cells, but you can choose your own reference. For that, please choose a batch that is representative of the diversity of populations you want to annotate; it can help the batch effect correction.
- To improve batch effect correction, we recommend to let the model run longer. This can be done by changing the parameters of the `model.fit()` method: for instance, one can increase the `patience` (e.g., 6) and/or decrease the `min_delta` (e.g., 0.5).
- `continuous_covariates` and `categorical_covariates` can be provided to the model if you have some. For instance, if you changed one antibody, you can add a categorical covariate telling which samples have been measured with which antibody. Any covariate may help the model annotations and batch effect correction.
Expand Down
1 change: 1 addition & 0 deletions docs/advice.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
### What should I do if Scyan seems wrong?

- First thing to do is to check your table again. You may have made a typo that could confuse the model. Typically, if you have written `Marker+` for a population that is `Marker-` (or the opposite), it can perturb the prediction toward this population **and** toward other populations.
- Try providing `prior_std=0.35` to the `scyan.Scyan` model. Maybe this parameter was too low. Advises to better choose the parameters can be found [here](../advanced/parameters).
- If one population annotation seems not consistent, or if a group of cells has not been predicted, you can try to target it with [`scyan.tools.PolygonGatingUMAP`][scyan.tools.PolygonGatingUMAP]. Then, use [`scyan.plot.pop_expressions(model, True, key="scyan_selected")`][scyan.plot.pop_expressions] to explore the expressions of this population. Another interesting graph is [`scyan.plot.probs_per_marker(model, True, key="scyan_selected")`][scyan.plot.probs_per_marker]: look for markers that show up dark on the heatmap, it may guide you to find some errors in the knowledge table. Combine it with [a UMAP plot][scyan.plot.umap] or with a [scatter plot][scyan.plot.scatter] to make sure it seems correct, and then read some literature again / update your table.
- One reason for not predicting a population may be an unbalanced knowledge quantity between two related populations. For instance, having 10 values inside the table for `CD4 T CM` cells versus 5 values for `CD4 T EM` cells will probably make the model predict very few `CD4 T CM` cells. Indeed, `CD4 T CM` has many constraints compared to `CD4 T EM`, which becomes the "easy prediction" (indeed, very few constraints are applied to this population). In that case, read the advice related to the scatter plot above again.

Expand Down
4 changes: 4 additions & 0 deletions docs/api/plots.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
options:
show_root_heading: true

::: scyan.plot.log_prob_threshold
options:
show_root_heading: true

::: scyan.plot.pop_level
options:
show_root_heading: true
Expand Down
233 changes: 150 additions & 83 deletions docs/tutorials/batch_effect_correction.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/tutorials/preprocessing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"\n",
"<details class=\"tip\">\n",
" <summary>Click to show an example</summary>\n",
" <p>This short script will concatenate all the FCS inside a specific folder, and save each file name into <code>adata.obs[\"file\"]</code>so that we don't loose information. You can add additional information, e.g. in <code>adata.obs[\"batch\"]</code> if you have different batches.</p>\n",
" <p>This short script will concatenate all the FCS inside a specific folder, and save each file name into <code>adata.obs[\"file\"]</code> so that we don't loose information. You can add additional information, e.g. in <code>adata.obs[\"batch\"]</code> if you have different batches.</p>\n",
" <div class=\"highlight\"><pre><span></span><code><a id=\"__codelineno-6-1\" name=\"__codelineno-6-1\" href=\"#__codelineno-6-1\"></a><span class=\"kn\">import</span> <span class=\"nn\">anndata</span>\n",
"<a id=\"__codelineno-6-2\" name=\"__codelineno-6-2\" href=\"#__codelineno-6-2\"></a><span class=\"kn\">from</span> <span class=\"nn\">pathlib</span> <span class=\"kn\">import</span> <span class=\"n\">Path</span>\n",
"<a id=\"__codelineno-6-3\" name=\"__codelineno-6-3\" href=\"#__codelineno-6-3\"></a>\n",
Expand All @@ -107,7 +107,7 @@
"<a id=\"__codelineno-6-10\" name=\"__codelineno-6-10\" href=\"#__codelineno-6-10\"></a> <span class=\"n\">adata</span><span class=\"o\">.</span><span class=\"n\">obs</span><span class=\"p\">[</span><span class=\"s2\">&quot;batch&quot;</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"s2\">&quot;NA&quot;</span> <span class=\"c1\"># If you have batches, add their names here</span>\n",
"<a id=\"__codelineno-6-11\" name=\"__codelineno-6-11\" href=\"#__codelineno-6-11\"></a> <span class=\"k\">return</span> <span class=\"n\">adata</span>\n",
"<a id=\"__codelineno-6-12\" name=\"__codelineno-6-12\" href=\"#__codelineno-6-12\"></a>\n",
"<a id=\"__codelineno-6-13\" name=\"__codelineno-6-13\" href=\"#__codelineno-6-13\"></a><span class=\"n\">adata</span> <span class=\"o\">=</span> <span class=\"n\">anndata</span><span class=\"o\">.</span><span class=\"n\">concat</span><span class=\"p\">([</span><span class=\"n\">read_one</span><span class=\"p\">(</span><span class=\"n\">p</span><span class=\"p\">)</span> <span class=\"k\">for</span> <span class=\"n\">p</span> <span class=\"ow\">in</span> <span class=\"n\">fcs_paths</span><span class=\"p\">],</span> <span class=\"n\">label</span><span class=\"o\">=</span><span class=\"s2\">&quot;file&quot;</span><span class=\"p\">,</span> <span class=\"n\">index_unique</span><span class=\"o\">=</span><span class=\"s2\">&quot;-&quot;</span><span class=\"p\">)</span>\n",
"<a id=\"__codelineno-6-13\" name=\"__codelineno-6-13\" href=\"#__codelineno-6-13\"></a><span class=\"n\">adata</span> <span class=\"o\">=</span> <span class=\"n\">anndata</span><span class=\"o\">.</span><span class=\"n\">concat</span><span class=\"p\">([</span><span class=\"n\">read_one</span><span class=\"p\">(</span><span class=\"n\">p</span><span class=\"p\">)</span> <span class=\"k\">for</span> <span class=\"n\">p</span> <span class=\"ow\">in</span> <span class=\"n\">fcs_paths</span><span class=\"p\">],</span> <span class=\"n\">index_unique</span><span class=\"o\">=</span><span class=\"s2\">&quot;-&quot;</span><span class=\"p\">)</span>\n",
"</code></pre></div>\n",
"</details>"
]
Expand Down Expand Up @@ -382,11 +382,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Use all markers to compute the UMAP\n",
"# Option 1: Use all markers to compute the UMAP\n",
"scyan.tools.umap(adata)\n",
"\n",
"# Use only the cell-type markers (recommended), or your choose your own list of markers\n",
"scyan.tools.umap(adata, markers=table.columns) # Remove the 'markers' argument to run on all markers"
"# Option 2: Use only the cell-type markers (recommended), or your choose your own list of markers\n",
"scyan.tools.umap(adata, markers=table.columns)"
]
},
{
Expand Down
103 changes: 45 additions & 58 deletions docs/tutorials/usage.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scyan"
version = "1.4.0"
version = "1.4.1"
description = "Single-cell Cytometry Annotation Network"
documentation = "https://mics-lab.github.io/scyan/"
homepage = "https://mics-lab.github.io/scyan/"
Expand Down
125 changes: 125 additions & 0 deletions scyan/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import logging
from typing import Optional

import numpy as np
import pandas as pd
import torch
from anndata import AnnData

from . import utils
from .data import _prepare_data
from .module.distribution import PriorDistribution

log = logging.getLogger(__name__)


class Baseline:
"""Baseline of Scyan (i.e., without the normalizing flow)."""

def __init__(
self,
adata: AnnData,
table: pd.DataFrame,
prior_std: float = 0.3,
):
"""
Args:
adata: `AnnData` object containing the FCS data of $N$ cells. **Warning**: it has to be preprocessed (e.g. `asinh` or `logicle`) and scaled (see https://mics-lab.github.io/scyan/tutorials/preprocessing/).
table: Dataframe of shape $(P, M)$ representing the biological knowledge about markers and populations. The columns names corresponds to marker that must be in `adata.var_names`.
prior_std: Standard deviation $\sigma$ of the cell-specific random variable $H$.
"""
super().__init__()
self.adata, self.table, self.continuum_markers = utils._validate_inputs(
adata, table, []
)
self.prior_std = prior_std
self.n_pops, self.n_markers = self.table.shape

self._prepare_data()

self.prior = PriorDistribution(
torch.tensor(table.values, dtype=torch.float32),
torch.full((self.n_markers,), False),
self.prior_std,
self.n_markers,
)

log.info(f"Initialized {self}")

@property
def pop_names(self) -> pd.Index:
"""Name of the populations considered in the knowledge table"""
return self.table.index.get_level_values(0)

@property
def var_names(self) -> pd.Index:
"""Name of the markers considered in the knowledge table"""
return self.table.columns

def __repr__(self) -> str:
return f"Baseline model with N={self.adata.n_obs} cells, P={self.n_pops} populations and M={len(self.var_names)} markers."

def _prepare_data(self) -> None:
"""Initialize the data"""
self.x, _ = _prepare_data(
self.adata,
self.table.columns,
None,
[],
[],
)

def predict(
self,
key_added: Optional[str] = "baseline_pop",
add_levels: bool = True,
log_prob_th: float = -50,
) -> pd.Series:
"""Model population predictions, i.e. one population is assigned for each cell. Predictions are saved in `adata.obs.scyan_pop` by default.
!!! note
Some cells may not be annotated, if their log probability is lower than `log_prob_th` for all populations. Then, the predicted label will be `np.nan`.
Args:
key_added: Column name used to save the predictions in `adata.obs`. If `None`, then the predictions will not be saved.
add_levels: If `True`, and if [hierarchical population names](../../tutorials/usage/#working-with-hierarchical-populations) were provided, then it also saves the prediction for every population level.
log_prob_th: If the log-probability of the most probable population for one cell is below this threshold, this cell will not be annotated (`np.nan`).
Returns:
Population predictions (pandas `Series` of length $N$ cells).
"""
df = self.predict_proba()

populations = df.iloc[:, : self.n_pops].idxmax(axis=1).astype("category")
populations[df["max_log_prob"] < log_prob_th] = np.nan

if key_added is not None:
self.adata.obs[key_added] = pd.Categorical(
populations, categories=self.pop_names
)
if add_levels and isinstance(self.table.index, pd.MultiIndex):
utils._add_level_predictions(self, key_added)

missing_pops = self.n_pops - len(populations.cat.categories)
if missing_pops:
log.warning(
f"{missing_pops} population(s) were not predicted. It may be due to:\n - Errors in the knowledge table (see https://mics-lab.github.io/scyan/advice/#advice-for-the-creation-of-the-table)\n - The model hyperparameters choice (see https://mics-lab.github.io/scyan/advanced/parameters/)\n - Or maybe these populations are really absent from this dataset."
)

return populations

def predict_proba(self) -> pd.DataFrame:
"""Soft predictions (i.e. an array of probability per population) for each cell.
Returns:
Dataframe of shape `(N, P)` with probabilities for each population.
"""
log_probs = self.prior.log_prob(self.x) - torch.log(torch.tensor(self.n_pops))
probs = torch.softmax(log_probs, dim=1)

df = pd.DataFrame(probs.numpy(force=True), columns=self.pop_names)

max_log_probs = log_probs.max(1)
df["max_log_prob"] = max_log_probs.values.numpy(force=True)

return df
1 change: 1 addition & 0 deletions scyan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def predict(
Population predictions (pandas `Series` of length $N$ cells).
"""
df = self.predict_proba()
self.adata.obs["scyan_log_probs"] = df["max_log_prob_u"].values

populations = df.iloc[:, : self.n_pops].idxmax(axis=1).astype("category")
populations[df["max_log_prob_u"] < log_prob_th] = np.nan
Expand Down
2 changes: 1 addition & 1 deletion scyan/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .density import kde
from .density import kde, log_prob_threshold
from .heatmap import probs_per_marker
from .dot import scatter, umap, pop_level
from .graph import pops_hierarchy
Expand Down
26 changes: 26 additions & 0 deletions scyan/plot/density.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -86,3 +87,28 @@ def kde(
palette=get_palette_others(df, key),
hue_order=sorted(df[key].unique(), key="Others".__eq__),
)


@plot_decorator(adata=True)
def log_prob_threshold(adata: AnnData, show: bool = True):
"""Plot the number of cells annotated depending on the log probability threshold (below which cells are left non-classified). It can be helpful to determine the best threshold value, i.e. before a significative decrease in term of number of cells annotated.
!!! note
To use this function, you first need to fit a `scyan.Scyan` model and use the `model.predict()` method.
Args:
adata: The `anndata` object used during the model training.
show: Whether or not to display the figure.
"""
assert (
"scyan_log_probs" in adata.obs
), f"Cannot find 'scyan_log_probs' in adata.obs. Have you run model.predict()?"

x = np.sort(adata.obs["scyan_log_probs"])
y = 1 - np.arange(len(x)) / float(len(x))

plt.plot(x, y)
plt.xlim(-100, x.max())
sns.despine(offset=10, trim=True)
plt.ylabel("Ratio of predicted cells")
plt.xlabel("Log density threshold")
10 changes: 5 additions & 5 deletions scyan/tools/biomarkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ def cell_type_ratios(
key: str = "scyan_pop",
among: str = None,
) -> pd.DataFrame:
"""Count for each patient (or group) the number of cells for each population.
"""Computes the ratio of cells per population. This ratio can be provided for each patient (or for any kind of 'group').
Args:
adata: An `AnnData` object.
groupby: Key(s) of `adata.obs` used to create groups (e.g. the patient ID).
normalize: If `False`, returns counts instead of percentages.
normalize: If `False`, returns counts instead of ratios.
key: Key of `adata.obs` containing the population names (or the values to count).
among: Key of `adata.obs` containing the parent population name. For example, if 'T CD4 RM' is found in `adata.obs[key]`, then we may find something like 'T cell' in `adata.obs[among]`. Typically, if using hierarchical populations, you can provide `'scyan_pop_level'` with your level name.
among: Key of `adata.obs` containing the parent population name. Typically, if using hierarchical populations, you can provide `'scyan_pop_level'` with your level name. E.g., if the parent of population of "T CD4 RM" is called "T cells" in `adata.obs[among]`, then this function computes the 'T CD4 RM ratio among T cells'.
Returns:
A DataFrame of counts (one row per group, one column per population).
A DataFrame of ratios or counts (one row per group, one column per population). If `normalize=False`, then each row sums to 1 (for `among=None`).
"""
normalize = among is not None or normalize
column_suffix = "percentage" if normalize else "count"
column_suffix = "ratio" if normalize else "count"

counts = _get_counts(adata, groupby, key, normalize)

Expand Down
9 changes: 4 additions & 5 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def test_normalize_cell_populations(adata: AnnData):
df = scyan.tools.cell_type_ratios(adata)

assert all(
df[f"{pop} percentage"][0] == count / 8
for pop, count in zip(pop_names, [1, 3, 2, 2])
df[f"{pop} ratio"][0] == count / 8 for pop, count in zip(pop_names, [1, 3, 2, 2])
)


Expand All @@ -49,9 +48,9 @@ def test_group_cell_populations(adata: AnnData):
def test_cell_populations_among(adata: AnnData):
df = scyan.tools.cell_type_ratios(adata, groupby="id", among="scyan_pop_level")

assert df.loc[1, "c percentage among C"] == 1
assert df.loc[1, "a2 percentage among A"] == 0.75
assert np.isnan(df.loc[2, "a2 percentage among A"])
assert df.loc[1, "c ratio among C"] == 1
assert df.loc[1, "a2 ratio among A"] == 0.75
assert np.isnan(df.loc[2, "a2 ratio among A"])


def test_mean_intensities(adata: AnnData):
Expand Down

0 comments on commit 2046e50

Please sign in to comment.