In [13]:
%%bash
#conda activate pyro2

CUDA_VISIBLE_DEVICES=0 python scClassifier.py --sup-data-file /home/zfeng/zfeng/scClassifier_Ex/pbmcsca.mtx \
                        --sup-label-file /home/zfeng/zfeng/scClassifier_Ex/pbmcsca_factors.txt \
                        -lr 0.0001 \
                        -n 300 \
                        -bs 100 \
                        --aux-loss \
                        -alm 100 \
                        -64 \
                        --jit \
                        --cuda \
                        -zi \
                        -likeli negbinomial \
                        -dirichlet \
                        --label-type onehot \
                        --validation-fold 0 \
                        --save-model pbmcsca.pth 2>&1 | tee pbmcsca.log

#conda deactivate

1 epoch: avg losses 75268.9830 68049.2556 0.0000 0.0000 elapsed 17.3707 seconds
2 epoch: avg losses 59315.9411 52032.3614 0.0000 0.0000 elapsed 6.5214 seconds
3 epoch: avg losses 57603.2121 42406.6290 0.0000 0.0000 elapsed 6.2148 seconds
4 epoch: avg losses 55869.6487 33188.3553 0.0000 0.0000 elapsed 6.3397 seconds
5 epoch: avg losses 54771.6093 25805.0103 0.0000 0.0000 elapsed 6.6000 seconds
6 epoch: avg losses 53966.4499 21615.3272 0.0000 0.0000 elapsed 6.5737 seconds
7 epoch: avg losses 53433.4453 17246.6037 0.0000 0.0000 elapsed 6.5423 seconds
8 epoch: avg losses 52960.6305 14089.0966 0.0000 0.0000 elapsed 6.5758 seconds
9 epoch: avg losses 52580.4841 12004.3507 0.0000 0.0000 elapsed 6.5491 seconds
10 epoch: avg losses 52233.1017 10418.8748 0.0000 0.0000 elapsed 6.4991 seconds
11 epoch: avg losses 51889.0660 9002.3251 0.0000 0.0000 elapsed 6.5676 seconds
12 epoch: avg losses 51623.8438 7724.5595 0.0000 0.0000 elapsed 6.5274 seconds
13 epoch: avg losses 51273.1707 6567.8764 0.0000 0

In [14]:
import numpy as np
import pandas as pd
from scipy.io import mmread
from scClassifier import scClassifier
from utils.scdata_cached import setup_data_loader, SingleCellCached

import torch
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, f1_score

import datatable as dt



In [15]:
ModelPath = 'pbmcsca.pth'
DataPath='/home/zfeng/zfeng/scClassifier_Ex/pbmcsca.mtx'
LabelPath='/home/zfeng/zfeng/scClassifier_Ex/pbmcsca_celltype.txt'


In [16]:
# load model
model = torch.load(ModelPath)

use_float64 = True
use_cuda = True

In [17]:
# load data
batch_size = 100

data_cached = SingleCellCached(DataPath, LabelPath, 'condition', use_cuda=use_cuda, use_float64 = use_float64)
data_loader = DataLoader(data_cached, batch_size = batch_size, shuffle = False)

In [18]:
# predict conditions
embeds = []
exprs = []
# use the appropriate data loader
for xs,ys in data_loader:
    # use classification function to compute all predictions for each batch

    if use_cuda:
        xs = xs.cuda()
        ys = ys.cuda()

    zs = model.latent_embedding(xs)
    expr = model.mute_expression(xs, mute_label_names=["10x Chromium (v2)","10x Chromium (v2) A","10x Chromium (v2) B","10x Chromium (v3)","CEL-Seq2","Drop-seq","Seq-Well","Smart-seq2","inDrops"], mute_noise=False)

    if use_cuda:
        zs = zs.cpu().detach().numpy()
        expr = expr.cpu().detach().numpy()
    else:
        zs = zs.detach().numpy()
        expr = expr.detach().numpy()

    embeds.append(zs)
    exprs.append(expr)


embeds = np.concatenate(embeds, axis=0)
exprs = np.concatenate(exprs, axis=0)



In [19]:
cells = pd.read_csv('/home/zfeng/zfeng/scClassifier_Ex/pbmcsca_cell.txt', header=None, index_col=None)
genes = pd.read_csv('/home/zfeng/zfeng/scClassifier_Ex/pbmcsca_gene.txt', header=None, index_col=None)
cells.shape

(14890, 1)

In [20]:
df = pd.DataFrame(exprs, columns=genes[0].values, index=cells[0].values)
dt.Frame(df.reset_index()).to_csv('/home/zfeng/zfeng/scClassifier_Ex/pbmcsca_mutate_express.txt')
df.head()

Unnamed: 0,CLPS,CPB1,MT1G,REG3A,COL1A2,MMP1,CRP,CELA3B,CELA3A,CPA1,...,AC008269.1,AC079061.1,PQLC2L,DOCK2,OPRK1,ZNF878,OPRPN,NEMP2,CCDC81,TEPSIN
D101_5,0.000487,0.006616,0.000219,0.002253,0.000213,3.9e-05,0.00012,0.000166,0.000319,0.000496,...,4.478372e-07,3.688737e-08,0.000341,1.208603e-06,1.001594e-05,3.687854e-06,1.156372e-06,0.000126,8.7e-05,9.953282e-06
D101_7,0.003216,0.005496,0.002627,0.006814,0.000389,0.000499,0.000631,0.003664,0.005937,0.006604,...,3.203163e-08,6.645654e-08,4.1e-05,2.934307e-06,1.842269e-06,1.458249e-05,2.833236e-06,6.9e-05,3.5e-05,3.761142e-07
D101_10,0.000433,0.000976,0.000129,0.002376,0.000322,0.000159,4e-05,0.000264,0.000665,0.000974,...,6.038206e-08,1.21439e-08,3e-06,6.543814e-07,1.940952e-06,4.005412e-06,2.570803e-06,8.7e-05,0.000172,1.68744e-06
D101_13,0.000476,0.003881,0.000139,0.001694,0.000434,0.000175,3.3e-05,0.000229,0.000652,0.000935,...,7.253705e-08,3.011997e-08,1.3e-05,7.043941e-06,3.390564e-06,6.14625e-07,1.12085e-08,7.9e-05,8.7e-05,5.395948e-09
D101_14,0.000509,0.000922,0.000158,0.001945,0.000176,5.2e-05,9e-06,0.000141,0.000476,0.000857,...,1.109188e-06,1.122901e-06,1.2e-05,2.714067e-06,1.056993e-07,1.348281e-06,6.280447e-09,0.000102,0.000137,1.054907e-07


In [21]:
df.shape

(14890, 2000)