### Notebook for the integration and label transfer of CD45+ lymphoid immune cells with `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology - Faculty of Medicine - Julius-Maximilian-Universität Würzburg**
- v231212

### Import required modules

In [1]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt
from scib_metrics.benchmark import Benchmarker

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

### Read in and format individual cell compartments

- Read in CD45+ immune cells 

In [None]:
MDX_MPC = sc.read_h5ad('../../data/MDX_MPC.h5ad')
MDX_MPC.obs['genotype'] = MDX_MPC.obs['Condition'].copy()
MDX_MPC.obs['sample'] = MDX_MPC.obs['hash.ID'].copy()
MDX_MPC.obs['donor'] = MDX_MPC.obs['Sample'].copy()
MDX_MPC.obs['seed_labels'] = MDX_MPC.obs['MPC_Annotation'].copy()
MDX_MPC.obs['cell_source'] = 'HIRI-CD45+'
MDX_MPC

In [None]:
MDX_POOL_NEW = sc.read_h5ad('../../data/MDX_POOL_NEW.h5ad')
MDX_POOL_NEW.obs['genotype'] = MDX_POOL_NEW.obs['Condition'].copy()
MDX_POOL_NEW.obs['sample'] = MDX_POOL_NEW.obs['hash.ID'].copy()
MDX_POOL_NEW.obs['donor'] = MDX_POOL_NEW.obs['Sample'].copy()
MDX_POOL_NEW.obs['seed_labels'] = MDX_POOL_NEW.obs['CD45_Annotation'].copy()
MDX_POOL_NEW.obs['cell_source'] = 'HIRI-CD45+'
MDX_POOL_NEW

- Read in Lymphoid cells

In [None]:
Lymphoid_scANVI = sc.read_h5ad('../../data/heart_mm_nuclei-23-0092_scANVI-Lymphoid_ctl231128.raw.h5ad')
Lymphoid_scANVI.obs_names = [name.rsplit('-', 2)[0] for name in Lymphoid_scANVI.obs_names]
Lymphoid_scANVI.obs['seed_labels'] = Lymphoid_scANVI.obs['C_scANVI'].copy()
Lymphoid_scANVI

### Merge cell compartments and compare with full object

In [None]:
compartments = MDX_POOL_NEW.concatenate(Lymphoid_scANVI, MDX_MPC, 
                                      batch_key = 'compartment', 
                                      batch_categories = ['MDX_POOL_NEW', 'Lymphoid_scANVI', 'MDX_MPC'], 
                                      join = 'inner')
compartments.obs_names = [name.rsplit('-', 1)[0] for name in compartments.obs_names]
compartments

### Group fine grained annotations into coarse groups

In [None]:
compartments.obs['seed_labels'].cat.categories

- Remove low quality or lymphoid cells

In [None]:
compartments_clean = compartments[~compartments.obs['seed_labels'].isin(['A_Res_Mac_MHCII', 'B_TLF_Mac',
       'C_Ccr2+MHCII+_Mac', 'D_Inflammatory_Mac', 'G_Ly6Chi_Mo', 'H_Ly6Clow_Mo', 'I_cDC2', 'J_Mature_DC', 
       'K_nd1', 'L_nd2', 'Ly6Chi_Mono', 'Ly6Clo_Mono', 'M_low_quality', 'Macrophages', 'Mast/Baso', 'Neutrophils', 
       'Proliferating', 'gd_T', 'low_quality_cells', 'myeloid??', 'pDC', 'cDC2'])]
compartments_clean

### Make uniform annotation for genotype labels

In [None]:
compartments_clean.obs['genotype'].value_counts()

In [None]:
trans_from=[['MdxSCID', 'Mdx-SCID'],
['Ctrl', 'WT']]
trans_to = ['MdxSCID', 'WT']

compartments_clean.obs['genotype'] = [str(i) for i in compartments_clean.obs['genotype']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        compartments_clean.obs['genotype'][compartments_clean.obs['genotype'] == leiden_from] = celltype

In [None]:
compartments_clean.obs['genotype'].value_counts()

In [None]:
del(compartments_clean.obs['C_scANVI'])
compartments_clean

### Visualise cell type distribution per condition

In [None]:
compartments_clean.obs['seed_labels'].value_counts()

### Hamonise cell type labels

In [None]:
compartments_clean.obs['seed_labels'].cat.categories

In [None]:
trans_from = [['Conventional_T', 'E_Isg15+', 'F_Spp1+Gpnmb+', 'non-immune']]
trans_to = ['Unknown']

compartments_clean.obs['seed_labels'] = [str(i) for i in compartments_clean.obs['seed_labels']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        compartments_clean.obs['seed_labels'][compartments_clean.obs['seed_labels'] == leiden_from] = celltype

- The label 'Macrophages' is very unspecific and takes a lot of cells. To allow for other detailed annotations to be mapped, I will label them as 'Unknown'

In [None]:
compartments_clean.obs['seed_labels'] = compartments_clean.obs['seed_labels'].astype('category')
compartments_clean.obs['seed_labels'].cat.categories

In [None]:
pd.crosstab(compartments_clean.obs['seed_labels'], compartments_clean.obs['genotype'])

### Read in other unannotated dataset and split into groups

In [None]:
compartments_clean.obs['genotype'].value_counts()

In [None]:
reference = compartments_clean[compartments_clean.obs['genotype'].isin(['MdxSCID', 'WT'])]
reference

In [None]:
query = compartments_clean[~compartments_clean.obs['genotype'].isin(['MdxSCID', 'WT'])]
query.obs['seed_labels'] = 'Unknown'
query

In [None]:
adata = reference.concatenate(query, batch_key = 'batch', batch_categories = ['reference', 'query'], join = 'inner')
adata

- Clean merged object

In [None]:
del(adata.obs['Condition'])
del(adata.obs['hash.ID'])
del(adata.obs['Sample'])
del(adata.obs['HTO_maxID'])
del(adata.obs['HTO_secondID'])
del(adata.obs['HTO_margin'])
del(adata.obs['nCount_RNA'])
del(adata.obs['nFeature_RNA'])
del(adata.obs['nCount_ADT'])
del(adata.obs['nFeature_ADT'])
del(adata.obs['percent.mt'])
del(adata.obs['RNA_snn_res.0.2'])
del(adata.obs['seurat_clusters'])
del(adata.obs['RNA_snn_res.0.3'])
del(adata.obs['RNA_snn_res.0.5'])
adata

In [None]:
sc.pl.scatter(adata, x = 'total_counts', y = 'pct_counts_mt', color = "batch", frameon = False)

In [None]:
adata.obs['genotype'].value_counts()

In [None]:
adata.obs['sample'].value_counts()

### Select HVGs

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "sample",
    subset = True
)
adata

### Transfer of annotation with scANVI

In [None]:
scvi.model.SCVI.setup_anndata(adata,
                              batch_key = "sample", 
                            categorical_covariate_keys = ["donor", "cell_source"], 
                            labels_key = "seed_labels", 
                            layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata, 
                             n_latent = 50, 
                             n_layers = 3, 
                             dispersion = 'gene-batch', 
                             gene_likelihood = 'nb')

In [None]:
scvi_model.train(20, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 use_gpu = 0)

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation(adata)

### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(20, 
                   check_val_every_n_epoch = 1, 
                   enable_progress_bar = True, 
                   use_gpu = 0)

In [None]:
adata.obs["C_scANVI_new"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

### Explore model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.6, spread = 6, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['genotype', 'cell_source', 'C_scANVI_new', 'sample', 'seed_labels'], size = 1.5, legend_fontsize = 5, ncols = 3)

In [None]:
sc.pl.umap(adata, frameon = False, color = ['n_genes', 'doublet_scores', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'n_counts'], size = 1, legend_fontsize = 5, ncols = 4, cmap = 'magma')

### Modify object to plot canonical marker genes

In [None]:
adata_toplot = anndata.AnnData(X = np.sqrt(sc.pp.normalize_total(adata_raw, inplace = False)["X"]), var = adata_raw.var, obs = adata.obs, obsm = adata.obsm)
adata_toplot

In [None]:
sc.pl.umap(adata_toplot, frameon = False, color = ['C_scANVI_new', 'Ttn', 'Nppa', 'Dcn', 'Vwf', 'Myh11', 'Rgs4', 'Kcnj8', 'C1qa', 'Cd3e', 'Trem2', 'Adipoq', 'Nrxn1', 'Msln'], size = 1, legend_fontsize = 5, ncols = 4, cmap = 'RdPu')

### Visualise proportions

In [None]:
sc.pl.umap(adata, frameon = False, color = ['sample', 'genotype', 'C_scANVI_new'], size = 0.6, legend_fontsize = 5, ncols = 4)

In [None]:
bauhaus_colors = ['#FF0000', '#FFFF00', '#000000', '#4D5D53', '#0000FF', '#808080']

In [None]:
sc.pl.umap(adata, frameon = False, color = ['C_scANVI_new'], size = 0.6, legend_fontsize = 5, ncols = 4, palette = bauhaus_colors)

In [None]:
df = adata_toplot.obs.groupby(['genotype', 'C_scANVI_new']).size().reset_index(name = 'counts')

grouped = df.groupby('genotype')['counts'].apply(lambda x: x / x.sum() * 100)
grouped = grouped.reset_index()

df['proportions'] = grouped['counts']
df['waffle_counts'] = (df['proportions'] * 10).astype(int)

In [None]:
def generate_modified_bauhaus_palette(n_colors):
    # Define specific shades
    base_colors = [
        (1, 0, 0),      # Red
        (0.07, 0.04, 0.56),  # Ultramarine blue
        (0, 0.28, 0.67),  # Cobalt blue
        (1, 0.9, 0),  # Bauhaus yellow
        (0, 0, 0),  # Black
        (0.5, 0, 0.5),  # Purple
        (1, 0.55, 0),  # Orange
        (0.54, 0.17, 0.89),  # Violet
    ]

    # Create more distinct variations of each base color
    colors = []
    variation_steps = n_colors // len(base_colors) + 1
    for color in base_colors:
        for i in range(variation_steps):
            # Adjusting brightness and saturation
            variation = tuple(min(max(c * (0.8 + i * 0.1), 0), 1) for c in color)
            colors.append(variation)

    # Ensure we only use the number of colors needed
    unique_colors = []
    for color in colors:
        if color not in unique_colors:
            unique_colors.append(color)
        if len(unique_colors) == n_colors:
            break

    # Shuffle the unique colors
    np.random.shuffle(unique_colors)

    return unique_colors

# Generate the modified palette
bauhaus_palette = generate_modified_bauhaus_palette(18)

In [None]:
for group in df['genotype'].unique():
    temp_df = df[df['genotype'] == group]
    data = dict(zip(temp_df['C_scANVI_new'], temp_df['waffle_counts']))

    # Ensure the color list matches the number of cell types
    colors = [bauhaus_palette[i % len(bauhaus_palette)] for i in range(len(temp_df['C_scANVI_new']))]

    fig = plt.figure(
        FigureClass=Waffle, 
        rows=7, 
        values=data, 
        title={'label': f'Genotype {group}', 'loc': 'left', 'fontsize': 14},
        labels=[f"{k} ({v}%)" for k, v in zip(temp_df['C_scANVI_new'], temp_df['proportions'].round(2))],
        legend={'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0, 'fontsize': 14},
        figsize=(40, 4),
        colors=colors
    )
    plt.show()

### Export annotated sample object 

In [None]:
adata.obs.index = pd.Index(['-'.join(idx.split('-')[:2]) for idx in adata.obs.index])
adata.obs.index

In [None]:
adata_raw.obs.index = pd.Index(['-'.join(idx.split('-')[:2]) for idx in adata_raw.obs.index])
adata_raw.obs.index

In [None]:
adata.obs_names

In [None]:
adata.obs['C_scANVI_new'].cat.categories

In [None]:
adata.obs['C_scANVI_new'].value_counts()

### Export annotated object with raw counts

In [None]:
adata

In [None]:
adata_raw

In [None]:
adata_export = anndata.AnnData(X = adata_raw.X, obs = adata.obs, var = adata_raw.var)
adata_export.obsm['X_scVI'] = adata.obsm['X_scVI'].copy()
adata_export.obsm['X_umap'] = adata.obsm['X_umap'].copy()
adata_export.obsm['X_scANVI'] = adata.obsm['X_scANVI'].copy()
adata_export.obsp = adata.obsp.copy()
adata_export.uns = adata.uns.copy()
adata_export