### Notebook for the integration and label transfer of myeloid immune cells with `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology - Faculty of Medicine - Julius-Maximilian-Universität Würzburg**
- **Created on**: 240329
- **Last modified**: 240402

### Import required modules

In [1]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt
from scib_metrics.benchmark import Benchmarker

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

### Read in and format individual datasets compartments

- Read in CD45+ immune cells 

In [None]:
MDX_MPC = sc.read_h5ad('../../../data/MDX_MPC.h5ad')
MDX_MPC.obs['genotype'] = MDX_MPC.obs['Condition'].copy()
MDX_MPC.obs['sample'] = MDX_MPC.obs['hash.ID'].copy()
MDX_MPC.obs['donor'] = MDX_MPC.obs['Sample'].copy()
MDX_MPC.obs['seed_labels'] = MDX_MPC.obs['MPC_Annotation'].copy()
MDX_MPC.obs['cell_source'] = 'HIRI-CD45+'
MDX_MPC

In [None]:
MDX_MPC = MDX_MPC[MDX_MPC.obs['nFeature_RNA'] > 200]
MDX_MPC = MDX_MPC[MDX_MPC.obs['nFeature_RNA'] < 4000]

MDX_MPC = MDX_MPC[MDX_MPC.obs['nCount_RNA'] < 15000]
MDX_MPC = MDX_MPC[MDX_MPC.obs['nCount_RNA'] > 200]

MDX_MPC = MDX_MPC[MDX_MPC.obs['percent.mt'] < 60]

MDX_MPC

In [None]:
MDX_POOL_NEW = sc.read_h5ad('../../../data/MDX_POOL_NEW.h5ad')
MDX_POOL_NEW.obs['genotype'] = MDX_POOL_NEW.obs['Condition'].copy()
MDX_POOL_NEW.obs['sample'] = MDX_POOL_NEW.obs['hash.ID'].copy()
MDX_POOL_NEW.obs['donor'] = MDX_POOL_NEW.obs['Sample'].copy()
MDX_POOL_NEW.obs['seed_labels'] = MDX_POOL_NEW.obs['CD45_Annotation'].copy()
MDX_POOL_NEW.obs['cell_source'] = 'HIRI-CD45+'
MDX_POOL_NEW

In [None]:
MDX_POOL_NEW = MDX_POOL_NEW[MDX_POOL_NEW.obs['nFeature_RNA'] > 200]
MDX_POOL_NEW = MDX_POOL_NEW[MDX_POOL_NEW.obs['nFeature_RNA'] < 4000]

MDX_POOL_NEW = MDX_POOL_NEW[MDX_POOL_NEW.obs['nCount_RNA'] < 15000]
MDX_POOL_NEW = MDX_POOL_NEW[MDX_POOL_NEW.obs['nCount_RNA'] > 200]

MDX_POOL_NEW = MDX_POOL_NEW[MDX_POOL_NEW.obs['percent.mt'] < 60]

MDX_POOL_NEW

- Read in Meyer annotated PBMC cells

In [None]:
pbmc = sc.read_h5ad('../../../data/meyer_nikolic_healthy_pbmc_raw.h5ad') 
pbmc

In [None]:
pbmc_healthy = pbmc[pbmc.obs['COVID_status'].isin(['Healthy'])]
pbmc_healthy.obs['sample'] = pbmc_healthy.obs['sample_id'].copy()
pbmc_healthy.obs['seed_labels'] = pbmc_healthy.obs['annotation_detailed'].copy()
pbmc_healthy.obs['cell_source'] = 'Sanger-Cells'
pbmc_healthy.obs['genotype'] = 'human'
pbmc_healthy

In [None]:
pbmc_healthy.var_names = [gene_name.capitalize() for gene_name in pbmc_healthy.var_names]
pbmc_healthy.var_names

### Read in DMD query cells

In [None]:
DMD_scANVI = sc.read_h5ad('../../../data/heart_mm_nuclei-23-0092_scANVI-Immune_ctl240329.raw.h5ad')
DMD_scANVI

In [None]:
DMD_scANVI.obs['C_scANVI_new'].cat.categories

In [None]:
DMD_myeloid = DMD_scANVI[DMD_scANVI.obs['C_scANVI_new'].isin(['Baso/Eos', 'CD14+Mo', 'CD16+Mo', 'CD56+NK', 'Ccr2+MHCII+MØ', 'DC', 'ILC', 'Isg15+MØ', 'Ly6ChiMo', 'Ly6CloMo', 'MHCII+MØtr', 'Mast', 'MØinf', 'NK', 'NKT', 'Neutrophils', 'Platelets', 'TLF+MØ', 'gdT', 'pDC'])]
DMD_myeloid.obs['seed_labels'] = 'Unknown'
del(DMD_myeloid.obs['C_scANVI_new'])
DMD_myeloid

### Merge cell compartments and compare with full object

In [None]:
compartments = pbmc_healthy.concatenate(DMD_myeloid, MDX_POOL_NEW, MDX_MPC,  
                                      batch_key = 'compartment', 
                                      batch_categories = ['pbmc_meyer', 'DMD', 'MDX_POOL_NEW', 'MDX_MPC'], 
                                      join = 'inner')
compartments

In [None]:
compartments.obs_names

In [None]:
adata_export_raw = DMD_myeloid.concatenate(MDX_POOL_NEW, MDX_MPC,  
                                      batch_key = 'data_sourcecompartment', 
                                      batch_categories = ['DMD', 'MDX_POOL_NEW', 'MDX_MPC'], 
                                      join = 'inner')
adata_export_raw

In [None]:
adata_export_raw.obs_names

### Group fine grained annotations into coarse groups

In [None]:
compartments.obs['seed_labels'] = compartments.obs['seed_labels'].astype('category')
compartments.obs['seed_labels'].cat.categories

### Remove low quality or irrelevant labels

In [None]:
compartments_clean = compartments[~compartments.obs['seed_labels'].isin(['B invar', 'B n-sw mem',
       'B n-sw mem IFN stim', 'B naive', 'B naive IFN stim', 'B sw mem', 'B_Cells', 'Conventional_T', 'MAIT', 'M_low_quality', 'Plasma cells', 
       'Plasmablasts', 'RBC', 'T CD4 CTL', 'T CD4 helper', 'T CD4 naive', 'T CD4 naive IFN stim', 'T CD8 CM', 'T CD8 CTL', 'T CD8 CTL IFN stim', 
       'T CD8 EM', 'T CD8 EMRA', 'T CD8 naive', 'T reg', 'low_quality_cells', 'myeloid??', 'non-immune'])]
compartments_clean

### Make uniform annotation for genotype labels

In [None]:
compartments_clean.obs['genotype'].value_counts()

In [None]:
trans_from=['MdxSCID', 'Mdx-SCID'],
trans_to = ['MdxSCID']

compartments_clean.obs['genotype'] = [str(i) for i in compartments_clean.obs['genotype']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        compartments_clean.obs['genotype'][compartments_clean.obs['genotype'] == leiden_from] = celltype

In [None]:
compartments_clean.obs['genotype'].value_counts()

### Visualise cell type distribution per condition

In [None]:
compartments_clean.obs['seed_labels'].value_counts()

### Hamonise cell type labels

In [None]:
compartments_clean.obs['seed_labels'].cat.categories

In [None]:
trans_from = [['pDC'],
              ['AS-DC', 'J_Mature_DC'],
              ['cDC1'],
              ['cDC2', 'I_cDC2'],
              ['Mast/Baso'],
              ['Monocyte CD14', 'Monocyte CD14 IFN stim', 'Monocyte CD14 IL6'],
              ['Monocyte CD16', 'Monocyte CD16 IFN stim', 'Monocyte CD16+C1'],
              ['G_Ly6Chi_Mo', 'Ly6Chi_Mono'],
              ['H_Ly6Clow_Mo', 'Ly6Clo_Mono'],
              ['A_Res_Mac_MHCII'], 
              ['B_TLF_Mac'], 
              ['C_Ccr2+MHCII+_Mac'], 
              ['D_Inflammatory_Mac'], 
              ['E_Isg15+'], 
              ['F_Spp1+Gpnmb+'],
              ['NK', 'NK IFN stim'],
              ['NK CD56'],
              ['NK_CD16hi'],
              ['NKT'],
              ['DOCK4+vMØ'],
              ['ILC', 'ILC2'],
              ['HPC', 'HPC IFN stim'],
              ['T g/d', 'gd_T'],
              ['Neutrophils', 'nøMo'],
              ['K_nd1', 'L_nd2', 'Unknown', 'Conventional_T', 'Macrophages', 'MAIT', 'Proliferating', 'gdT', 'T g/d', 'gd_T', 'Cycling', 'Platelets']]

trans_to = ['pDC', 'DC', 'DC1', 'DC2', 'Mast', 'CD14+Mo', 'CD16+Mo', 'Ly6ChiMo', 'Ly6CloMo',  'MHCII+MØtr', 'TLF+MØ', 'Ccr2+MHCII+MØ', 'MØinf', 'Isg15+MØ', 'Spp1+Gpnmb+MØ', 'NK', 'CD56+NK', 'CD16+NK', 'NKT','DOCK4+MØ', 'ILC', 'HPC', 'gdT', 'NØ', 'Unknown']

compartments_clean.obs['seed_labels'] = [str(i) for i in compartments_clean.obs['seed_labels']]
for leiden,celltype in zip(trans_from, trans_to):
    for leiden_from in leiden:
        compartments_clean.obs['seed_labels'][compartments_clean.obs['seed_labels'] == leiden_from] = celltype

- The label 'Macrophages' is very unspecific and takes a lot of cells. To allow for other detailed annotations to be mapped, I will label them as 'Unknown'

In [None]:
compartments_clean.obs['seed_labels'] = compartments_clean.obs['seed_labels'].astype('category')
compartments_clean.obs['seed_labels'].cat.categories

In [None]:
pd.crosstab(compartments_clean.obs['seed_labels'], compartments_clean.obs['genotype'])

### Read in other unannotated dataset and split into groups

In [None]:
compartments_clean.obs['genotype'].value_counts()

In [None]:
reference = compartments_clean[compartments_clean.obs['genotype'].isin(['Ctrl', 'human'])]
reference

In [None]:
query = compartments_clean[~compartments_clean.obs['genotype'].isin(['Ctrl', 'human'])]
query.obs['seed_labels'] = 'Unknown'
query

In [None]:
adata = reference.concatenate(query, batch_key = 'batch', batch_categories = ['reference', 'query'], join = 'inner')
adata

In [None]:
adata.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in adata.obs.index])
adata.obs_names

- Clean merged object

In [None]:
del(adata.obs['Condition'])
del(adata.obs['hash.ID'])
del(adata.obs['Sample'])
del(adata.obs['HTO_maxID'])
del(adata.obs['HTO_secondID'])
del(adata.obs['HTO_margin'])
del(adata.obs['nCount_RNA'])
del(adata.obs['nFeature_RNA'])
del(adata.obs['nCount_ADT'])
del(adata.obs['nFeature_ADT'])
del(adata.obs['percent.mt'])
del(adata.obs['RNA_snn_res.0.2'])
del(adata.obs['seurat_clusters'])
del(adata.obs['RNA_snn_res.0.3'])
del(adata.obs['RNA_snn_res.0.5'])
adata

In [None]:
sc.pl.scatter(adata, x = 'total_counts', y = 'n_genes', color = "genotype", frameon = False)

In [None]:
adata.obs['genotype'].value_counts()

In [None]:
adata.obs['sample'].value_counts()

### Select HVGs

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 8000,
    layer = "counts",
    batch_key = "sample",
    subset = True
)
adata

### Transfer of annotation with scANVI

In [None]:
scvi.model.SCVI.setup_anndata(adata,
                              batch_key = "sample", 
                            categorical_covariate_keys = ["donor", "cell_source"], 
                            labels_key = "seed_labels", 
                            layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata, 
                             n_latent = 50, 
                             n_layers = 3, 
                             dispersion = 'gene-batch', 
                             gene_likelihood = 'nb')

In [None]:
scvi_model.train(30, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = 'gpu',
                 devices = [1])

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation(adata)

### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(20, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = 'gpu',
                 devices = [1])

In [None]:
adata.obs["C_scANVI_S1"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

### Explore model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.3, spread = 1, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['genotype', 'C_scANVI_S1', 'cell_source', 'seed_labels', 'sample'], size = 0.8, legend_fontsize = 5, ncols = 3)

In [None]:
sc.pl.umap(adata, frameon = False, color = ['n_genes', 'doublet_scores', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'n_counts'], size = 1, legend_fontsize = 5, ncols = 4, cmap = 'magma')

### Modify object to plot canonical marker genes

In [None]:
adata_toplot = anndata.AnnData(X = np.sqrt(sc.pp.normalize_total(adata_raw, inplace = False)["X"]), var = adata_raw.var, obs = adata.obs, obsm = adata.obsm)
adata_toplot

In [None]:
sc.pl.umap(adata_toplot, frameon = False, color = ['C_scANVI_S1', 'C1qa', 'Cd3e', 'Trem2', 'Adipoq', 'Nrxn1', 'Msln'], size = 1, legend_fontsize = 5, ncols = 4, cmap = 'RdPu')

### Visualise proportions

In [None]:
sc.pl.umap(adata, frameon = False, color = ['sample', 'genotype', 'C_scANVI_S1'], size = 0.6, legend_fontsize = 5, ncols = 4)

In [None]:
bauhaus_colors = ['#FF0000', '#FFFF00', '#000000', '#4D5D53', '#0000FF', '#808080']

In [None]:
sc.pl.umap(adata, frameon = False, color = ['genotype'], size = 0.6, legend_fontsize = 5, ncols = 4, palette = bauhaus_colors)

In [None]:
df = adata_toplot.obs.groupby(['genotype', 'C_scANVI_S1']).size().reset_index(name = 'counts')

grouped = df.groupby('genotype')['counts'].apply(lambda x: x / x.sum() * 100)
grouped = grouped.reset_index()

df['proportions'] = grouped['counts']
df['waffle_counts'] = (df['proportions'] * 10).astype(int)

In [None]:
def generate_modified_bauhaus_palette(n_colors):
    # Define specific shades
    base_colors = [
        (1, 0, 0),      # Red
        (0.07, 0.04, 0.56),  # Ultramarine blue
        (0, 0.28, 0.67),  # Cobalt blue
        (1, 0.9, 0),  # Bauhaus yellow
        (0, 0, 0),  # Black
        (0.5, 0, 0.5),  # Purple
        (1, 0.55, 0),  # Orange
        (0.54, 0.17, 0.89),  # Violet
    ]

    # Create more distinct variations of each base color
    colors = []
    variation_steps = n_colors // len(base_colors) + 1
    for color in base_colors:
        for i in range(variation_steps):
            # Adjusting brightness and saturation
            variation = tuple(min(max(c * (0.8 + i * 0.1), 0), 1) for c in color)
            colors.append(variation)

    # Ensure we only use the number of colors needed
    unique_colors = []
    for color in colors:
        if color not in unique_colors:
            unique_colors.append(color)
        if len(unique_colors) == n_colors:
            break

    # Shuffle the unique colors
    np.random.shuffle(unique_colors)

    return unique_colors

# Generate the modified palette
bauhaus_palette = generate_modified_bauhaus_palette(18)

In [None]:
for group in df['genotype'].unique():
    temp_df = df[df['genotype'] == group]
    data = dict(zip(temp_df['C_scANVI_S1'], temp_df['waffle_counts']))

    # Ensure the color list matches the number of cell types
    colors = [bauhaus_palette[i % len(bauhaus_palette)] for i in range(len(temp_df['C_scANVI_S1']))]

    fig = plt.figure(
        FigureClass=Waffle, 
        rows=7, 
        values=data, 
        title={'label': f'Genotype {group}', 'loc': 'left', 'fontsize': 14},
        labels=[f"{k} ({v}%)" for k, v in zip(temp_df['C_scANVI_S1'], temp_df['proportions'].round(2))],
        legend={'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0, 'fontsize': 14},
        figsize=(40, 4),
        colors=colors
    )
    plt.show()

### Export annotated sample object 

- Fix label for cell annotation

In [None]:
adata.obs['cell_type'] = adata.obs['C_scANVI_S1'].copy()
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')
adata.obs['cell_type'].cat.categories

In [None]:
adata

In [None]:
adata_query = adata[adata.obs['compartment'].isin(['MDX_POOL_NEW', 'DMD', 'MDX_MPC'])]
adata_query

In [None]:
pd.crosstab(adata_query.obs['cell_type'], adata_query.obs['genotype'])

### Prepare object for export

In [None]:
adata_query = adata[adata.obs['compartment'].isin(['MDX_POOL_NEW', 'DMD', 'MDX_MPC'])]
adata_query

In [None]:
adata_raw.obs['cell_type'] = adata.obs['cell_type'].copy()
adata_raw.obs['cell_type'].cat.categories

In [None]:
adata_export = adata_raw[adata_raw.obs['compartment'].isin(['MDX_POOL_NEW', 'DMD', 'MDX_MPC'])]
adata_export

In [None]:
pd.crosstab(adata.obs['cell_type'], adata.obs['genotype'])

### Export annotated object with raw counts

In [None]:
adata_query.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in adata_query.obs.index])
adata_query.obs_names

In [None]:
adata_export

In [None]:
adata_export_full = anndata.AnnData(X = adata_export_raw.X, obs = adata_export.obs, var = adata_export_raw.var)
adata_export_full

In [None]:
adata_query

In [None]:
#adata_export_raw.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in adata_export_raw.obs.index])
adata_export_raw.obs_names

In [None]:
adata_export_raw

In [None]:
common_obs_names = adata_export_raw.obs_names[adata_export_raw.obs_names.isin(adata_query.obs_names)]
adata_subset = adata_export_raw[common_obs_names]
adata_subset

In [None]:
adata_subset

In [None]:
adata_export = anndata.AnnData(X = adata_subset.X, obs = adata_query.obs, var = adata_subset.var)
adata_export

In [None]:
pd.crosstab(adata_query.obs['cell_type'], adata_query.obs['genotype'])

In [None]:
adata_export.write('../../data/heart_mm_nuclei-23-0092_scANVI-Myeloid_ctl240402.raw.h5ad')