### Notebook for the integration of all cell compartments with `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology - Faculty of Medicine - Julius-Maximilian-Universität Würzburg**
- v231212

### Import required modules

In [1]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt
from scib_metrics.benchmark import Benchmarker

  self.seed = seed
  self.dl_pin_memory_gpu_training = (


### Set up working environment

In [2]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.9.2
scanpy      1.9.4
-----
PIL                 10.0.0
absl                NA
aiohttp             3.8.5
aiosignal           1.3.1
annotated_types     0.5.0
anyio               NA
asttokens           NA
async_timeout       4.0.3
attr                23.1.0
backcall            0.2.0
backoff             2.2.1
bs4                 4.12.2
certifi             2023.07.22
charset_normalizer  3.2.0
chex                0.1.7
click               8.1.7
comm                0.1.4
contextlib2         NA
croniter            NA
cycler              0.10.0
cython_runtime      NA
dateutil            2.8.2
debugpy             1.6.7.post1
decorator           5.1.1
deepdiff            6.3.1
docrep              0.3.2
etils               1.4.1
executing           1.2.0
fastapi             0.103.0
flax                0.7.2
frozenlist          1.4.0
fsspec              2023.6.0
h5py                3.9.0
idna                3.4
igraph              0.10.8
importlib_resources NA
ipykernel         

In [3]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

Global seed set to 1712


In [4]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

### Read in and format individual cell compartments

- Read in Cardiomyocytes

In [5]:
CMC_scANVI = sc.read_h5ad('../data/heart_mm_nuclei-23-0092_scANVI-CMC_ctl231128.raw.h5ad')
CMC_scANVI.obs_names = [name.rsplit('-', 2)[0] for name in CMC_scANVI.obs_names]
CMC_scANVI

AnnData object with n_obs × n_vars = 10414 × 32285
    obs: 'cell_source', 'cell_type', 'donor', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'cell_states', 'seed_labels', 'genotype', 'batch', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'C_scANVI', 'leiden'
    var: 'gene_ids', 'feature_types', 'genome', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

- Read in Fibroblasts

In [6]:
FB_scANVI = sc.read_h5ad('../data/heart_mm_nuclei-23-0092_FB_ctl231128.raw.h5ad')
FB_scANVI.obs_names = [name.rsplit('-', 2)[0] for name in FB_scANVI.obs_names]
FB_scANVI

AnnData object with n_obs × n_vars = 9026 × 32285
    obs: 'cell_source', 'cell_type', 'donor', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'cell_states', 'seed_labels', 'genotype', 'batch', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'C_scANVI', 'leiden'
    var: 'gene_ids', 'feature_types', 'genome', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

- Read in Vascular cells

In [7]:
Vascular_scANVI = sc.read_h5ad('../data/heart_mm_nuclei-23-0092_scANVI-Vascular_ctl231128.raw.h5ad')
Vascular_scANVI.obs_names = [name.rsplit('-', 2)[0] for name in Vascular_scANVI.obs_names]
Vascular_scANVI

AnnData object with n_obs × n_vars = 18472 × 32285
    obs: 'cell_source', 'cell_type', 'donor', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'cell_states', 'seed_labels', 'genotype', 'batch', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'C_scANVI', 'leiden'
    var: 'gene_ids', 'feature_types', 'genome', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

- Read in Immune cells

In [10]:
Immune_scANVI = sc.read_h5ad('../data/heart_mm_nuclei-23-0092_scANVI-Immune_ctl231212.raw.h5ad')
Immune_scANVI.obs_names = [name.rsplit('-', 2)[0] for name in Immune_scANVI.obs_names]
Immune_scANVI.obs['C_scANVI'] = Immune_scANVI.obs['C_scANVI_new'].copy()
Immune_scANVI

AnnData object with n_obs × n_vars = 24006 × 27478
    obs: 'nCount_HTO', 'nFeature_HTO', 'HTO_classification', 'Library', 'CD45_Annotation', 'genotype', 'sample', 'donor', 'seed_labels', 'cell_source', 'cell_type', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'scrublet_score', 'cell_states', 'batch', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'leiden', 'MPC_Annotation', 'compartment', 'C_scANVI_new', 'C_scANVI'
    var: 'gene_ids-Lymphoid_scANVI', 'feature_types-Lymphoid_scANVI', 'genome-Lymphoid_scANVI', 'mt-Lymphoid_scANVI', 'ribo-Lymphoid_scANVI', 'n_cells_by_counts-Lymphoid_scANVI', 'mean_counts-Lymphoid_scANVI', 'pct_dropout_by_counts-Lymphoid_scANVI', 'total_counts-Lymphoid_scANVI', 'vst.mean-MDX_MPC', 'vst.variance-MDX_MPC', 'vst.variance.expected-MDX_MPC', 'vst.variance.s

### Merge cell compartments and compare with full object

In [11]:
compartments = CMC_scANVI.concatenate(FB_scANVI, Vascular_scANVI, Immune_scANVI, 
                                      batch_key = 'compartment', 
                                      batch_categories = ['CMC', 'FB', 'Vascular', 'Immune'], 
                                      join = 'inner')
compartments.obs_names = [name.rsplit('-', 1)[0] for name in compartments.obs_names]
compartments

AnnData object with n_obs × n_vars = 61918 × 27478
    obs: 'cell_source', 'cell_type', 'donor', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'cell_states', 'seed_labels', 'genotype', 'batch', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', '_scvi_batch', '_scvi_labels', 'C_scANVI', 'leiden', 'nCount_HTO', 'nFeature_HTO', 'HTO_classification', 'Library', 'CD45_Annotation', 'MPC_Annotation', 'compartment', 'C_scANVI_new'
    var: 'gene_ids-CMC', 'feature_types-CMC', 'genome-CMC', 'mt-CMC', 'ribo-CMC', 'n_cells_by_counts-CMC', 'mean_counts-CMC', 'pct_dropout_by_counts-CMC', 'total_counts-CMC', 'gene_ids-FB', 'feature_types-FB', 'genome-FB', 'mt-FB', 'ribo-FB', 'n_cells_by_counts-FB', 'mean_counts-FB', 'pct_dropout_by_counts-FB', 'total_counts-FB', 'gene_ids-Lymphoid_scANVI-Immune', 'feature

### Group fine grained annotations into coarse groups

In [12]:
compartments.obs['C_scANVI'].cat.categories

Index(['B_cells', 'CD4+T', 'DC', 'DOCK4+MØ', 'EC1_cap', 'EC2_cap', 'EC3_cap',
       'EC5_art', 'EC6_ven', 'EC7_atria', 'FB1', 'FB2', 'FB3', 'FB4', 'FB5',
       'FB6', 'LYVE1+MØ', 'M2MØ', 'Mast', 'MoMø', 'Monocytes', 'NK',
       'Neutrophils', 'PC1_vent', 'PC2_atria', 'PC3_str', 'SMC1_basic',
       'SMC2_art', 'aCM1', 'proIMØ', 'vCM1', 'vCM2', 'vCM3', 'vCM4'],
      dtype='object')

In [None]:
trans_from=[['vCM1', 'vCM2', 'vCM3', 'vCM4'],
['aCM1'],
['FB1', 'FB2', 'FB3', 'FB4', 'FB5', 'FB6'],
['DOCK4+aMØ', 'DOCK4+vMØ', 'LYVE1+MØ1', 'LYVE1+MØ2', 'M2MØ', 'MoMø','proIMØ'],
['CD14+Mo', 'CD69+Mo', 'nøMo'],
['CD8+T_cytox', 'CD8+T_em', 'CD8+T_te', 'CD8+T_trans'],
['CD4+T_act', 'CD4+T_naive'],
['NK_CD16hi', 'NK_CD56hi'],
['Mast'],
['B'],
['B_plasma'],
['EC7_atria'],
['EC6_ven'],
['EC5_art'],
['EC1_cap', 'EC2_cap', 'EC3_cap'],
['SMC1_basic', 'SMC2_art'],
['PC1_vent', 'PC2_atria', 'PC3_str']]
trans_to = ['vCM', 'aCM', 'FB', 'MØ', 'Monocytes', 'CD8+T', 'CD4+T', 'NK', 'Mast', 'B', 'B_plasma', 'EC', 'EC_ven', 'EC_art', 'EC_cap', 'SMC', 'PC']

compartments.obs['cell_type'] = [str(i) for i in compartments.obs['C_scANVI']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        compartments.obs['cell_type'][compartments.obs['cell_type'] == leiden_from] = celltype

In [None]:
compartments.obs['cell_type'].value_counts()

In [None]:
compartments.obs['seed_labels'] = compartments.obs['cell_type'].copy()
#del(compartments.obs['C_scANVI'])
compartments

### Visualise cell type distribution per condition

In [None]:
compartments.obs['seed_labels'].value_counts()

In [None]:
pd.crosstab(compartments.obs['seed_labels'], compartments.obs['genotype'])

### Read in other unannotated dataset and split into groups

In [None]:
compartments.obs['genotype'].value_counts()

In [None]:
reference = compartments[compartments.obs['genotype'].isin(['WT', 'Mdx'])]
reference

In [None]:
query = compartments[~compartments.obs['genotype'].isin(['WT', 'Mdx'])]
query.obs['seed_labels'] = 'Unknown'
query

In [None]:
adata = reference.concatenate(query, batch_key = 'batch', batch_categories = ['reference', 'query'], join = 'inner')
adata

In [None]:
sc.pl.scatter(adata, x = 'total_counts', y = 'pct_counts_mt', color = "batch", frameon = False)

In [None]:
adata.obs['genotype'].value_counts()

In [None]:
adata.obs['sample'].value_counts()

### Select HVGs

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "sample",
    subset = True
)
adata

### Transfer of annotation with scANVI

In [None]:
scvi.model.SCVI.setup_anndata(adata,
                              batch_key = "donor", 
                            categorical_covariate_keys = ["donor", "cell_source"], 
                            labels_key = "seed_labels", 
                            layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata, 
                             n_latent = 50, 
                             n_layers = 3, 
                             dispersion = 'gene-batch', 
                             gene_likelihood = 'nb')

In [None]:
scvi_model.train(50, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 use_gpu = 1)

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation(adata)

### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(10, 
                   check_val_every_n_epoch = 1, 
                   enable_progress_bar = True, 
                   use_gpu = 1)

In [None]:
adata.obs["C_scANVI_new"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

### Explore model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.3, spread = 5, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['sample', 'genotype', 'C_scANVI_new', 'seed_labels', 'cell_type', 'C_scANVI'], size = 0.8, legend_fontsize = 5, ncols = 4)

In [None]:
sc.pl.umap(adata, frameon = False, color = ['n_genes', 'doublet_scores', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'n_counts'], size = 0.6, legend_fontsize = 5, ncols = 4, cmap = 'magma')

### Modify object to plot canonical marker genes

In [None]:
adata_toplot = anndata.AnnData(X = np.sqrt(sc.pp.normalize_total(adata_raw, inplace = False)["X"]), var = adata_raw.var, obs = adata.obs, obsm = adata.obsm)
adata_toplot

In [None]:
sc.pl.umap(adata_toplot, frameon = False, color = ['C_scANVI', 'Ttn', 'Nppa', 'Dcn', 'Vwf', 'Myh11', 'Rgs4', 'Kcnj8', 'C1qa', 'Cd3e', 'Trem2', 'Adipoq', 'Nrxn1', 'Msln'], size = 0.6, legend_fontsize = 5, ncols = 4, cmap = 'RdPu')

### Visualise proportions

In [None]:
sc.pl.umap(adata, frameon = False, color = ['sample', 'genotype', 'C_scANVI'], size = 0.6, legend_fontsize = 5, ncols = 4)

In [None]:
bauhaus_colors = ['#FF0000', '#FFFF00', '#000000', '#4D5D53', '#0000FF', '#808080']
sc.pl.umap(adata, frameon = False, color = ['sample'], size = 0.6, legend_fontsize = 5, ncols = 4, palette = bauhaus_colors)

In [None]:
sc.pl.umap(adata, frameon = False, color = ['C_scANVI'], size = 0.6, legend_fontsize = 5, ncols = 4, palette = bauhaus_colors)

In [None]:
df = adata_toplot.obs.groupby(['genotype', 'cell_type']).size().reset_index(name = 'counts')

grouped = df.groupby('genotype')['counts'].apply(lambda x: x / x.sum() * 100)
grouped = grouped.reset_index()

df['proportions'] = grouped['counts']
df['waffle_counts'] = (df['proportions'] * 10).astype(int)

In [None]:
def generate_modified_bauhaus_palette(n_colors):
    # Define specific shades
    base_colors = [
        (1, 0, 0),      # Red
        (0.07, 0.04, 0.56),  # Ultramarine blue
        (0, 0.28, 0.67),  # Cobalt blue
        (1, 0.9, 0),  # Bauhaus yellow
        (0, 0, 0),  # Black
        (0.5, 0, 0.5),  # Purple
        (1, 0.55, 0),  # Orange
        (0.54, 0.17, 0.89),  # Violet
    ]

    # Create more distinct variations of each base color
    colors = []
    variation_steps = n_colors // len(base_colors) + 1
    for color in base_colors:
        for i in range(variation_steps):
            # Adjusting brightness and saturation
            variation = tuple(min(max(c * (0.8 + i * 0.1), 0), 1) for c in color)
            colors.append(variation)

    # Ensure we only use the number of colors needed
    unique_colors = []
    for color in colors:
        if color not in unique_colors:
            unique_colors.append(color)
        if len(unique_colors) == n_colors:
            break

    # Shuffle the unique colors
    np.random.shuffle(unique_colors)

    return unique_colors

# Generate the modified palette
bauhaus_palette = generate_modified_bauhaus_palette(18)

In [None]:
for group in df['genotype'].unique():
    temp_df = df[df['genotype'] == group]
    data = dict(zip(temp_df['cell_type'], temp_df['waffle_counts']))

    # Ensure the color list matches the number of cell types
    colors = [bauhaus_palette[i % len(bauhaus_palette)] for i in range(len(temp_df['cell_type']))]

    fig = plt.figure(
        FigureClass=Waffle, 
        rows=7, 
        values=data, 
        title={'label': f'Genotype {group}', 'loc': 'left', 'fontsize': 14},
        labels=[f"{k} ({v}%)" for k, v in zip(temp_df['cell_type'], temp_df['proportions'].round(2))],
        legend={'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0, 'fontsize': 14},
        figsize=(40, 4),
        colors=colors
    )
    plt.show()

### Export annotated sample object 

In [None]:
adata.obs.index = pd.Index(['-'.join(idx.split('-')[:2]) for idx in adata.obs.index])
adata.obs.index

In [None]:
adata_raw.obs.index = pd.Index(['-'.join(idx.split('-')[:2]) for idx in adata_raw.obs.index])
adata_raw.obs.index

In [None]:
adata.obs_names

In [None]:
adata.obs['C_scANVI'].cat.categories

In [None]:
adata.obs['C_scANVI'].value_counts()

### Export annotated object with raw counts

In [None]:
adata

In [None]:
adata_raw

In [None]:
adata_export = anndata.AnnData(X = adata_raw.X, obs = adata.obs, var = adata_raw.var)
adata_export.obsm['X_scVI'] = adata.obsm['X_scVI'].copy()
adata_export.obsm['X_umap'] = adata.obsm['X_umap'].copy()
adata_export.obsm['X_scANVI'] = adata.obsm['X_scANVI'].copy()
adata_export.obsp = adata.obsp.copy()
adata_export.uns = adata.uns.copy()
adata_export