In [None]:
import numpy as np
import muon as mu
import scanpy as sc

path = "C:/Users/cleme/Desktop/Github/"

As a first step we load the context and target dataset as `.h5ad` files. As a test dataset will use a subset of the mouse and human liver cell atlas. 

The data is preprocessed to contain only 4000 highly variable genes. Cells belonging to cell types with numerous samples were randomly sampled to reduce their size.  
Also cells with labeling conflicts between precise and rough labels were removed. The full datasets can be downloaded at https://www.livercellatlas.org/.

In [None]:
context_adata = sc.read_h5ad(path+"dataset/mouse_liver_filtered.h5ad")
target_adata = sc.read_h5ad(path+"dataset/human_liver_filtered.h5ad")

context_adata.X = context_adata.X.astype('float32')
target_adata.X = target_adata.X.astype('float32')

We specify the key under which the cell and batch label for the context and target dataset are saved.  
The cell labels for the target dataset are used only for plotting and not needed during training.  
If the target cell labels are unknown they can be set to `target_cell_key = None`.  

We print the cell labels common to both datasets and the cell labels unique to context and target dataset.   

In [None]:
context_batch_key = 'sample'
context_cell_key = 'cell_type_fine'

target_batch_key = 'sample'
target_cell_key = 'cell_type_fine'

joint_labels = set(context_adata.obs[context_cell_key]).intersection(set(target_adata.obs[target_cell_key]))  
unique_context_labels = set(context_adata.obs[context_cell_key]).difference(set(target_adata.obs[target_cell_key]))  
unique_target_labels = set(target_adata.obs[target_cell_key]).difference(set(context_adata.obs[context_cell_key]))   

print('Cell labels occuring in both datasets: ', sorted(list(joint_labels)))
print('Unique context cell labels:', sorted(list(unique_context_labels)))
print('Unique target cell labels:', sorted(list(unique_target_labels)))

Next, we create the `muon.MuData` dataset (https://muon.readthedocs.io/en/latest/) which scPecies uses during training.  
Muon lets us define container for multimodal data.  
One modality will be our context species dataset and one (or possibly more) will contain the target species dataset(s).  
We instantiate a preprocessing class and register context and target `anndata.AnnData` datasets.  


This also performs the data-level nearest neighbor search.  
We further reduce the dimensionality to the 2500 most highly variable genes.  
When performing an alignment for multiple species `.setup_target_adata` can be run multiple times.  

In [None]:
from preprocessing import create_mdata

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


preprocess = create_mdata(context_adata, context_batch_key, context_cell_key, context_dataset_name='mouse', context_n_top_genes=2500)
preprocess.setup_target_adata(target_adata, target_batch_key, target_cell_key, target_dataset_name='human', target_n_top_genes=2500)
preprocess.save_mdata(path, 'liver')

We define the context and target scVI models by instantiating the scPecies class.  
We recommend using an NVIDIA GPU during training. CPU training can be slow, and Apple Silicon runs into errors when trying to compute the log-gamma function for the scVI loss.  

In [None]:
from models import scPecies
import torch
import muon as mu


device = ("cuda" if torch.cuda.is_available() else "cpu")
mdata = mu.read_h5mu(path+"dataset/liver.h5mu")

model = scPecies(device, 
                mdata, 
                path,
                context_dataset_key = 'mouse', 
                target_dataset_key = 'human', 
                context_data_distr = 'nb',    
                target_data_distr = 'nb',                                    
                )

We train and evaluate the context scVI model.  
The model parameters are automatically saved to the specified path and the latent representations saved in the `muon.MuData` object at the context modality in the `.obsm` layer.

In [None]:
model.train_context(60, early_stopping=False)
model.eval_context()

Next we train and evaluate the target scVI model using the following commands:

In [None]:
model.train_target(60, early_stopping=False)
model.eval_target()

After training, we can predict cell labels using the aligned representation.  
We can compare the quality of the predicted labels with the data level nearest neighbor search. 

In [None]:
model.pred_labels_nns_aligned_latent_space()
model.compute_metrics() 

knn_acc = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['balanced_accuracy_score_nns_hom_genes']*100,1)
latent_acc = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['balanced_accuracy_score_nns_aligned_latent_space']*100,1)

knn_adj = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['adjusted_rand_score_nns_hom_genes'],3)
latent_adj = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['adjusted_rand_score_nns_aligned_latent_space'],3)

knn_mis = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['adjusted_mutual_info_score_nns_hom_genes'],3)
latent_mis = round(model.mdata.mod[model.target_dataset_key].uns['metrics']['adjusted_mutual_info_score_nns_aligned_latent_space'],3)

# prediction dataframes of the aligned latent knn search and the data-level knn search.
#model.mdata.mod[model.target_dataset_key].uns['prediction_df_nns_aligned_latent_space']
#model.mdata.mod[model.target_dataset_key].uns['prediction_df_nns_hom_genes']


# predicted cell labels of the aligned latent knn search and the data-level knn search.
#model.mdata.mod[model.target_dataset_key].obs['label_nns_aligned_latent_space']
#model.mdata.mod[model.target_dataset_key].obs['label_nns_hom_genes']


print('\n Accuracy: data-level knn-search: {}%, latent knn-search: {}%.'.format(knn_acc, latent_acc))
print('\n Adjusted Rand iIndex: data-level knn-search: {}, latent knn-search: {}.'.format(knn_adj, latent_adj))
print('\n Mutual information: data-level knn-search: {}, latent knn-search: {}.'.format(knn_mis, latent_mis))

We can plot the results for the liver cell dataset with provided functions.  
On other datasets these functions should be modified or scanpy functions like `scanpy.pl.umap` should be used.

In [None]:
from plot_utils import plot_umap, bar_plot

plot_umap(model)
bar_plot(model)

Finally the difference in modeled gene expression can be analyzed by comparing the log2-fold change in normalized gene expression. 

In [None]:
from plot_utils import plot_lfc

model.compute_logfold_change(lfc = 1)

# dataframe of the log2-fold change and corresponding probabilities.
# model.mdata.mod[model.context_dataset_key].uns['lfc_df']
# model.mdata.mod[model.context_dataset_key].uns['prob_df']

plot_lfc(model)