In [1]:
import os
import warnings
warnings.filterwarnings('ignore')
import scanpy as sc
import scvi

In [3]:
os.chdir('../../')

# 1. Check samples

In [5]:
adata = sc.read_h5ad('output/extended_gbmap_filtered.h5ad')
adata

AnnData object with n_obs × n_vars = 819910 × 26302
    obs: 'author', 'donor_id', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'annotation_level_1', 'annotation_level_2', 'annotation_level_3', 'gbmap', 'method', 'stage', 'location', 'sector', 'celltype_original', 'EGFR', 'MET', 'p53', 'TERT', 'ATRX', 'PTEN', 'MGMT', 'chr1p19q', 'PDGFR', 'suspension_type', 'tissue_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_types', 'genome', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type'
    

In [7]:
adata.obs.annotation_level_3.value_counts()

annotation_level_3
TAM-BDM            179319
TAM-MG             154839
AC-like            121680
MES-like            94457
OPC-like            93483
NPC-like            37271
CD4/CD8             35399
Oligodendrocyte     30699
Mono                21719
OPC                 21447
Mural cell           5281
Endothelial          4980
Neutrophil           4778
DC                   4537
RG                   3753
Neuron               3413
B cell                857
Plasma B              694
NK                    499
Mast                  459
Astrocyte             346
Name: count, dtype: int64

In [8]:
adata.obs.cell_type.value_counts()

cell_type
malignant cell                    346891
macrophage                        179319
microglial cell                   154839
mature T cell                      35399
oligodendrocyte                    30699
monocyte                           21719
oligodendrocyte precursor cell     21447
mural cell                          5281
endothelial cell                    4980
neutrophil                          4778
dendritic cell                      4537
radial glial cell                   3753
neuron                              3413
B cell                               857
plasma cell                          694
natural killer cell                  499
mast cell                            459
astrocyte                            346
Name: count, dtype: int64

# 2. Integrate

In [None]:
sc.pp.normalize_total(source_adata)
sc.pp.log1p(source_adata)
source_adata.raw = source_adata # keep full dimension safe
sc.pp.highly_variable_genes(
    source_adata,
    n_top_genes=5000,
    batch_key="author",
    layer="counts",
    subset=True)

In [None]:
scvi.model.SCANVI.setup_anndata(adata, layer = "counts", batch_key="author", labels_key="cell_type")

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "patience": 10,
    "threshold": 0,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
    "early_stopping_metric": "accuracy",
    "save_best_state_metric": "accuracy",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}
early_stopping_kwargs_surgery = {
    "early_stopping_metric": "elbo",
    "save_best_state_metric": "elbo",
    "on": "full_dataset",
    "patience": 10,
    "threshold": 0.001,
    "reduce_lr_on_plateau": True,
    "lr_patience": 8,
    "lr_factor": 0.1,
}

In [None]:
model = scvi.models.SCANVI(
    adata,
    "Unknown",
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
    use_cuda = True # indicate to use gpu!
)

In [None]:
model.train(
    n_epochs_unsupervised=500,
    n_epochs_semisupervised=200,
    unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
    semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
                                       early_stopping_kwargs=early_stopping_kwargs_scanvi),
    frequency=1
)

In [None]:

adata.obsm["X_scVI"] = vae.get_latent_representation()
adata.obs["pred_label"] = vae.predict()

# 3. Reannotate