In [None]:
import os
import pickle
from aavomics import database
import anndata
import pandas
import numpy
import scvi
import scanpy

In [None]:
SEED = 1042
REFERENCE_DATA_FILE_PATH = os.path.join(database.DATA_PATH, "reference_databases", "20200331_Allen_Cortex_Hippocampus_10X_v3", "barcode_transcript_counts_filtered.h5ad")
CELL_TYPE_MAP_FILE_PATH = os.path.join(database.DATA_PATH, "reference_databases", "neuron_type_map_20210130.pkl")
VAE_MODEL_FILE_PATH = os.path.join(database.DATA_PATH, "reference_databases", "20200331_Allen_Cortex_Hippocampus_10X_v3", "vae_trained")
TRAINED_DATA_FILE_PATH = os.path.join(database.DATA_PATH, "reference_databases", "20200331_Allen_Cortex_Hippocampus_10X_v3", "vae_trained.h5ad")
SCANVI_TRAINED_DATA_FILE_PATH = os.path.join(database.DATA_PATH, "reference_databases", "20200331_Allen_Cortex_Hippocampus_10X_v3", "scanvi_trained.h5ad")

In [None]:
# Which alignment to use. Set to None to use the first available
ALIGNMENT_NAME = "cellranger_5.0.1_allen_premRNA"
LABELED_DATA_ALIGNMENT_NAME = "cellranger_5.0.1_gex_mm10_2020_A"

TAXONOMY_NAME = "CCN202105041"
NEW_TAXONOMY_NAME = "CCN202105050"

In [None]:
# Add custom label
with open(CELL_TYPE_MAP_FILE_PATH, "rb") as pickle_file:
    cell_type_map = pickle.load(pickle_file)

cell_type_categorical_type = pandas.CategoricalDtype(categories=[""] + list(set(cell_type_map.values())))

In [None]:
reference_data = anndata.read(REFERENCE_DATA_FILE_PATH)

reference_data.obs[NEW_TAXONOMY_NAME] = pandas.Series(dtype=cell_type_categorical_type)
reference_data.obs[NEW_TAXONOMY_NAME].loc[:] = ""

for cell_type, cell_type_label in cell_type_map.items():
    
    cell_type_mask = reference_data.obs["cell_type_alias_label"] == cell_type
    
    reference_data.obs[NEW_TAXONOMY_NAME][cell_type_mask] = cell_type_label

In [None]:
adatas = []
cell_set_names = []

for cell_set in database.CELL_SETS:
    
    labelled_anndata_file_path = cell_set.get_anndata_file_path(alignment_name=LABELED_DATA_ALIGNMENT_NAME)
    
    if not os.path.exists(labelled_anndata_file_path):
        print("Skipping %s, doesn't have anndata file" % cell_set.name)
        continue
        
    labelled_adata = anndata.read_h5ad(labelled_anndata_file_path, backed="r")
    
    if TAXONOMY_NAME not in labelled_adata.obs.columns:
        print("Skipping %s, not annotated with %s" % (cell_set.name, TAXONOMY_NAME))
        continue
    
    labelled_cell_set_adata = labelled_adata[(labelled_adata.obs[TAXONOMY_NAME] == "Neurons") & (labelled_adata.obs["Cell Called"] == "True")]
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=ALIGNMENT_NAME)
        
    adata = anndata.read_h5ad(anndata_file_path)
    adata = adata[labelled_cell_set_adata.obs.index].copy()
    adata.obs[NEW_TAXONOMY_NAME] = pandas.Series(dtype=cell_type_categorical_type)
    adata.obs[NEW_TAXONOMY_NAME].loc[:] = ""
    adata.obs["sample_id"] = cell_set.name
    
    adatas.append(adata)
    cell_set_names.append(cell_set.name)

In [None]:
sample_ids = set([x.split("-")[1] for x in reference_data.obs.index])
sample_ids.update(cell_set_names)

sample_ids_categorical_type = pandas.CategoricalDtype(categories=list(sample_ids))

reference_data.obs["sample_id"] = pandas.Series(dtype=sample_ids_categorical_type)

for sample_id in sample_ids:
    reference_data.obs["sample_id"][reference_data.obs.index.str.endswith(sample_id)] = sample_id

In [None]:
# Only keep genes where we have data from our reference dataset
gene_maxes = reference_data.X.max(axis=0)
gene_maxes = numpy.array(gene_maxes.todense()).flatten()
gene_mask = (gene_maxes > 0)

In [None]:
row_index = 0

for row in adatas[0].var.iterrows():
    
    ensembl_id = row[0]
    gene_name = row[1]["Gene Name"]
    
    if reference_data.var.index[row_index] != gene_name:
        if reference_data.var.index[row_index] != "%s %s" % (gene_name, ensembl_id):
            raise ValueError("Gene name doesn't match for %s at index %i! Abort!" % (gene_name, row_index))
    
    row_index += 1
    
reference_data.var = adatas[0].var

In [None]:
reference_data = reference_data.concatenate(adatas)
del adatas
combined_data = reference_data[:, gene_mask].copy()
del reference_data
del adata

In [None]:
reference_data_mask = ~combined_data.obs["sample_id"].isin(cell_set_names).values.flatten()
combined_data.obs[NEW_TAXONOMY_NAME].loc[~reference_data_mask] = ""
combined_data.obs[NEW_TAXONOMY_NAME] = combined_data.obs[NEW_TAXONOMY_NAME].astype(cell_type_categorical_type)

In [None]:
scvi.data.setup_anndata(combined_data, batch_key="sample_id", labels_key=NEW_TAXONOMY_NAME)

In [None]:
scanvi = scvi.model.SCANVI(
    combined_data,
    unlabeled_category="",
    n_latent=20,
    n_layers=2,
    n_hidden=256
)

results = scanvi.train(
    unsupervised_trainer_kwargs={
        "seed": SEED + 1
    },
    semisupervised_trainer_kwargs={
        "seed": SEED + 2,
        "n_iter_kl_warmup": 128*5000/400,
        "n_epochs_kl_warmup": None
    },
    balanced_sampling=True,
    frequency=1,
    n_epochs_kl_warmup=None,
    n_iter_kl_warmup=128*5000/400, # Based on documentation at https://www.scvi-tools.org/en/stable/api/reference/scvi.core.trainers.UnsupervisedTrainer.html
)

In [None]:
test_dataset = scanvi.trainer.classifier_trainer.train_test_validation()[1]
test_indices = test_dataset.indices

In [None]:
predicted_labels = scanvi.predict(combined_data)
prediction_scores = scanvi.predict(combined_data, soft=True)
prediction_scores_max = prediction_scores.max(axis=1)

In [None]:
print("Accuracy: %.2f%%" % (100*(predicted_labels[test_indices][reference_data_mask[test_indices]] == combined_data[test_indices].obs[NEW_TAXONOMY_NAME][reference_data_mask[test_indices]]).sum()/reference_data_mask[test_indices].sum()))

In [None]:
combined_data.obs[NEW_TAXONOMY_NAME] = predicted_labels
combined_data.obs["p_%s" % NEW_TAXONOMY_NAME] = prediction_scores_max

In [None]:
for cell_set_index, cell_set_name in enumerate(cell_set_names):
        
    print(cell_set_name)
    
    cell_set = database.CELL_SETS_DICT[cell_set_name]
    
    anndata_file_path = cell_set.get_anndata_file_path(alignment_name=LABELED_DATA_ALIGNMENT_NAME)
    
    if not os.path.exists(anndata_file_path):
        print("Missing %s, skipping" % cell_set.name)
        continue
    
    cell_set_adata = anndata.read(anndata_file_path)
    
    if NEW_TAXONOMY_NAME in cell_set_adata.obs.columns:
        cell_set_adata.obs.drop(NEW_TAXONOMY_NAME, axis=1, inplace=True)
    if "p_%s" % NEW_TAXONOMY_NAME in cell_set_adata.obs.columns:
        cell_set_adata.obs.drop("p_%s" % NEW_TAXONOMY_NAME, axis=1, inplace=True)
        
    adata_filtered = combined_data[combined_data.obs["sample_id"] == cell_set_name]
    adata_filtered.obs.index = ["-".join(x.split("-")[0:-1]) for x in adata_filtered.obs.index]
    cell_set_adata.obs[NEW_TAXONOMY_NAME] = adata_filtered.obs[NEW_TAXONOMY_NAME]
    cell_set_adata.obs["p_%s" % NEW_TAXONOMY_NAME] = adata_filtered.obs["p_%s" % NEW_TAXONOMY_NAME]
    
    cell_set_adata.write_h5ad(anndata_file_path)