In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import datetime
sys.path.append("../")
from concerto_function5_3 import *
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import scanpy as sc
import scipy.sparse as sps
import matplotlib.pyplot as plt
from metrics import osr_evaluator

from os.path import join

In [None]:
#Select an available GPU to run on a multi-GPU computer or you can run it directly on the CPU without executing this cell
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = '1' 
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True) 

In [None]:
exp_id = 'HumanFetal_100k'

adata_rna  = sc.read_h5ad('/home/yanxh/data/HumanFetal_100k/RNA/adata_rna_sampled.h5ad')
adata_atac = sc.read_h5ad('/home/yanxh/data/HumanFetal_100k/ATAC/adata_atac.h5ad')

adata_rna.obs['domain'] = 'RNA'
adata_atac.obs['domain']= 'ATAC'

adata_rna.obs['cell_type'] = adata_rna.obs['Main_cluster_name'].values

batch_key = 'domain'
type_key = 'cell_type'

In [None]:
adata_all = sc.concat([adata_rna, adata_atac])
adata_all

# Preprocess

In [None]:
st_time = datetime.datetime.now()

In [None]:
# filter cells, normalize_total, hvg(if), no scale
adata = preprocessing_rna(adata_all, 
                          min_features=0, 
                          n_top_features=None, 
                          is_hvg=False, 
                          batch_key=batch_key)

adata_ref = adata[adata.obs[batch_key] == 'RNA']
adata_query = adata[adata.obs[batch_key] == 'ATAC']

shr_mask = np.in1d(adata_query.obs[type_key], adata_ref.obs[type_key].unique())
atac_lab = np.array(adata_query.obs[type_key].values)

save_path = './'
# if not os.path.exists(save_path):
#     os.makedirs(save_path)
# adata_ref.write_h5ad(save_path + 'adata_ref.h5ad')
# adata_query.write_h5ad(save_path + 'adata_query.h5ad')  # .tech=='indrop'

In [None]:
ed_time = datetime.datetime.now()

pp_cost = (ed_time-st_time).total_seconds()
print('pp cost ', pp_cost)

In [None]:
st_time = datetime.datetime.now()

In [None]:
ref_tf_path = concerto_make_tfrecord_supervised(adata_ref, tf_path = save_path + f'tfrecord/{exp_id}/ref_tf/',
                                     batch_col_name = batch_key, label_col_name=type_key)
query_tf_path = concerto_make_tfrecord_supervised(adata_query, tf_path = save_path + f'tfrecord/{exp_id}/query_tf/',
                                     batch_col_name = batch_key, label_col_name=type_key)

In [None]:
ed_time = datetime.datetime.now()

rec_cost = (ed_time-st_time).total_seconds()
print('rec cost ', rec_cost)

In [None]:
st_time = datetime.datetime.now()

In [None]:
# train (leave spleen out). If you don't want to train the model, you can just load our trained classifier's weight and test it directly.
weight_path = save_path + f'weight/{exp_id}/'
ref_tf_path = save_path + f'tfrecord/{exp_id}/ref_tf/'
query_tf_path = save_path + f'tfrecord/{exp_id}/query_tf/'

concerto_train_inter_supervised_uda2(ref_tf_path, query_tf_path, weight_path,
                                     super_parameters={'batch_size': 128, 'epoch_pretrain': 1,'epoch_classifier': 10, 'lr': 1e-4,'drop_rate': 0.1})

In [None]:
ed_time = datetime.datetime.now()

train_cost = (ed_time-st_time).total_seconds()
print('train cost ', train_cost)

In [None]:
# test (only spleen)
weight_path = save_path + f'weight/{exp_id}/'
ref_tf_path = save_path + f'tfrecord/{exp_id}/ref_tf/'
query_tf_path = save_path + f'tfrecord/{exp_id}/query_tf/'

for epoch in [4]:
    st_time = datetime.datetime.now()
    results = concerto_test_inter_supervised2(weight_path, ref_tf_path, query_tf_path,
                                         super_parameters = {'batch_size': 64, 'epoch': epoch, 'lr': 1e-5,'drop_rate': 0.1})
    ed_time = datetime.datetime.now()

    test_cost = (ed_time-st_time).total_seconds()
    print('test cost ', test_cost)
    
    # NN classifier
    query_neighbor, query_prob = knn_classifier(results['source_feature'],
                                           results['target_feature'],
                                           adata_ref,
                                           adata_ref.obs_names,
                                           column_name=type_key,
                                           k=30)
    open_score = 1 - query_prob

    kn_data_pr = query_neighbor[shr_mask]
    kn_data_gt = atac_lab[shr_mask]
    kn_data_open_score = open_score[shr_mask]

    unk_data_open_score = open_score[np.logical_not(shr_mask)]

    closed_acc, os_auroc, os_aupr, oscr = osr_evaluator(kn_data_pr, kn_data_gt, kn_data_open_score, unk_data_open_score)
    print(closed_acc, os_auroc, os_aupr, oscr)
