In [1]:
from scEMAIL_model import *
import pandas as pd

### load the preprocessed target data

In [3]:
dataset='Pancreas'
print("dataname:", dataset)
adata = sc.read("/data/wanh/scEMAIL/real_data/{}_adata.h5ad".format(dataset))
X = adata.X.astype(np.float32)
with open('/data/wanh/scEMAIL/real_data/{}_count_X.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    count_X = []
    for row in spamreader:
        count_X.append([round(float(j)) for j in row])
count_X = np.array(count_X)

dataname: Pancreas


### load the pre-trained source model

In [9]:
checkpoint = torch.load("/data/wanh/scEMAIL/source_model/real_data/AE_weights_{}.pth.tar".format(dataset))
source_class_num = checkpoint['ae_state_dict']["classifier.0.bias"].size()[0]
model = target_model(input_dim=adata.n_vars, z_dim=32, n_clusters=source_class_num,
                             encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=2.5).cuda()
model_dict = model.state_dict()
for i in checkpoint['ae_state_dict']:
    model_dict[i] = checkpoint['ae_state_dict'][i]
model.load_state_dict(model_dict)

with open('/data/wanh/scEMAIL/real_data/{}_annotation.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    source_annotation = []
    for row in spamreader:
        source_annotation.append(row[0])
print("source cell types:",source_class_num, source_annotation)

source cell types: 9 ['acinar', 'alpha', 'beta', 'delta', 'ductal', 'endothelial', 'epsilon', 'gamma', 'mesenchymal']


In [10]:
t0 = time()
bimodality,pred_celltype = model.fit(x=X, annotation=source_annotation, X_raw=count_X,
                           size_factor=adata.obs.size_factors,pretrain_epoch=10,midtrain_epoch=20,
                                     K=5, KK=5, alpha=0.1)
time_cost = time() - t0

number of samples: 2282
number of class: 9
bimodality of dip test: 0.42115788421157885 False
bimodality coefficient:(>0.555 indicates bimodality) 0.5071104284345349 False
ood sample exists: False
Pretrain epoch [1/1], ZINB loss:5.5391
Pretrain epoch [2/1], ZINB loss:4.2187
Pretrain epoch [3/1], ZINB loss:3.3590
Pretrain epoch [4/1], ZINB loss:3.7244
Pretrain epoch [5/1], ZINB loss:3.0431
Pretrain epoch [6/1], ZINB loss:3.1818
Pretrain epoch [7/1], ZINB loss:2.8888
Pretrain epoch [8/1], ZINB loss:2.5971
Pretrain epoch [9/1], ZINB loss:2.7450
Pretrain epoch [1/2], ZINB loss:2.5437
Pretrain epoch [2/2], ZINB loss:2.4650
Pretrain epoch [3/2], ZINB loss:2.3449
Pretrain epoch [4/2], ZINB loss:2.3353
Pretrain epoch [5/2], ZINB loss:2.1993
Pretrain epoch [6/2], ZINB loss:2.2195
Pretrain epoch [7/2], ZINB loss:1.8883
Pretrain epoch [8/2], ZINB loss:1.9429
Pretrain epoch [9/2], ZINB loss:1.8747
Pretrain epoch [1/3], ZINB loss:1.7869
Pretrain epoch [2/3], ZINB loss:1.7559
Pretrain epoch [3/3], ZI

Midtrain epoch [8/15], ZINB loss:0.8891,  neighbor loss 1:-2.1462, expanded neighbor loss 1:-1.9792, self loss:-0.7795
Midtrain epoch [9/15], ZINB loss:0.9042,  neighbor loss 1:-2.0226, expanded neighbor loss 1:-1.9711, self loss:-0.7759
current error: tensor(0.0013, device='cuda:0')
Midtrain epoch [1/16], ZINB loss:0.9032,  neighbor loss 1:-2.0330, expanded neighbor loss 1:-1.9797, self loss:-0.7830
Midtrain epoch [2/16], ZINB loss:0.8630,  neighbor loss 1:-1.8658, expanded neighbor loss 1:-1.9835, self loss:-0.7815
Midtrain epoch [3/16], ZINB loss:0.9279,  neighbor loss 1:-2.0561, expanded neighbor loss 1:-1.9633, self loss:-0.7736
Midtrain epoch [4/16], ZINB loss:0.8702,  neighbor loss 1:-2.0816, expanded neighbor loss 1:-1.9748, self loss:-0.7772
Midtrain epoch [5/16], ZINB loss:0.9158,  neighbor loss 1:-2.3169, expanded neighbor loss 1:-1.9767, self loss:-0.7815
Midtrain epoch [6/16], ZINB loss:0.9507,  neighbor loss 1:-2.1729, expanded neighbor loss 1:-1.9637, self loss:-0.7759
M

### calculate annotation accuracy

In [12]:
print("novel cell types exist:",bimodality)
cellname = np.array(adata.obs["celltype"])
print("target cell types:",np.unique(cellname))
accuracy=np.mean(pred_celltype == cellname)
result=np.array([[dataset,accuracy,time_cost]])
output = pd.DataFrame(result,columns=["Dataset","Total accuracy","Time consuming"])
print(output)

novel cell types exist: False
target cell types: ['acinar' 'alpha' 'beta' 'delta' 'ductal' 'mesenchymal']
    Dataset      Total accuracy      Time consuming
0  Pancreas  0.9614373356704645  12.809916734695435
