## Set path

In [1]:
import os
dataset_dir = os.path.join(os.getcwd(), 'datasets/')
outputs_dir = os.path.join(os.getcwd(), 'outputs/')
if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

save_dir = os.path.join(outputs_dir, "different samples/CITE-SLN111-Gayoso-Mouse1toMouse2/sciPENN/")
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

## Load necessary libraries

In [2]:
import anndata
from sciPENN.sciPENN_API import sciPENN_API
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [3]:
train_data = anndata.read(os.path.join(dataset_dir, "different samples/CITE-SLN111-Gayoso/Mouse1.h5ad"))
test_data = anndata.read(os.path.join(dataset_dir, "different samples/CITE-SLN111-Gayoso/Mouse2.h5ad"))
train_data, test_data

(AnnData object with n_obs × n_vars = 9264 × 13553
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
     uns: 'protein_name', 'version'
     obsm: 'protein_expression',
 AnnData object with n_obs × n_vars = 7564 × 13553
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'
     uns: 'protein_name', 'version'
     obsm: 'protein_expression')

## Convert gene expression data and protein expression data in training dataset to anndata separately

In [4]:
train_rna = anndata.AnnData(X=train_data.X, var=train_data.var, obs=train_data.obs)

train_protein = anndata.AnnData(X=train_data.obsm["protein_expression"])
train_protein_var = pd.DataFrame(train_data.uns["protein_name"], columns=["protein_name"])
train_protein.var = train_protein_var
train_protein.obs = train_data.obs

train_rna, train_protein

(AnnData object with n_obs × n_vars = 9264 × 13553
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode',
 AnnData object with n_obs × n_vars = 9264 × 110
     obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
     var: 'protein_name')

## Convert gene expression data in test dataset to anndata

In [5]:
test_rna = anndata.AnnData(X=test_data.X, var=test_data.var, obs=test_data.obs)
test_rna

AnnData object with n_obs × n_vars = 7564 × 13553
    obs: 'n_protein_counts', 'n_proteins', 'seurat_hash_id', 'batch_indices', 'hash_id', 'n_genes', 'percent_mito', 'leiden_subclusters', 'cell_types'
    var: 'gene_ids', 'feature_types', 'highly_variable', 'highly_variable_mean_variance', 'encode', 'hvg_encode'

## Create sciPENN object
### Pass the batch information contained within the datasets stored in obs (if available) to the parameters *train_batchkeys* and *test_batchkey*

In [6]:
model = sciPENN_API(gene_trainsets=[train_rna], protein_trainsets=[train_protein], gene_test=test_rna) 

Searching for GPU
GPU detected, using GPU

QC Filtering Training Cells




QC Filtering Testing Cells

QC Filtering Training Genes
QC Filtering Testing Genes

Normalizing Training Cells
Normalizing Testing Cells

Log-Normalizing Training Data
Log-Normalizing Testing Data

Finding HVGs


  [AnnData(sparse.csr_matrix(a.shape), obs=a.obs) for a in all_adatas],



Normalizing Gene Training Data by Batch


100%|██████████| 1/1 [00:00<00:00,  8.41it/s]



Normalizing Protein Training Data by Batch


100%|██████████| 1/1 [00:00<00:00, 37.80it/s]



Normalizing Gene Testing Data by Batch


100%|██████████| 1/1 [00:00<00:00, 10.60it/s]


## Train model
### The trained model weights will be saved in the *weights_dir*

In [7]:

model.train(weights_dir=os.path.join(save_dir, "sciPENN_weights"))

Epoch 0 prediction loss = 1.403
Epoch 1 prediction loss = 0.931
Epoch 2 prediction loss = 0.911
Epoch 3 prediction loss = 0.905
Epoch 4 prediction loss = 0.890
Epoch 5 prediction loss = 0.889
Epoch 6 prediction loss = 0.884
Epoch 7 prediction loss = 0.882
Epoch 8 prediction loss = 0.878
Epoch 9 prediction loss = 0.878
Epoch 10 prediction loss = 0.880
Epoch 11 prediction loss = 0.879
Epoch 12 prediction loss = 0.881
Epoch 13 prediction loss = 0.875
Decaying loss to 0.0001
Epoch 14 prediction loss = 0.860
Epoch 15 prediction loss = 0.862
Epoch 16 prediction loss = 0.863
Epoch 17 prediction loss = 0.860
Epoch 18 prediction loss = 0.861
Epoch 19 prediction loss = 0.860
Decaying loss to 1e-05
Epoch 20 prediction loss = 0.860
Epoch 21 prediction loss = 0.859
Epoch 22 prediction loss = 0.859
Epoch 23 prediction loss = 0.861
Epoch 24 prediction loss = 0.860
Epoch 25 prediction loss = 0.861
Decaying loss to 1.0000000000000002e-06
Epoch 26 prediction loss = 0.859


## Impute protein expression data in test

In [8]:
predicted_test = model.predict()
predicted_protein = predicted_test.X

  imputed_test = AnnData(zeros(shape = (len(cells), len(proteins.var))))


## Save prediction

In [9]:
predicted_protein = pd.DataFrame(predicted_protein, columns=train_protein.var["protein_name"], index=test_rna.obs.index)
predicted_protein.to_csv(os.path.join(save_dir, "test_protein_prediction.txt"), sep="\t")
predicted_protein

protein_name,ADT_CD102_A0104,ADT_CD103_A0201,ADT_CD106_A0226,ADT_CD115(CSF-1R)_A0105,ADT_CD117(c-Kit)_A0012,ADT_CD11a_A0595,ADT_CD11c_A0106,ADT_CD122(IL-2Rb)_A0227,ADT_CD127(IL-7Ra)_A0198,ADT_CD134(OX-40)_A0195,...,ADT_TCRVr1.1-Cr4_A0209,ADT_TCRVr2_A0211,ADT_TCRVr3_A0210,ADT_TCRbchain_A0120,ADT_TCRr-d_A0121,ADT_TER-119-ErythroidCells_A0122,ADT_Tim-4_A0567,ADT_XCR1_A0568,ADT_anti-P2RY12_A0415,ADT_integrinb7_A0214
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGAATCTAG-2,-0.075074,-0.480129,0.314208,-0.115879,-0.250983,-0.529469,-0.178850,-0.290278,-0.299035,-0.212885,...,0.146033,0.048426,-0.005435,-0.769430,0.216387,-0.038839,0.434090,0.035958,-0.275547,0.352861
AAACCCACACCGGAAA-2,2.070195,0.266736,0.577778,0.676200,0.448604,-1.921980,0.222479,0.173083,0.093942,0.989365,...,0.283938,0.974859,1.071386,-0.541235,1.082818,0.066703,0.969275,0.627288,0.849877,-0.683034
AAACCCACACTACTTT-2,0.115370,-0.626729,-0.219099,-0.348924,-0.692816,-0.538368,-0.355847,-0.490868,-0.591396,-0.681801,...,-0.193898,-0.422313,-0.511603,-0.866305,-0.478096,0.009231,-0.336346,-0.341177,-0.385641,-0.186743
AAACCCAGTAGGCAAC-2,0.130569,-0.453139,0.033271,-0.069973,-0.526858,-0.628178,-0.214568,-0.006709,-0.632176,-0.300095,...,-0.085706,-0.063461,-0.028688,-0.822515,-0.128905,-0.254644,-0.107969,-0.092040,0.012610,0.138125
AAACCCAGTCTCAGGC-2,0.007473,-0.193142,0.059421,0.181116,-0.135164,-0.593218,-0.057032,0.146777,-0.613469,0.141324,...,0.177557,0.229020,0.481565,-0.742698,0.135457,-0.177336,-0.013411,0.217457,0.455122,0.127714
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTAAGATCA-2,0.150713,0.044235,-0.300959,-0.170286,-0.499647,-0.676416,-0.081885,-0.165093,-0.581892,-0.181092,...,-0.352278,-0.047560,-0.140448,-0.741950,-0.260319,0.320621,-0.243821,0.079172,0.137088,0.377301
TTTGTTGGTCAGACTT-2,-0.658673,-0.294298,-0.189475,-0.143172,0.529070,0.544871,-0.196329,-0.126351,0.927257,0.199352,...,-0.028141,0.145750,0.132429,1.455479,0.131354,-0.264727,-0.115566,0.005645,0.092516,-0.355014
TTTGTTGGTGTTTACG-2,1.526162,-0.064761,0.048496,0.044260,-0.099444,-1.181569,0.008174,-0.128753,-0.251232,0.148012,...,-0.220676,0.112171,0.257751,-0.686211,0.091789,0.014448,0.166384,0.145029,0.275247,-0.891063
TTTGTTGTCAGAATAG-2,0.805792,0.036285,0.201381,0.152380,-0.268955,-0.586130,-0.229748,0.134019,-0.539321,0.178207,...,0.011201,0.300608,0.417309,-0.690056,0.459855,-0.132069,0.189741,0.187593,0.592481,0.393760
