In [127]:
from scEMAIL_model import *
import pandas as pd

### load the preprocessed target data

In [128]:
dataset='Neonatal_rib'
print("dataname:", dataset)
adata = sc.read("/data/wanh/scEMAIL/real_data/{}_adata.h5ad".format(dataset))
X = adata.X.astype(np.float32)
with open('/data/wanh/scEMAIL/real_data/{}_count_X.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    count_X = []
    for row in spamreader:
        count_X.append([round(float(j)) for j in row])
count_X = np.array(count_X)


dataname: Neonatal_rib


### load the pre-trained source model

In [134]:
checkpoint = torch.load("/data/wanh/scEMAIL/source_model/real_data/AE_weights_{}.pth.tar".format(dataset))
source_class_num = checkpoint['ae_state_dict']["classifier.0.bias"].size()[0]
model = target_model(input_dim=adata.n_vars, z_dim=32, n_clusters=source_class_num,
                             encodeLayer=[256, 64], decodeLayer=[64, 256], sigma=2.5).cuda()
model_dict = model.state_dict()
for i in checkpoint['ae_state_dict']:
    model_dict[i] = checkpoint['ae_state_dict'][i]
model.load_state_dict(model_dict)

with open('/data/wanh/scEMAIL/real_data/{}_annotation.csv'.format(dataset),newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
    source_annotation = []
    for row in spamreader:
        source_annotation.append(row[0])
print("source cell types:",source_class_num, source_annotation)

source cell types: 13 ['B cell', 'Dividing cell', 'Endothelial cell', 'Erythroblast', 'Granulocyte', 'Macrophage', 'Muscle cell', 'Neuron', 'Neutrophil', 'Oligodendrocyte', 'Osteoblast', 'Osteoclast', 'Stromal cell']


### source model adaptation towards target data

In [132]:
t0 = time()
bimodality,pred_celltype = model.fit(x=X, annotation=source_annotation, X_raw=count_X,
                           size_factor=adata.obs.size_factors,pretrain_epoch=10,midtrain_epoch=20,
                                     K=5, KK=5, alpha=0.1)
time_cost = time() - t0

number of samples: 1217
number of class: 13
bimodality of dip test: 0.0 True
bimodality coefficient:(>0.555 indicates bimodality) 0.863532056011354 True
ood sample exists: True
Pretrain epoch [1/1], ZINB loss:0.1741
Pretrain epoch [2/1], ZINB loss:0.1808
Pretrain epoch [3/1], ZINB loss:0.1768
Pretrain epoch [4/1], ZINB loss:0.1780
Pretrain epoch [5/1], ZINB loss:0.1794
Pretrain epoch [1/2], ZINB loss:0.1746
Pretrain epoch [2/2], ZINB loss:0.1703
Pretrain epoch [3/2], ZINB loss:0.1728
Pretrain epoch [4/2], ZINB loss:0.1770
Pretrain epoch [5/2], ZINB loss:0.1870
Pretrain epoch [1/3], ZINB loss:0.1667
Pretrain epoch [2/3], ZINB loss:0.1731
Pretrain epoch [3/3], ZINB loss:0.1727
Pretrain epoch [4/3], ZINB loss:0.1869
Pretrain epoch [5/3], ZINB loss:0.1754
Pretrain epoch [1/4], ZINB loss:0.1713
Pretrain epoch [2/4], ZINB loss:0.1739
Pretrain epoch [3/4], ZINB loss:0.1723
Pretrain epoch [4/4], ZINB loss:0.1727
Pretrain epoch [5/4], ZINB loss:0.1844
Pretrain epoch [1/5], ZINB loss:0.1708
Pret

### calculate annotation accuracy

In [136]:
cellname = np.array(adata.obs["celltype"])
print("novel cell types exist:",bimodality)
print("target cell types:",np.unique(cellname))
print("novel cell type:",[j for j in np.unique(cellname) if j not in source_annotation])
true_known,true_unknown,right_pred_known,right_pred_unknown = 0,0,0,0
for i in range(len(cellname)):
    if cellname[i] not in source_annotation:
        true_unknown += 1
        if pred_celltype[i]=="Unknown":
                right_pred_unknown += 1
    else:
        true_known += 1
        if pred_celltype[i]== cellname[i]:
            right_pred_known += 1
accuracy_known=right_pred_known / true_known
accuracy_unknown=right_pred_unknown / true_unknown
total_accuracy= (right_pred_known + right_pred_unknown) / len(cellname)
H_score=2 * accuracy_known * accuracy_unknown / (accuracy_known + accuracy_unknown)
result=np.array([[dataset,accuracy_known, accuracy_unknown, total_accuracy, H_score, time_cost]])
output = pd.DataFrame(result,
                      columns=["Dataset","Accuracy of known", "Accuracy of unknown","Total accuracy", "H-score",
                               "Time consuming"])
print(output)

novel cell types exist: True
target cell types: ['B cell' 'Cartilage cell' 'Dividing cell' 'Endothelial cell'
 'Erythroblast' 'Granulocyte' 'Macrophage' 'Muscle cell' 'Neuron'
 'Neutrophil' 'Osteoblast' 'Osteoclast' 'Stromal cell']
novel cell type: ['Cartilage cell']
        Dataset   Accuracy of known Accuracy of unknown      Total accuracy  \
0  Neonatal_rib  0.8719611021069692               0.975  0.9227608874281019   

              H-score      Time consuming  
0  0.9206063664085294  11.259994506835938  
