In [10]:
import numpy as np
from sklearn.cluster import SpectralClustering,KMeans
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import matthews_corrcoef,accuracy_score,precision_score,recall_score,confusion_matrix
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import os
import pandas as pd
from pathlib import Path

In [11]:
# HL: transfer from dataden "/umms-kinfai/duolin/ying/reditools2_candidates/"
covfilename=5
datatype='HEK293T_WT_directRNA'

hash_candidat={}
hash_candidat['AFG-H1_directRNA']='H1-AFG.candidate_sites.tab'
hash_candidat['AFG-H9_directRNA']='H9-AFG.candidate_sites.tab'
hash_candidat["PGC-H1_directRNA"]='H1-PGC.candidate_sites.tab'
hash_candidat["DE-H1_directRNA"]='H1-DE.candidate_sites.tab'
hash_candidat["DE-H9_directRNA"]='H9-DE.candidate_sites.tab'
hash_candidat["GM12878_directRNA"]='GM12878.candidate_sites.tab'
hash_candidat["H1-hESC_directRNA"]='H1-hESC.candidate_sites.tab'
hash_candidat["H9-hESC_directRNA"]='H9-hESC.candidate_sites.tab'
hash_candidat['HEK293T_DKO_directRNA']='HEK293T_WT.candidate_sites.tab'
hash_candidat["HEK293T_WT_directRNA"]='HEK293T_WT.candidate_sites.tab'
hash_candidat["HEK_WT_pass"]='HEK293T_WT.candidate_sites.tab'
candidatefile="/nfs/turbo/umms-kinfai/haorli/20240314_ReDD_result_data/figure2a/reditools2_candidates/"+hash_candidat[datatype]


In [12]:
AG_ratio_per_site={}
shortreadcoverage={}
input = open(candidatefile,'r')
for line in input:
    chr_ = line.split()[0]
    pos_ = line.split()[1]
    chrpos=chr_+"-"+pos_
    AG_ratio_per_site[chrpos]=float(line.split("\t")[3])
    shortreadcoverage[chrpos]=float(line.split("\t")[4])

In [16]:
foldername="/nfs/turbo/umms-kinfai/haorli/20240314_ReDD_result_data/supfig2/ReDD_results/"
modelist=[foldername+'/Eventfeature_5feature_win5/hg38retrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNA_epochs1_100_epochs2_50',
          foldername+'/Eventfeature_5feature_win7/hg38retrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNA_epochs1_100_epochs2_50',
foldername+'/Eventfeature_3feature_win9/hg38retrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNA_epochs1_100_epochs2_50',
foldername+'/Eventfeature_5feature_win9/ablationseq0011_nopretrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNA_epochs2_50',
foldername+'/Eventfeature_5feature_win9/ablationseq1011_pretrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNAepochs1_100_epochs2_50',
foldername+'/Eventfeature_5feature_win9/ablationseq0111_pretrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNAepochs1_100_epochs2_50',
foldername+'/Eventfeature_5feature_win9/hg38retrain_HEK293T_KO1_chrsplitHEK293T_WT_directRNA_epochs1_100_epochs2_88',
foldername+'/Eventfeature_5feature_win9/ablationnoKO_HEK293T_KO1_chrsplitHEK293T_WT_directRNAepochs1_100_epochs2_50',
]
modelnames=['5D_5mer','5D_7mer','3D_9mer','editing_status_loss_only','ref_loss','basecall_loss','ReDD','REDD_WT_only']

In [37]:
num_bins = 5
binsize=1/num_bins
def binize_list(true_list,predict_list):
    ratio_list = np.arange(0,1,binsize) #np.random.uniform(0.0, 1.0, 30)
    ratio_list.sort()
    sample_size=[]
    AG_ratios_bin=[]
    predict_ratio_bin = []
    ratio_bin_list=[]
    for index in range(len(ratio_list)):
        start = ratio_list[index]
        end = start+binsize
        index_start_end = np.where((true_list<=end) &(true_list>start))[0]
        if(len(index_start_end)>1):
            ratio_bin_list.append(start)
            AG_ratios_bin.append(true_list[index_start_end])
            predict_ratio_bin.append(predict_list[index_start_end])
            sample_size.append(len(predict_list[index_start_end]))
    return AG_ratios_bin,predict_ratio_bin,sample_size
def cal_MAE(AG_ratios_bin,predict_ratio_bin):
    mae=np.abs(np.asarray([np.median(x) for x in AG_ratios_bin])-np.asarray([np.median(x) for x in predict_ratio_bin])).mean()
    return mae

In [41]:
long_reads_min_coverage = coverage_cutoff = 5
rows = []
Path('plot_data').mkdir(exist_ok=True,parents=True)
for model,met in zip(modelist,modelnames):
    
    filename=model+"/test_"+datatype+"_onlycandidate.txt"
    input = open(filename)
    pos_coverage={}
    coverage = {}
    cutoff = 0.5#0.5
    for line in input:
        score=float(line.split("\t")[-1])
        if score>=cutoff:
           predict_label=1
        else:
           predict_label=0

        transid= line.split("\t")[2]
        transpos = line.split("\t")[3]
        chrpos = transid+"-"+transpos

        if chrpos not in pos_coverage.keys():
            pos_coverage[chrpos]=predict_label
            coverage[chrpos]=1
        else:
            pos_coverage[chrpos]+=predict_label
            coverage[chrpos]+=1
        

    REDD_predict_value_all={}
    for site in AG_ratio_per_site:   
         if site in pos_coverage:
              if coverage[site] >= long_reads_min_coverage:
                  REDD_predict_value_all[site]=float(pos_coverage[site]/coverage[site])
    
    REDD_true_list = []
    REDD_predict_list = []
    coords_list = []
    for key in REDD_predict_value_all:#1259
          REDD_true_list.append(AG_ratio_per_site[key])
          REDD_predict_list.append(REDD_predict_value_all[key])
          coords_list.append(key)
            
    REDD_predict_list=np.asarray(REDD_predict_list)
    REDD_true_list=np.asarray(REDD_true_list)
    # export to csv
    df = pd.DataFrame([coords_list,REDD_true_list,REDD_predict_list]).T
    df.columns = ['Site','truth','ReDD']
    Path('plot_data/ratio/').mkdir(exist_ok=True,parents=True)
    df.to_csv(f'plot_data/ratio/{met}_site_ratio.tsv',sep='\t',index=False)
    # binize
    AG_ratios_bin,predict_ratio_bin,sample_size = binize_list(REDD_true_list,REDD_predict_list)
    mae = cal_MAE(AG_ratios_bin,predict_ratio_bin)
    rows.append([met,mae])

In [42]:
# export to csv
df = pd.DataFrame(rows)
df.columns = ['Model','MAE']
df.to_csv('plot_data/quantification_metrics.tsv',sep='\t',index=False)

In [40]:
df

Unnamed: 0,Model,MAE
0,5D_5mer,0.264295
1,5D_7mer,0.223775
2,3D_9mer,0.291918
3,editing_status_loss_only,0.249647
4,ref_loss,0.225671
5,basecall_loss,0.210652
6,ReDD,0.172502
7,REDD_WT_only,0.218523
