## Train head and antenna combined model

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import torch
import scvi
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
from scipy.sparse import csr_matrix
from scipy.stats import median_abs_deviation
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
import gdown
import copy as cp
import os
import plotnine as p

sns.set_style('white')
sns.set(font_scale=1.5)
sc.settings.set_figure_params(dpi=80, facecolor="white")
sc.logging.print_header()
sc.settings.verbosity = 3


print(os.getenv("CUDA_VISIBLE_DEVICES"))
os.environ["CUDA_VISIBLE_DEVICES"]=os.getenv("CUDA_VISIBLE_DEVICES")

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))


In [None]:

## Set up I/O directories
basepath = Path("/projectnb/mccall/sbandyadka/drpr42d_snrnaseq/")
referencepath = basepath.joinpath('reference','FCA')
outputpath = basepath.joinpath('analysis','scarches')

In [None]:
## Read reference h5ad 
fca_reference_head = sc.read_h5ad(referencepath.joinpath("v2_fca_biohub_head_10x_raw.h5ad"))
fca_reference_antenna = sc.read_h5ad(referencepath.joinpath("v2_fca_biohub_antenna_10x_raw.h5ad"))

fca_reference_head.layers["counts"] = fca_reference_head.X.copy()
fca_reference_antenna.layers["counts"] = fca_reference_antenna.X.copy()

fca_reference_head.obs['tissue'] = "fca_head"
fca_reference_head.obs['tissue_batch'] = fca_reference_head.obs['batch'].astype(str)+"_"+fca_reference_head.obs['tissue'].astype(str)
fca_reference_antenna.obs['tissue'] = "fca_antenna"
fca_reference_antenna.obs['tissue_batch'] = fca_reference_antenna.obs['batch'].astype(str)+"_"+fca_reference_antenna.obs['tissue'].astype(str)

In [None]:
print(fca_reference_head.var.index.is_unique)
print(fca_reference_antenna.var.index.is_unique)
print(fca_reference_head.shape,fca_reference_antenna.shape)

In [None]:
fca_reference_combined = ad.concat([fca_reference_head,fca_reference_antenna])
#fca_reference_combined = fca_reference_head.concatenate(fca_reference_antenna,join='outer', batch_key="concatbatch")
fca_reference_combined.layers["counts"] = fca_reference_combined.X.copy()
fca_reference_combined



In [None]:
print(fca_reference_combined.var.index.is_unique)
fca_reference_combined.obs['tissue'].value_counts()

In [None]:
sc.pp.normalize_total(fca_reference_combined)
sc.pp.log1p(fca_reference_combined)
fca_reference_combined.raw = fca_reference_combined

In [None]:
sc.pp.highly_variable_genes(
    fca_reference_combined,
    n_top_genes=2000,
    layer="counts",
    batch_key="tissue_batch",
    subset=True,
)

In [None]:
sca.models.SCVI.setup_anndata(fca_reference_combined, batch_key="tissue_batch", 
                              categorical_covariate_keys=["age","fly_genetics","tissue","tissue_batch","dissection_lab","sex"])

#sca.models.SCVI.setup_anndata(fca_reference_combined, batch_key="tissue_batch")

In [None]:
sc.pp.highly_variable_genes(
    fca_reference_combined,
    n_top_genes=2000,
    batch_key="tissue_batch",
    subset=True)

In [None]:
fca_reference_combined

In [None]:

combined_vae = sca.models.SCVI(
    fca_reference_combined,
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=True,
    use_layer_norm="both",
    use_batch_norm="none"
)

In [None]:
combined_vae.train(max_epochs=2, check_val_every_n_epoch = 1)