In [1]:
from scEMAIL_model import *
import pandas as pd

### load the preprocessed target data

In [2]:
dataset='Pancreas'
print("dataname:", dataset)
adata = sc.read("/data/wanh/scEMAIL/real_data/{}_adata.h5ad".format(dataset))
X = adata.X.astype(np.float32)
with open('/data/wanh/scEMAIL/real_data/{}_count_X.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    count_X = []
    for row in spamreader:
        count_X.append([round(float(j)) for j in row])
count_X = np.array(count_X)

dataname: Pancreas


### load the pre-trained source model

In [3]:
checkpoint = torch.load("/data/wanh/scEMAIL/source_model/real_data/AE_weights_{}.pth.tar".format(dataset))
source_class_num = checkpoint['ae_state_dict']["classifier.0.bias"].size()[0]
model = target_model(input_dim=adata.n_vars, z_dim=32, n_clusters=source_class_num,
                             encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=2.5).cuda()
model_dict = model.state_dict()
for i in checkpoint['ae_state_dict']:
    model_dict[i] = checkpoint['ae_state_dict'][i]
model.load_state_dict(model_dict)

with open('/data/wanh/scEMAIL/real_data/{}_annotation.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    source_annotation = []
    for row in spamreader:
        source_annotation.append(row[0])
print("source cell types:",source_class_num, source_annotation)

source cell types: 9 ['acinar', 'alpha', 'beta', 'delta', 'ductal', 'endothelial', 'epsilon', 'gamma', 'mesenchymal']


In [4]:
t0 = time()
bimodality,pred_celltype = model.fit(x=X, annotation=source_annotation, X_raw=count_X,
                           size_factor=adata.obs.size_factors,pretrain_epoch=10,midtrain_epoch=20,
                           K=5, KK=5, alpha=0.1)
time_cost = time() - t0

number of samples: 2282
number of class: 9
bimodality of dip test: 0.3226677332266773 False
bimodality coefficient:(>0.555 indicates bimodality) 0.49933163808816555 False
novel cell types exist: False
Pretrain epoch [1/1]
Pretrain epoch [2/1]
Pretrain epoch [3/1]
Pretrain epoch [4/1]
Pretrain epoch [5/1]
Pretrain epoch [6/1]
Pretrain epoch [7/1]
Pretrain epoch [8/1]
Pretrain epoch [9/1]
Pretrain epoch [1/2]
Pretrain epoch [2/2]
Pretrain epoch [3/2]
Pretrain epoch [4/2]
Pretrain epoch [5/2]
Pretrain epoch [6/2]
Pretrain epoch [7/2]
Pretrain epoch [8/2]
Pretrain epoch [9/2]
Pretrain epoch [1/3]
Pretrain epoch [2/3]
Pretrain epoch [3/3]
Pretrain epoch [4/3]
Pretrain epoch [5/3]
Pretrain epoch [6/3]
Pretrain epoch [7/3]
Pretrain epoch [8/3]
Pretrain epoch [9/3]
Pretrain epoch [1/4]
Pretrain epoch [2/4]
Pretrain epoch [3/4]
Pretrain epoch [4/4]
Pretrain epoch [5/4]
Pretrain epoch [6/4]
Pretrain epoch [7/4]
Pretrain epoch [8/4]
Pretrain epoch [9/4]
Pretrain epoch [1/5]
Pretrain epoch [2/5]
P

### calculate annotation accuracy

In [5]:
print("novel cell types exist:",bimodality)
cellname = np.array(adata.obs["celltype"])
print("target cell types:",np.unique(cellname))
accuracy=np.mean(pred_celltype == cellname)
result=np.array([[dataset,accuracy,time_cost]])
output = pd.DataFrame(result,columns=["Dataset","Total accuracy","Time consuming"])
print(output)

novel cell types exist: False
target cell types: ['acinar' 'alpha' 'beta' 'delta' 'ductal' 'mesenchymal']
    Dataset      Total accuracy      Time consuming
0  Pancreas  0.9631901840490797  13.350159645080566
