# Looking for neuron marker candidates for ASG and ASH

On Cao and Taylor C. elegans data using single-cell Variational Inference (scVI)


In [None]:
import scvi
scvi.__version__

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Control warnings
import warnings; warnings.simplefilter('ignore')

import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from scvi.dataset import GeneExpressionDataset
from scvi.models import VAE
from scvi.inference import UnsupervisedTrainer
import torch
import anndata
from tqdm import tqdm


import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Image

In [None]:
## Change the path where the models will be saved 
save_path = "./"

if os.path.isfile('wormcells-data-2020-03-30.h5ad'):
    print ("Found the data file! No need to download.")
else:
    print ("Downloading data...")
    ! wget https://github.com/Munfred/wormcells-site/releases/download/Packer2019Taylor2019Cao2019_wrangle2/wormcells-data-2020-03-30.h5ad


In [None]:
adata = anndata.read('wormcells-data-2020-03-30.h5ad')
adata = adata[adata.obs.study != 'packer'].copy()
adata

In [None]:
gene_dataset = scvi.dataset.AnnDatasetFromAnnData(adata,
                                    ctype_label=adata.obs['cell_type'].cat,
                                    batch_label=adata.obs['experiment'].cat
                                  )
# gene_dataset.filter_genes_by_count(min_count = 1, per_batch=False)
# sel_genes = gene_dataset.gene_names
gene_dataset

## Define and train the model

We now create the model and the trainer object. We train the model and output model likelihood every epoch. In order to evaluate the likelihood on a test set, we split the datasets (the current code can also so train/validation/test).

If a pre-trained model already exist in the save_path then load the same model rather than re-training it. This is particularly useful for large datasets.

In [None]:
# for this dataset 5 epochs is sufficient 
n_epochs = 10
lr = 1e-3
use_cuda = False # we are loading a CPU trained model so this should be false

# set the VAE to perform batch correction
vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches)

trainer = UnsupervisedTrainer(
    vae,
    gene_dataset,
    train_size=0.75, # number between 0 and 1, default 0.8
    use_cuda=use_cuda,
    frequency=1,
)

In [None]:
# check if a previously trained model already exists, if yes load it
vae_file_name = 'CT_vae_cpu1.pkl'


full_file_save_path = os.path.join(save_path, vae_file_name)

if os.path.isfile(full_file_save_path):
    trainer.model.load_state_dict(torch.load(full_file_save_path))
    trainer.model.eval()
else:
    trainer.train(n_epochs=n_epochs, lr=lr)
    torch.save(trainer.model.state_dict(), full_file_save_path)

In [None]:
train_test_results = pd.DataFrame(trainer.history).rename(columns={'elbo_train_set':'Train', 'elbo_test_set':'Test'})

train_test_results



## Obtaining the posterior object and sample latent space

The posterior object contains a model and a gene_dataset, as well as additional arguments that for Pytorch's `DataLoader`. It also comes with many methods or utilities querying the model, such as differential expression, imputation and differential analyisis.


To get an ordered output result, we might use `.sequential` posterior's method which return another instance of posterior (with shallow copy of all its object references), but where the iteration is in the same ordered as its  indices attribute.

In [None]:
# scvi tutorial latent space code
full = trainer.create_posterior(trainer.model, gene_dataset, indices=np.arange(len(gene_dataset)))
latent, batch_indices, labels = full.sequential().get_latent()
batch_indices = batch_indices.ravel()
latent.shape



### Selecting cells to compare

Let's look at cells with validated marker genes

The marker gene table is taken from:
```
Ramiro Lorenzo, Michiho Onizuka, Matthieu Defrance, Patrick Laurent 
Combining single-cell RNA-sequencing with a molecular atlas unveils new markers for Caenorhabditis elegans neuron classes
Nucleic Acid Research 2020
```
http://doi.org/10.1093/nar/gkaa486 


In [None]:
if os.path.isfile('lorenzo_markers_descriptions.csv'):
    print ("Found the data file! No need to download.")
else:
    print ("Downloading marker gene data...")
    ! wget https://github.com/Munfred/worm-markers/releases/download/CT_vae_cpu1/lorenzo_markers_descriptions.csv

lorenzo_markers=pd.read_csv('lorenzo_markers_descriptions.csv', index_col=0)
display(lorenzo_markers.head(15))


In [None]:
counter = 0
neurons_with_markers = []
for neuron in lorenzo_markers.index.unique():
    
    if neuron in adata.obs['cell_type'].unique():
        counter +=1
        neurons_with_markers.append(neuron)

print('In the single cell data we have', counter, 'neurons with known markers:')
print(neurons_with_markers)

# Perform differential expression between each neuron type for which there are markers and all other cells
We plot the bayes factor vs the scVI scale1, which is a value for the expression of the gene (the frequency with which transcripts of that gene are seen in the data)

High bayes factor => evidence for DE in the tissue of interest

We color the known markers for that neuron, the ones for other neurons, and some common genes for reference

We do this for all 2 neurons with several markers as an example. 

In [None]:
# a dictionary to hold the differential expression results dataframes
dedfs={}
for cell_type_1 in ['ASK', 'ASG']: 
# you can change the neuron list to plot other cells, for example:
# for cell_type_1 in ['ASK', 'ASJ','AWB','AWC_OFF','AWC_ON']:
    cell_type_2 = 'not ' + cell_type_1
    cell_idx1 = (adata.obs['cell_type'] == cell_type_1)
    ncells1=sum(cell_idx1)
    print(ncells1, 'cells of type', cell_type_1)
    cell_idx2 = ~adata.obs['cell_type'].str.contains(cell_type_1)
    ncells2=sum(cell_idx2)
    print(ncells2, 'cells of type', cell_type_2)
    if ncells1==0:
        print('No cells available in data...skipping...')
        continue


    genedesc = pd.read_csv('celegans_gene_descriptions.csv', index_col=0)
    genemaps = pd.read_csv('worm_gene_name_id.csv', index_col=1)

    n_samples = 10000
    M_permutation = 10000

    ###### DE CHANGE ############
    de = full.differential_expression_score(
        idx1 = cell_idx1.values, 
        idx2 = cell_idx2.values, 
        mode='change', # vanilla is the default
        n_samples=n_samples, 
        M_permutation=M_permutation,
    )

    de['gene_name']=de.index.map(genemaps['gene_name'])
    de['gene_description']=de.index.map(genedesc['gene_description'])
    de['gene_id']=de.index

    de['gene_color'] = 'rgba(100, 100, 100, 0.25)'
    de['gene_name']=de['gene_name'].fillna('noname')
    de['gene_description']=de['gene_description'].fillna('No description available')

    sanity_genes=['unc-122','sur-5', ' myo-2', 'rab-3', 'nsf-1', 'snb-1', 'cha-1', 'unc-17', 'unc-25', 'unc-47', 'eat-4']

    de['gene_kind']='other genes'

    # highlight markers for other neurons in blue
    for gene in pd.Series(lorenzo_markers['gene_name'].values):
        de.loc[de['gene_name']==gene, 'gene_color'] = 'rgba(0, 0,255, 1)'
        de.loc[de['gene_name']==gene, 'gene_kind'] = 'known marker genes for other neurons'
            
    #highlight sanity genes in green
    
    for gene in sanity_genes:
        de.loc[de['gene_name']==gene, 'gene_color'] = 'rgba(0, 255,0, 1)'
        de.loc[de['gene_name']==gene, 'gene_kind'] = 'sanity genes for reference'
    
    #highlight validated markers for cell of interest in red
    
    try:
        for gene in pd.Series(lorenzo_markers.loc[cell_type_1]['gene_name']):
            print(gene)
    #         print('MARKER GENES FOR ', cell_type_1)
    #         print(pd.Series(lorenzo_markers.loc[cell_type_1]['gene_name']))
            de.loc[de['gene_name']==gene, 'gene_color'] = 'rgba(255, 0,0, 1)'
            de.loc[de['gene_name']==gene, 'gene_kind'] = 'candidate marker genes for this neuron'
    except: pass
    
    de['gene_description_html'] = de['gene_description'].str.replace('\. ', '.<br>')
    fig = go.Figure(layout= {
                                        "title": {"text": 
                                                   'Change DE between ' +
                                                   str(ncells1) + ' ' + str(cell_type_1) + "</b> and <b>" + str(ncells2) + ' ' + str(cell_type_2)
                                                  , 'x':0.5        
                                                 }
                                        , 'xaxis': {'title': {"text": "log10 of scVI scale1"}}
                                        , 'yaxis': {'title': {"text": "Bayes Factor"}}
    #                                             , 'width':1600
    #                                             , 'height': 800
                                })
    for gene_kind in ['other genes',
        'candidate marker genes for this neuron',
        'known marker genes for other neurons',
        'sanity genes for reference']:
        sel = de[de['gene_kind']==gene_kind]    
        
        fig.add_trace(go.Scattergl(
                                  x=np.log10(sel["scale1"])
                                , y=sel["bayes_factor"]
                                , name = gene_kind
                                , mode='markers'
                                , marker=dict(color=sel['gene_color'])
                                , hoverinfo='text'
                                , text=sel['gene_description_html']
                                , customdata=sel.gene_id.values + '<br>Name: ' + sel.gene_name.values
                                , hovertemplate='%{customdata} <br>' +
                                                'Bayes Factor: %{y}<br>' +
                                                'Log10 scale1: %{x} <br>' +
                                                'Log10 scale2: ' + np.log10(sel["scale2"]).astype(str) +
                                                '<br> Ratio scale1/scale2: '  + np.round((sel["scale1"]/sel["scale2"]), 3).astype(str) +
                                                '<br>' + sel["gene_kind"].astype(str) + 
                                                '<extra>%{text}</extra>'
                                )
                                
                       )


    fig.update_layout(xaxis=dict(range=[-8,-1]))
    fig.update_layout(showlegend=True, template='none')
    fig.write_html('./'+ str(cell_type_1)+ '_DEchange.html')
    
    # save the DE results in csv
    de.to_csv('./'+ str(cell_type_1)+ '_DEchange.csv')
    dedfs[cell_type_1]=de
    fig.show()

# Use the lines below to save a static image instead of interactive plot    
#     img_bytes = fig.to_image(format="png", width=1000, height=800, scale=2)
#     display(Image(img_bytes))
#     print('woop')
#     fig.write_image("markerplots5_aug20/"+ str(cell_type_1)+'_DEchange.png')




## Create a matrix of tissue x gene with expression frequency


In [None]:
## Create a matrix of tissue x gene with expression frequency

tde = pd.DataFrame(columns=full.gene_dataset.gene_name)

for cell_type in tqdm(adata.obs.cell_type.unique()):
    res = full.scale_sampler(adata.obs.cell_type==cell_type, give_mean=True)
    tde.loc[cell_type]= res['scale']
tde.to_csv('./CaoTaylor_tissue_expression_levels.csv')
tde

In [None]:
tde = pd.read_csv('./CaoTaylor_tissue_expression_levels.csv', index_col=0)
tde

## Create swarm plot of relative expression 
For each of the top genes with the highest bayes factor in the tissue of interest we plot their expression relative to all other 210 tissues. The goal is to identify genes with relatively high expression only on the target tissue. 

In the x axis we show the gene name, and the -log10 of the frequency. So if a label says `sri-49 | 3.0` that means that gene sri-39 is present at a frequency of 10^-3 = 0.001. That means that one in every 1000 transcripts seen in that cell type was of sri-49.


In [None]:
for cell_type in ['ASK', 'ASG']:

    data = []
    sel_genes = dedfs[cell_type][:50]['gene_name']
    # sort display order by median log fold change
    sel = dedfs[cell_type][:50].sort_values(by='lfc_median', ascending=False)
    sel_genes=sel['gene_name']
    mean_fold_change_order = pd.Series()
    for gene in sel_genes:
        if gene not in tde.columns: continue
        expression=np.mean(((tde[gene]/tde[gene][cell_type]).drop([cell_type])))
        mean_fold_change_order[gene]=expression
        mean_fold_change_order=mean_fold_change_order.sort_values()
        
    for gene in sel_genes: #mean_fold_change_order.index:
        if gene not in tde.columns: continue
        color = 'rgb(66, 167, 244)'
        if (lorenzo_markers.gene_name == gene).any(): color = 'rgb(0,0,255)'

        if lorenzo_markers[lorenzo_markers.gene_name == gene].index==cell_type: color = 'rgb(255,0,0)'

        trace = go.Box(
         y=np.round(np.log2((tde[gene]/tde[gene][cell_type]).drop([cell_type])),3), boxpoints = 'all', pointpos = 0
        , marker = dict(color = color)
        , line = dict(color = 'rgba(0,0,0,0)')
        , fillcolor = 'rgba(0,0,0,0)'
        , opacity=1
        , marker_size=4  
        , name = gene  + ' | ' + str(np.round(-np.log10(tde[gene][cell_type]),2))
        , hoverinfo='text'
        , text=gene
        , customdata=[gene]*(len(tde[gene])-1)
        , hovertemplate='%{customdata} <br>' +
                        '%{y}<br>' +'<extra></extra>'
    #                     '<extra>%{text}</extra>'
        )
        data.append(trace)
    title='Relative log2 expression of top '+str(len(sel_genes))+ ' ' + cell_type + ' specific genes vs all other tissues sorted by mean expression' 
    layout = go.Layout(title=title
    #                    , width=750, height=500
                      )
    fig = go.Figure(data, layout)
    fig.update_layout(showlegend=False, template='none')
    fig.update_xaxes(tickangle=90, tickfont=dict( color='black', size=9))
#     fig.write_html('./' + cell_type + '_top80_log2swarmplot.html')
    fig.show()
#     break