## Set path

In [1]:
import os
dataset_dir = os.path.join(os.getcwd(), 'datasets/')
outputs_dir = os.path.join(os.getcwd(), 'outputs/')
if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

save_dir = os.path.join(outputs_dir, "different samples/CITE-SLN111-Gayoso-Mouse1toMouse2/TotalVI/")
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

## Load necessary libraries

In [2]:
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import anndata

Global seed set to 0


## Load data

In [3]:
train_data = sc.read_h5ad(os.path.join(dataset_dir, "different samples/CITE-SLN111-Gayoso/Mouse1.h5ad"))
test_data = sc.read_h5ad(os.path.join(dataset_dir, "different samples/CITE-SLN111-Gayoso/Mouse2.h5ad"))
train_data, test_data

(AnnData object with n_obs × n_vars = 9264 × 13553
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
     uns: 'protein_name', 'version'
     obsm: 'protein_expression',
 AnnData object with n_obs × n_vars = 7564 × 13553
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
     uns: 'protein_name', 'version'
     obsm: 'protein_expression')

## Combine gene expression data from the training and test sets

In [4]:
train_rna_expression = pd.DataFrame(train_data.X.todense(), columns=train_data.var.index, index=train_data.obs.index)
test_rna_expression = pd.DataFrame(test_data.X.todense(), columns=test_data.var.index, index=test_data.obs.index)
rna_expression = pd.concat([train_rna_expression, test_rna_expression], axis=0)
rna_expression

index,0610007P14Rik,0610009B22Rik,0610009L18Rik,0610009O20Rik,0610010F05Rik,0610010K14Rik,0610011F06Rik,0610012G03Rik,0610030E20Rik,0610037L13Rik,...,mt-Co2,mt-Co3,mt-Cytb,mt-Nd1,mt-Nd2,mt-Nd3,mt-Nd4,mt-Nd4l,mt-Nd5,mt-Nd6
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGGGTAATT-1,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,...,80.0,107.0,64.0,22.0,38.0,19.0,43.0,4.0,5.0,0.0
AAACCCAAGGTAAACT-1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,70.0,78.0,31.0,19.0,29.0,8.0,24.0,1.0,5.0,0.0
AAACCCACACTAGGTT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,28.0,39.0,19.0,7.0,4.0,3.0,12.0,1.0,2.0,0.0
AAACCCACAGATACCT-1,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,...,54.0,43.0,26.0,9.0,22.0,10.0,8.0,1.0,0.0,1.0
AAACCCACAGGAATAT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,...,47.0,64.0,36.0,15.0,22.0,8.0,19.0,1.0,2.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTAAGATCA-2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,12.0,15.0,6.0,8.0,3.0,8.0,4.0,1.0,0.0,0.0
TTTGTTGGTCAGACTT-2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,20.0,19.0,6.0,4.0,8.0,11.0,6.0,1.0,2.0,0.0
TTTGTTGGTGTTTACG-2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,24.0,28.0,10.0,6.0,8.0,14.0,7.0,4.0,2.0,0.0
TTTGTTGTCAGAATAG-2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,17.0,14.0,9.0,11.0,7.0,6.0,2.0,4.0,2.0,0.0


## If the dataset lacks batch information, set *batch_index* to 0 for training cells and 1 for test cells

In [5]:
cells = pd.concat([train_data.obs, test_data.obs],axis=0)
cells["batch_index"] = np.repeat(0, cells.shape[0])
cells["batch_index"][train_data.shape[0]:] = 1
cells.index = range(cells.shape[0])
cells

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cells["batch_index"][train_data.shape[0]:] = 1


Unnamed: 0,n_protein_counts,n_proteins,seurat_hash_id,batch_indices,hash_id,n_genes,percent_mito,leiden_subclusters,cell_types,batch_index
0,2319.0,100,Spleen,0,Spleen,3137,0.062138,120,NKT,0
1,3760.0,105,Spleen,0,Spleen,2256,0.057545,6,CD122+ CD8 T,0
2,1351.0,104,Spleen,0,Spleen,1367,0.058373,3,Transitional B,0
3,3341.0,102,Lymph_Node,0,Lymph Node,1567,0.065386,4,Mature B,0
4,3708.0,102,Lymph_Node,0,Lymph Node,1895,0.059644,0,CD4 T,0
...,...,...,...,...,...,...,...,...,...,...
16823,3172.0,99,Lymph_Node,1,Lymph Node,774,0.060875,1,Mature B,1
16824,3603.0,97,Spleen,1,Spleen,1130,0.035604,0,CD4 T,1
16825,1915.0,97,Spleen,1,Spleen,1187,0.052554,3,Transitional B,1
16826,3010.0,100,Lymph_Node,1,Lymph Node,715,0.073801,5,Mature B,1


## Combine protein expression data from the training and test sets

In [6]:
train_protein_expression = pd.DataFrame(train_data.obsm["protein_expression"].todense(), columns=train_data.uns["protein_name"], index=train_data.obs.index)
test_protein_expression = pd.DataFrame(test_data.obsm["protein_expression"].todense(), columns=test_data.uns["protein_name"], index=test_data.obs.index)
protein_expression = pd.concat([train_protein_expression,test_protein_expression],axis=0)
protein_expression

Unnamed: 0_level_0,ADT_CD102_A0104,ADT_CD103_A0201,ADT_CD106_A0226,ADT_CD115(CSF-1R)_A0105,ADT_CD117(c-Kit)_A0012,ADT_CD11a_A0595,ADT_CD11c_A0106,ADT_CD122(IL-2Rb)_A0227,ADT_CD127(IL-7Ra)_A0198,ADT_CD134(OX-40)_A0195,...,ADT_TCRVr1.1-Cr4_A0209,ADT_TCRVr2_A0211,ADT_TCRVr3_A0210,ADT_TCRbchain_A0120,ADT_TCRr-d_A0121,ADT_TER-119-ErythroidCells_A0122,ADT_Tim-4_A0567,ADT_XCR1_A0568,ADT_anti-P2RY12_A0415,ADT_integrinb7_A0214
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGGGTAATT-1,71.0,6.0,6.0,7.0,1.0,89.0,7.0,2.0,3.0,9.0,...,0.0,2.0,3.0,31.0,3.0,0.0,9.0,4.0,1.0,67.0
AAACCCAAGGTAAACT-1,14.0,1.0,2.0,3.0,3.0,176.0,4.0,3.0,3.0,10.0,...,0.0,1.0,8.0,22.0,7.0,0.0,4.0,2.0,4.0,20.0
AAACCCACACTAGGTT-1,72.0,2.0,6.0,3.0,2.0,11.0,3.0,1.0,1.0,16.0,...,0.0,1.0,8.0,3.0,2.0,0.0,7.0,1.0,3.0,6.0
AAACCCACAGATACCT-1,28.0,3.0,7.0,8.0,1.0,57.0,1.0,1.0,3.0,9.0,...,0.0,0.0,7.0,5.0,5.0,1.0,4.0,2.0,2.0,45.0
AAACCCACAGGAATAT-1,15.0,4.0,2.0,2.0,5.0,86.0,2.0,3.0,3.0,10.0,...,0.0,2.0,5.0,76.0,1.0,0.0,10.0,4.0,1.0,17.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTAAGATCA-2,52.0,2.0,5.0,2.0,4.0,69.0,2.0,2.0,4.0,11.0,...,1.0,0.0,6.0,1.0,0.0,1.0,1.0,1.0,3.0,29.0
TTTGTTGGTCAGACTT-2,45.0,8.0,0.0,2.0,3.0,120.0,1.0,1.0,11.0,8.0,...,0.0,0.0,4.0,41.0,3.0,1.0,4.0,7.0,3.0,28.0
TTTGTTGGTGTTTACG-2,124.0,2.0,6.0,1.0,3.0,17.0,1.0,4.0,3.0,4.0,...,1.0,1.0,4.0,1.0,1.0,2.0,5.0,1.0,3.0,8.0
TTTGTTGTCAGAATAG-2,30.0,9.0,6.0,7.0,1.0,57.0,2.0,1.0,1.0,6.0,...,0.0,2.0,3.0,1.0,4.0,1.0,5.0,6.0,4.0,27.0


## Save protein expression data in test set for evaluation

In [7]:
test_protein_expression.to_csv(os.path.join(save_dir, "test_raw_protein_expression.txt"), sep="\t")

## Convert the merged gene expression data and protein expression data to anndata format

In [8]:
data = anndata.AnnData(X=rna_expression.values, var=train_data.var, obs=cells)
data.obsm["protein_expression"] = protein_expression.values
data.uns["protein_name"] = train_data.uns["protein_name"]
data



AnnData object with n_obs × n_vars = 16828 × 13553
    obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types', 'batch_index'
    var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
    uns: 'protein_name'
    obsm: 'protein_expression'

## Read batch and cell name information

In [9]:
batch = data.obs["batch_index"].values
cell_names = test_protein_expression.index
batch, cell_names

(array([0, 0, 0, ..., 1, 1, 1]),
 Index(['AAACCCAAGAATCTAG-2', 'AAACCCACACCGGAAA-2', 'AAACCCACACTACTTT-2',
        'AAACCCAGTAGGCAAC-2', 'AAACCCAGTCTCAGGC-2', 'AAACCCATCTCGTGAA-2',
        'AAACGAAAGAGGTATT-2', 'AAACGAACAAACACCT-2', 'AAACGAACACCCTTAC-2',
        'AAACGAATCGACCTAA-2',
        ...
        'TTTGGTTTCGTCCTTG-2', 'TTTGTTGAGGTAGGCT-2', 'TTTGTTGCAGAAATCA-2',
        'TTTGTTGCAGTTGGTT-2', 'TTTGTTGCATAGATGA-2', 'TTTGTTGGTAAGATCA-2',
        'TTTGTTGGTCAGACTT-2', 'TTTGTTGGTGTTTACG-2', 'TTTGTTGTCAGAATAG-2',
        'TTTGTTGTCGAGTGGA-2'],
       dtype='object', name='index', length=7564))

## Mask protein expression data in test set

In [10]:
data.obsm["protein_expression"][train_protein_expression.shape[0]:, :] = np.zeros(test_protein_expression.shape)

## Select highly variable genes

In [11]:
sc.pp.highly_variable_genes(data, batch_key="batch_index", flavor="seurat_v3", n_top_genes=4000, subset=True)

## Convert to scvi object

In [12]:
scvi.model.TOTALVI.setup_anndata(data, batch_key="batch_index", protein_expression_obsm_key="protein_expression")
model = scvi.model.TOTALVI(data, latent_distribution="normal", n_layers_decoder=2)
model

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[34mINFO    [0m Generating sequential column names                                                                        
[34mINFO    [0m Found batches with missing protein expression                                                             
[34mINFO    [0m Computing empirical prior initialization for protein background.                                          






## Train model

In [13]:
model.train()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 295/400:  74%|████████████████████████████████████████████████████▎                  | 295/400 [10:07<03:23,  1.94s/it, loss=580, v_num=1]Epoch 00295: reducing learning rate of group 0 to 2.4000e-03.
Epoch 328/400:  82%|██████████████████████████████████████████████████████████▏            | 328/400 [11:12<02:27,  2.05s/it, loss=588, v_num=1]Epoch 00328: reducing learning rate of group 0 to 1.4400e-03.
Epoch 396/400:  99%|██████████████████████████████████████████████████████████████████████▎| 396/400 [13:28<00:07,  1.95s/it, loss=573, v_num=1]Epoch 00396: reducing learning rate of group 0 to 8.6400e-04.
Epoch 400/400: 100%|███████████████████████████████████████████████████████████████████████| 400/400 [13:36<00:00,  1.93s/it, loss=577, v_num=1]

`Trainer.fit` stopped: `max_epochs=400` reached.


Epoch 400/400: 100%|███████████████████████████████████████████████████████████████████████| 400/400 [13:36<00:00,  2.04s/it, loss=577, v_num=1]


## Impute protein expression data in test

In [14]:
_, protein_means = model.get_normalized_expression(n_samples=25, transform_batch=0, include_protein_background=True, sample_protein_mixing=False, 
                                                   return_mean=True)
predicted_protein = pd.DataFrame(protein_means.iloc[train_protein_expression.shape[0]:,:].values, index=test_data.obs.index, 
                                 columns=test_data.uns["protein_name"])
predicted_protein

Unnamed: 0_level_0,ADT_CD102_A0104,ADT_CD103_A0201,ADT_CD106_A0226,ADT_CD115(CSF-1R)_A0105,ADT_CD117(c-Kit)_A0012,ADT_CD11a_A0595,ADT_CD11c_A0106,ADT_CD122(IL-2Rb)_A0227,ADT_CD127(IL-7Ra)_A0198,ADT_CD134(OX-40)_A0195,...,ADT_TCRVr1.1-Cr4_A0209,ADT_TCRVr2_A0211,ADT_TCRVr3_A0210,ADT_TCRbchain_A0120,ADT_TCRr-d_A0121,ADT_TER-119-ErythroidCells_A0122,ADT_Tim-4_A0567,ADT_XCR1_A0568,ADT_anti-P2RY12_A0415,ADT_integrinb7_A0214
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGAATCTAG-2,38.298965,4.838750,3.170877,2.208748,2.943351,35.889439,1.709485,1.643250,1.149943,9.997074,...,0.392259,2.493951,5.990996,2.781760,3.202429,0.661361,4.849572,2.602395,3.477031,25.962280
AAACCCACACCGGAAA-2,75.393692,4.787436,3.211803,2.285113,2.916541,10.653851,1.654977,1.633892,1.159133,10.054410,...,0.390018,2.395954,6.019639,2.805542,3.248852,0.608538,4.945483,2.539578,3.456675,8.463384
AAACCCACACTACTTT-2,37.164722,4.869664,3.338182,2.253772,2.954064,47.055954,1.687118,1.595588,1.147133,10.068223,...,0.401388,2.480696,6.141207,2.839814,3.217414,0.625307,4.917911,2.579165,3.402333,28.409063
AAACCCAGTAGGCAAC-2,49.512177,4.819985,3.415898,2.234860,2.953491,46.824188,1.726000,1.600700,1.165271,9.947783,...,0.397768,2.508254,6.005615,2.844030,3.271807,0.675059,4.962556,2.604548,3.503935,27.692053
AAACCCAGTCTCAGGC-2,38.180374,4.861646,3.215620,2.220744,2.998650,40.973579,1.675673,1.576723,1.155495,9.819227,...,0.392433,2.460261,5.995323,2.808746,3.212855,0.617171,4.832979,2.539773,3.441365,28.455059
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTAAGATCA-2,47.071709,4.834207,3.415769,2.259956,2.991157,42.736572,1.726898,1.608829,1.164149,10.037500,...,0.388243,2.499568,6.013940,2.754824,3.186126,0.668289,4.942200,2.560801,3.410412,27.222607
TTTGTTGGTCAGACTT-2,25.386141,4.917342,3.191632,2.239081,5.130885,95.176033,1.693355,1.656218,5.225467,10.787858,...,0.395005,2.559529,6.020990,71.683304,3.169446,0.605307,4.891374,2.568089,3.444194,18.921734
TTTGTTGGTGTTTACG-2,74.214737,5.014071,3.176440,2.276683,2.908447,31.866217,1.686445,1.566241,1.192797,9.953971,...,0.392231,2.443239,6.090394,2.827880,3.204513,0.613120,4.856096,2.514205,3.431981,12.791815
TTTGTTGTCAGAATAG-2,42.382954,4.804501,3.260980,2.275431,2.993080,41.361351,1.708885,1.585249,1.137194,9.988891,...,0.401485,2.479378,6.028633,2.833426,3.249082,0.702626,4.834034,2.562946,3.459940,24.746408


## Save prediction and trained model

In [15]:
predicted_protein.to_csv(os.path.join(save_dir, "test_protein_prediction.txt"), sep="\t")
model.save(dir_path=save_dir, save_anndata=True, overwrite=True)