# Annotate w1118_42D cells with cell type labels

In [None]:
!python --version

## Load required packages 

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import torch
import scvi
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
from scipy.sparse import csr_matrix
from scipy.stats import median_abs_deviation
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
import gdown
import copy as cp
import os


In [None]:
print(os.getenv("CUDA_VISIBLE_DEVICES"))
os.environ["CUDA_VISIBLE_DEVICES"]=os.getenv("CUDA_VISIBLE_DEVICES")

In [None]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

In [None]:
sns.set_style('white')
sns.set(font_scale=1.5)
sc.settings.set_figure_params(dpi=80, facecolor="white")
sc.logging.print_header()
sc.settings.verbosity = 3

## Read w1118_42d query dataset

In [None]:

## Set up I/O directories
basepath = Path("/projectnb/mccall/sbandyadka/drpr42d_snrnaseq/")
inputpath = basepath.joinpath('analysis','preprocess')
referencepath = basepath.joinpath('reference','FCA')
outputpath = basepath.joinpath('analysis','scarches')


In [None]:
## Read sctk-qc h5ad 
w1118_42d = sc.read_h5ad(inputpath.joinpath("w1118_42d_slim.h5ad"))
w1118_42d.raw = w1118_42d

w1118_42d

In [None]:
w1118_42d.obs['age'] = 42
w1118_42d.obs['fly_genetics'] = "w1118"
w1118_42d.obs['sex'] = "mixed"
w1118_42d.obs['dissection_lab'] = "mccall"
w1118_42d.obs['tissue'] = "head"


In [None]:
head_model = "/projectnb/mccall/sbandyadka/drpr42d_snrnaseq/analysis/scarches/head_model/"
reference_latent_head = sc.read_h5ad(head_model+'reference_latent.h5ad')


antenna_model = "/projectnb/mccall/sbandyadka/drpr42d_snrnaseq/analysis/scarches/antenna_model/"
reference_latent_antenna = sc.read_h5ad(antenna_model+'reference_latent.h5ad')

In [None]:
reference_latent_head.obs['broad_annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent_head.obs['broad_annotation'], dtype= str), '-'),
                                             np.array(reference_latent_head.obs['batch'], dtype=str))
reference_latent_head.obs['annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent_head.obs['annotation'], dtype= str), '-'),
                                             np.array(reference_latent_head.obs['batch'], dtype=str))

reference_latent_antenna.obs['broad_annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent_antenna.obs['broad_annotation'], dtype= str), '-'),
                                             np.array(reference_latent_antenna.obs['batch'], dtype=str))
reference_latent_antenna.obs['annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent_antenna.obs['annotation'], dtype= str), '-'),
                                             np.array(reference_latent_antenna.obs['batch'], dtype=str))

In [None]:
head_batchorder = reference_latent_head.obs.batch.unique().tolist()
head_tree_ref_broadannotation, head_mp_ref_broadannotation = sca.classifiers.scHPL.learn_tree(data = reference_latent_head,
                batch_key = 'batch',
                batch_order = head_batchorder ,
                cell_type_key='broad_annotation_batch',
                classifier = 'knn', dynamic_neighbors=True,
                dimred = False, print_conf= False)

# Archive

In [None]:
common_genes = list(set(w1118_42d.var_names) & set(fca_reference.var_names))
len(common_genes)

fca_reference = fca_reference[:,common_genes ]
fca_reference = fca_reference.copy()
fca_reference

In [None]:
reference_latent.obs['broad_annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent.obs['broad_annotation'], dtype= str), '-'),
                                             np.array(reference_latent.obs['batch'], dtype=str))
reference_latent.obs['annotation_batch'] = np.char.add(np.char.add(np.array(reference_latent.obs['annotation'], dtype= str), '-'),
                                             np.array(reference_latent.obs['batch'], dtype=str))

In [None]:
batchorder = reference_latent.obs.batch.unique().tolist()
tree_ref_broadannotation, mp_ref_broadannotation = sca.classifiers.scHPL.learn_tree(data = reference_latent,
                batch_key = 'batch',
                batch_order = batchorder ,
                cell_type_key='broad_annotation_batch',
                classifier = 'knn', dynamic_neighbors=True,
                dimred = False, print_conf= False)

In [None]:
tree_ref_fullannotation, mp_ref_fullannotation = sca.classifiers.scHPL.learn_tree(data = reference_latent,
                batch_key = 'batch',
                batch_order = batchorder ,
                cell_type_key='annotation_batch',
                classifier = 'knn', dynamic_neighbors=True,
                dimred = False, print_conf= False)

In [None]:
commongenes = list(set(w1118_42d.var_names)& set(fca_reference.var_names))
print(len(commongenes))
w1118_42d = w1118_42d[:, fca_reference.var_names]
w1118_42d

In [None]:
w1118_42d = w1118_42d.copy()

In [None]:
w1118_42d.obs['batch'] = "McCall_w1118_42d_batch1" ## all cells in the same batch

In [None]:
w1118_42d.obs[:5]

In [None]:

ref_path = "/projectnb/mccall/sbandyadka/drpr42d_snrnaseq/analysis/scarches/"
model = sca.models.SCVI.load_query_data(
    w1118_42d,
    ref_path ,
    freeze_dropout = True,
)

In [None]:
model.train(max_epochs=50)

In [None]:
query_latent = sc.AnnData(model.get_latent_representation())
#query_latent.obs['cell_type'] = target_adata.obs["final_annotation"].tolist()
query_latent.obs['batch'] = w1118_42d.obs["batch"].tolist()

In [None]:
model.save(outputpath, overwrite=True)
query_latent.write('query_latent.h5ad')

In [None]:
adata_full = fca_reference.concatenate(w1118_42d, batch_key="ref_query")
adata_full

In [None]:
full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['annotation'] = adata_full.obs["annotation"].tolist()
full_latent.obs["broad_annotation"] = adata_full.obs["broad_annotation"].tolist()
full_latent.obs["broad_annotation_extrapolated"] = adata_full.obs["broad_annotation_extrapolated"].tolist()
full_latent.obs['batch'] = adata_full.obs["batch"].tolist()

In [None]:
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)

In [None]:
full_latent

In [None]:
sc.pl.umap(full_latent,
           color=['broad_annotation'],
           frameon=False,
           wspace=0.6, s=25,
           #palette=sns.color_palette('colorblind', as_cmap=True)[:5],
           #save='study_query.pdf'
           )

In [None]:
query_pred_fullannotation = sca.classifiers.scHPL.predict_labels(query_latent.X, tree=tree_ref_fullannotation)
query_pred_braodannotation = sca.classifiers.scHPL.predict_labels(query_latent.X, tree=tree_ref_broadannotation)

In [None]:
w1118_42d.obs['predicted_fullannotation'] = query_pred_fullannotation[0]
w1118_42d.obs['predicted_broadannotation'] = query_pred_braodannotation[0]

In [None]:
sc.pp.neighbors(w1118_42d)
sc.tl.leiden(w1118_42d)
sc.tl.umap(w1118_42d)


In [None]:
sc.pl.umap(w1118_42d,color=['predicted_broadannotation'])

In [None]:
sc.pl.umap(w1118_42d,color=['predicted_fullannotation'])

In [None]:
markers = ["elav","lncRNA:noe","VAChT","VGlut","Gad1","Vmat","SerT","Tdc2","ple", # neurons
             "repo","lncRNA:CR34335","alrm","wrapper","Indy","moody",#glia 
             "ninaC",	"trp",	"trpl", #photoreceptors
             "Hml", #hemocytes
             "ppl",#fatbody
             "drpr"]


In [None]:
sc.pl.dotplot(w1118_42d, markers, 'predicted_broadannotation')

In [None]:
sc.pl.dotplot(w1118_42d, markers, 'predicted_fullannotation')