### For self-test dataset

In [1]:
import os
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

figsuplix = 'pdf'

In [2]:
def is_Exist_file(path):
    import os
    if os.path.exists(path):
        os.remove(path)


def mkdir(path):
    import os
    path = path.strip()  # 去除首位空格
    path = path.rstrip("\\")  # 去除尾部 \ 符号
    isExists = os.path.exists(path)  # 判断路径是否存在
    # 判断结果
    if not isExists:
        os.makedirs(path)  # 如果不存在则创建目录
        print(path + ' 创建成功')
    else:
        print(path + ' 目录已存在')  # 如果目录存在则不创建，并提示目录已存在

In [3]:
## note significant
def get_note_for_off_target(sign_data, cell_line, data_label, comb):
    index = sign_data.loc[(sign_data['cell_line']==cell_line) & 
                          (sign_data['data_label']==data_label) & 
                          (sign_data['combination']==comb), :].index.tolist()[0]
    pvalue = sign_data.loc[index, 'p_value']
    if pvalue < 1e-5:
        note = '**'
    elif pvalue < 0.05:
        note = '*'
    else:
        note = 'n.s.'
    return note


def note_significant_for_off_target(ax, plot_data, sign_data, cell_line, data_label, y, order):
    model_comb_list = [("MLP", "XGBoost"), ("MLP", "Elastic")]
    if y == 'pearson':
        delt_y = 0.05
    else:
        delt_y = 0.04
    last_y = 0
    for model1, model2 in model_comb_list:
        comb = '%s vs %s'%(model1, model2)
        note = get_note_for_off_target(sign_data, cell_line, data_label, comb)
        ##
        index1, index2 = order.index(model1), order.index(model2)
        maxy = plot_data.loc[plot_data['model_label'].isin([model1, model2]), :][y].max()
        last_y = max([last_y + delt_y, maxy + delt_y])
        ax.hlines(last_y, index1, index2, colors="black")
        ax.text((index1 + index2)/2, last_y-0.01, note, fontsize=13)
        print(model1, model2, last_y)


## plot bar
def plot_barplot(plot_data, cell_line, data_label, y, sign_data, save_dir='./plot'):
    ## plot 
    import matplotlib.pyplot as plt
    import seaborn as sns

    # 设置默认绘图风格
    plt.style.use("seaborn-white")  
    fig, ax = plt.subplots(1,1, figsize=(6, 4))



    ylabel_dict = {'spearman': 'Spearman correlation', 
                   'pearson': 'Pearson correlation', 
                   'mse': 'MSE'}
    ylabel = ylabel_dict[y]
    order = ['MLP', 'XGBoost', 'Elastic', 'Lasso', 'Ridge']
    ax = sns.barplot(x='model_label', y=y, data=plot_data, saturation=2, linewidth=0.1, capsize=.2, order=order)
    widthbars = [0.5]* len(order)
    def adjust_barwidth(ax, widthbars):
        for bar,newwidth in zip(ax.patches,widthbars):
            x = bar.get_x()
            width = bar.get_width()
            centre = x+width/2.

            bar.set_x(centre-newwidth/2.)
            bar.set_width(newwidth)
    adjust_barwidth(ax, widthbars)

    ## 坐标轴不可见
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ## xlabel, ylabel 
    plt.ylabel(ylabel, fontsize=12, weight='bold')
    plt.xlabel("Models", fontsize=12, weight='bold')
    ## ylim
    if y == 'pearson':
        note_significant_for_off_target(ax, plot_data, sign_data, cell_line, data_label, y, order)
        plt.ylim(0, 1.1)
    elif y == 'spearman':
        note_significant_for_off_target(ax, plot_data, sign_data, cell_line, data_label, y, order)
        plt.ylim(0, 0.8)
    else: ## mse
        plt.ylim(0, 0.010)
    # plt.xlim(-1, 3)
    if data_label == 'val':
        title = 'Performance of %s off-target models for validation datasets'%(cell_line)
    else:
        title = 'Performance of %s off-target models for test datasets'%(cell_line)
    plt.title(title, fontsize=12, weight='bold')
    mkdir(save_dir)
    savefig_path = save_dir + '/self-%s-%s-%s_%s.%s'%(cell_line, data_label, ylabel, title, figsuplix)
    plt.savefig(savefig_path, dpi=300, bbox_inches = 'tight')
    plt.show()
#########################################################

### Evaluation for public datasets

In [4]:
import os
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

In [5]:
def is_Exist_file(path):
    import os
    if os.path.exists(path):
        os.remove(path)


def mkdir(path):
    import os
    path = path.strip()  # 去除首位空格
    path = path.rstrip("\\")  # 去除尾部 \ 符号
    isExists = os.path.exists(path)  # 判断路径是否存在
    # 判断结果
    if not isExists:
        os.makedirs(path)  # 如果不存在则创建目录
        print(path + ' 创建成功')
    else:
        print(path + ' 目录已存在')  # 如果目录存在则不创建，并提示目录已存在

In [6]:
## 从公共数据集中去除训练所用的 gRNAs
def excluding_gRNAs_list(cell_line, pdata_label):
    if cell_line == 'K562':
        if pdata_label == 'CRISPOR':
            exc_gRNA_list = ['TGGATGGAGGAATGAGGAGT', 'GCCTCCCCAAAGCCTGGCCA', 'GACCCCCTCCACCCCGCCTC', 
                             'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC']
        elif pdata_label == 'CRISTA':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 
                             'GACCCCCTCCACCCCGCCTC']
        elif pdata_label == 'Elevation':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Kleinstiver':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 
                             'GACCCCCTCCACCCCGCCTC']
        elif pdata_label == 'Listgarten':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GAGTCCGAGCAGAAGAAGAA', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Tsai':
            exc_gRNA_list = ['GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 'GACCCCCTCCACCCCGCCTC']
        else: ## Combined_GUIDE-seq_datasets
            exc_gRNA_list = ['GACCCCCTCCACCCCGCCTC', 'GATGGTAGATGGAGACTCAG', 'GTCACCTCCAATGACTAGGG', 
                             'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 'GCCGGAGGGGTTTGCACAGA']
    elif cell_line == 'Jurkat':
        if pdata_label == 'CRISPOR':
            exc_gRNA_list = ['TGGATGGAGGAATGAGGAGT', 'GCCTCCCCAAAGCCTGGCCA', 'GACCCCCTCCACCCCGCCTC', 
                             'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC']
        elif pdata_label == 'CRISTA':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 
                             'GACCCCCTCCACCCCGCCTC']
        elif pdata_label == 'Elevation':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Kleinstiver':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 
                             'GACCCCCTCCACCCCGCCTC']
        elif pdata_label == 'Listgarten':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GAGTCCGAGCAGAAGAAGAA', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Tsai':
            exc_gRNA_list = ['GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 'GACCCCCTCCACCCCGCCTC']
        else: ## Combined_GUIDE-seq_datasets
            exc_gRNA_list = ['GACCCCCTCCACCCCGCCTC', 'GATGGTAGATGGAGACTCAG', 'GTCACCTCCAATGACTAGGG', 
                             'GAGTCCGAGCAGAAGAAGAA', 'GGAATCCCTTCTGCAGCACC', 'GCCGGAGGGGTTTGCACAGA']
    else: ## for H1
        if pdata_label == 'CRISPOR':
            exc_gRNA_list = ['TGGATGGAGGAATGAGGAGT', 'GCCTCCCCAAAGCCTGGCCA', 'GACCCCCTCCACCCCGCCTC', 
                             'GGTGAGTGAGTGTGTGCGTG', 'GGAATCCCTTCTGCAGCACC']
        elif pdata_label == 'CRISTA':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GGAATCCCTTCTGCAGCACC', 'GACCCCCTCCACCCCGCCTC', 
                             'GGTGAGTGAGTGTGTGCGTG']
        elif pdata_label == 'Elevation':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Kleinstiver':
            exc_gRNA_list = ['GTCACCTCCAATGACTAGGG', 'GGAATCCCTTCTGCAGCACC', 'GACCCCCTCCACCCCGCCTC', 
                             'GGTGAGTGAGTGTGTGCGTG']
        elif pdata_label == 'Listgarten':
            exc_gRNA_list = ['GATGGTAGATGGAGACTCAG', 'GCCGGAGGGGTTTGCACAGA']
        elif pdata_label == 'Tsai':
            exc_gRNA_list = ['GGAATCCCTTCTGCAGCACC', 'GACCCCCTCCACCCCGCCTC', 'GGTGAGTGAGTGTGTGCGTG']
        else: ## Combined_GUIDE-seq_datasets
            exc_gRNA_list = ['GACCCCCTCCACCCCGCCTC', 'GGTGAGTGAGTGTGTGCGTG', 'GATGGTAGATGGAGACTCAG', 
                             'GTCACCTCCAATGACTAGGG', 'GGAATCCCTTCTGCAGCACC', 'GCCGGAGGGGTTTGCACAGA']
    return exc_gRNA_list

In [7]:
## model col
## 选择 selected models
def selected_model_cols_for_cell(cell_line):
    if cell_line == 'K562':
        selected_model_dict = {'MLP': 'MLP-all+P+M-44',
                               'Elastic': 'Elastic-all+P+M-24',
                               'Ridge': 'Ridge-all+P+M-49',
                               'Lasso': 'Lasso-all+P+M-23',
                               'XGBoost': 'XGBoost-all+P+M-48', 
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    elif cell_line == 'Jurkat':
        selected_model_dict = {'MLP': 'MLP-all+M-7',
                               'Elastic': 'Elastic-all+M-2',
                               'Ridge': 'Ridge-all+M-2',
                               'Lasso': 'Lasso-all+M-46',
                               'XGBoost': 'XGBoost-all+M-27', 
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    else: ## H1
        selected_model_dict = {'MLP': 'MLP-all+P+M-10',
                               'Elastic': 'Elastic-all+P+M-28',
                               'Ridge': 'Ridge-all+P+M-21',
                               'Lasso': 'Lasso-all+P+M-12',
                               'XGBoost': 'XGBoost-all+P+M-45', 
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    return selected_model_dict


def selected_model_cols_for_cell_2(cell_line):
    if cell_line == 'K562':
        selected_model_dict = {'MLP': 'MLP-all+P+M-44',
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    elif cell_line == 'Jurkat':
        selected_model_dict = {'MLP': 'MLP-all+M-7',
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    else: ## H1
        selected_model_dict = {'MLP': 'MLP-all+P+M-10', 
                              'Elevation-score': 'Elevation-score_score', 
                              'CFD-score': 'CFD_score', 
                              'Hsu-Zhang-score': 'Hsu-Zhang_score', 
                              'CCTop-score': 'CCTop_score'}
    return selected_model_dict


## plot auc
def plot_ROC(data, ytrue_col, ypred_col_dict, title, savefig_path, curve='auc'):
    ytrue = data[ytrue_col]
    # ================================ Ploting ====================================
    colors = ['red', 'royalblue', 'darkorange', 'lightgreen', 'palevioletred', 'teal',
              'maroon', 'indigo', 'darkorchid', 'mediumorchid', 'thistle', 'pink', 'blueviolet',
              'plum', 'violet', 'purple', 'm', 'lightseagreen', 'magenta',
              'orchid', 'chartreuse', 'deeppink', 'hotpink']
    from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
    import matplotlib.pyplot as plt

    # plt.figure(figsize=(12, 8))
    plt.figure(figsize=(6, 4))
    if len(ypred_col_dict) != 0:
        xlabel, ylabel = 'x', 'y'
        i = 0
        for label, ypred_col in ypred_col_dict.items():
            ypred = data[ypred_col]
            if curve == 'auc':
                score = roc_auc_score(ytrue, ypred)
                score = ', AUC=%s' % (round(score, 3))
                fpr, tpr, thresholds = roc_curve(ytrue, ypred, pos_label=1)  # pos_label=1，表示值为1的实际值为正样本
                x, y = fpr, tpr
                xlabel, ylabel = 'False Postive Rate', 'True Positive Rate'
            else:
                precision, recall, thresholds = precision_recall_curve(ytrue, ypred)
                x, y = precision, recall
                score = auc(recall, precision)
                score = ', PR-AUC=%s' % (round(score, 3))
                xlabel, ylabel = 'Recall', 'Precision'
                ## plot
            plt.plot(x, y, colors[i], label=label + score, linewidth=0.7)
            i += 1
        ## title
        if curve == 'auc':
            plt.ylim(0.5, 1.025)
        plt.title(title, fontsize=12, weight='bold')
        plt.xlabel(xlabel, fontsize=12, weight='bold')
        plt.ylabel(ylabel, fontsize=12, weight='bold')
        plt.legend(prop={'weight': 'bold', 'size': 6})
        plt.savefig(savefig_path, dpi=300, bbox_inches='tight')
        plt.close()
#         plt.show()
        
        
## main
## For all models
def plot_AUC_and_ROC(data, pdata_label, ytrue_col, data_label, gRNA_n, 
                     savefig_1_path, savefig_2_path):
    print('\n%s-%s ... ...'%(cell_line, pdata_label))
    title1 = 'ROC-AUC for %s'%(pdata_label)
    title2 = 'PR-AUC for %s'%(pdata_label)
    ypred_col_dict = selected_model_cols_for_cell(cell_line)
    plot_ROC(data, ytrue_col, ypred_col_dict, title1, savefig_1_path, curve='auc')
    plot_ROC(data, ytrue_col, ypred_col_dict, title2, savefig_2_path, curve='pr-auc')
    ## stat data
    stat_data_dict = {'cell line': [], 
                      'public data': [], 
                      'data label': [], 
                      'gRNA num': [],
                      'off-targets count': []}
    stat_data_dict['cell line'].append(cell_line)
    stat_data_dict['public data'].append(pdata_label)
    stat_data_dict['data label'].append(data_label)
    stat_data_dict['gRNA num'].append(gRNA_n)
    stat_data_dict['off-targets count'].append(data.shape[0])
    stat_data = pd.DataFrame(stat_data_dict)
    return stat_data