In [37]:
from scEMAIL_model import *
import pandas as pd

### load the preprocessed target data

In [38]:
dataset='Pancreas'
print("dataname:", dataset)
adata = sc.read("/data/wanh/scEMAIL/real_data/{}_adata.h5ad".format(dataset))
X = adata.X.astype(np.float32)
with open('/data/wanh/scEMAIL/real_data/{}_count_X.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    count_X = []
    for row in spamreader:
        count_X.append([round(float(j)) for j in row])
count_X = np.array(count_X)

dataname: Pancreas


### load the pre-trained source model

In [39]:
checkpoint = torch.load("/data/wanh/scEMAIL/source_model/real_data/AE_weights_{}.pth.tar".format(dataset))
source_class_num = checkpoint['ae_state_dict']["classifier.0.bias"].size()[0]
model = target_model(input_dim=adata.n_vars, z_dim=32, n_clusters=source_class_num,
                             encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=2.5).cuda()
model_dict = model.state_dict()
for i in checkpoint['ae_state_dict']:
    model_dict[i] = checkpoint['ae_state_dict'][i]
model.load_state_dict(model_dict)

with open('/data/wanh/scEMAIL/real_data/{}_annotation.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    source_annotation = []
    for row in spamreader:
        source_annotation.append(row[0])
print("source cell types:",source_class_num, source_annotation)

source cell types: 9 ['acinar', 'alpha', 'beta', 'delta', 'ductal', 'endothelial', 'epsilon', 'gamma', 'mesenchymal']


In [40]:
t0 = time()
bimodality,pred_celltype = model.fit(x=X, annotation=source_annotation, X_raw=count_X,
                           size_factor=adata.obs.size_factors,pretrain_epoch=10,midtrain_epoch=20,
                                     K=5, KK=5, alpha=0.1)
time_cost = time() - t0

number of samples: 2282
number of class: 9
bimodality of dip test: 0.7144285571442855 False
bimodality coefficient:(>0.555 indicates bimodality) 0.4676114187132847 False
ood sample exists: False
Pretrain epoch [1/1], ZINB loss:5.3878
Pretrain epoch [2/1], ZINB loss:4.6107
Pretrain epoch [3/1], ZINB loss:4.0291
Pretrain epoch [4/1], ZINB loss:3.8227
Pretrain epoch [5/1], ZINB loss:3.3147
Pretrain epoch [6/1], ZINB loss:3.1258
Pretrain epoch [7/1], ZINB loss:3.0232
Pretrain epoch [8/1], ZINB loss:2.9984
Pretrain epoch [9/1], ZINB loss:2.6662
Pretrain epoch [1/2], ZINB loss:2.4081
Pretrain epoch [2/2], ZINB loss:2.7254
Pretrain epoch [3/2], ZINB loss:2.4217
Pretrain epoch [4/2], ZINB loss:2.4856
Pretrain epoch [5/2], ZINB loss:2.1758
Pretrain epoch [6/2], ZINB loss:2.1847
Pretrain epoch [7/2], ZINB loss:2.1552
Pretrain epoch [8/2], ZINB loss:1.9722
Pretrain epoch [9/2], ZINB loss:2.0193
Pretrain epoch [1/3], ZINB loss:1.9396
Pretrain epoch [2/3], ZINB loss:1.8370
Pretrain epoch [3/3], ZIN

Midtrain epoch [3/16], ZINB loss:0.9099,  neighbor loss 1:-2.2737, expanded neighbor loss 1:-1.9736, self loss:-0.7788
Midtrain epoch [4/16], ZINB loss:0.9095,  neighbor loss 1:-2.1520, expanded neighbor loss 1:-1.9541, self loss:-0.7724
Midtrain epoch [5/16], ZINB loss:0.8957,  neighbor loss 1:-2.0071, expanded neighbor loss 1:-1.9756, self loss:-0.7785
Midtrain epoch [6/16], ZINB loss:0.8781,  neighbor loss 1:-2.0473, expanded neighbor loss 1:-1.9866, self loss:-0.7824
Midtrain epoch [7/16], ZINB loss:0.8905,  neighbor loss 1:-2.0994, expanded neighbor loss 1:-1.9750, self loss:-0.7778
Midtrain epoch [8/16], ZINB loss:0.8719,  neighbor loss 1:-1.8907, expanded neighbor loss 1:-1.9811, self loss:-0.7813
Midtrain epoch [9/16], ZINB loss:0.8730,  neighbor loss 1:-1.7444, expanded neighbor loss 1:-1.9746, self loss:-0.7746
current error: tensor(0.0022, device='cuda:0')
Midtrain epoch [1/17], ZINB loss:0.8734,  neighbor loss 1:-1.9539, expanded neighbor loss 1:-1.9792, self loss:-0.7832
M

### calculate annotation accuracy

In [41]:
print("novel cell types exist:",bimodality)
cellname = np.array(adata.obs["celltype"])
print("target cell types:",np.unique(cellname))
accuracy=np.mean(pred_celltype == cellname)
result=np.array([[dataset,accuracy,time_cost]])
output = pd.DataFrame(result,columns=["Dataset","Total accuracy","Time consuming"])
print(output)

novel cell types exist: False
target cell types: ['acinar' 'alpha' 'beta' 'delta' 'ductal' 'mesenchymal']
    Dataset      Total accuracy      Time consuming
0  Pancreas  0.9666958808063103  13.497034788131714
