In [1]:
import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import copy
from collections import defaultdict
from sklearn.metrics import precision_recall_curve, auc

In [2]:
def compute_true_false(preds, targets, dist_thresh):
    """
    compute the number of true and false predictions with the Hungarian algorithm

    Args:
        preds (array): coordinates of n predicted loops, array of shape (n, 2)
        targets (array): coordinates of m ground truth loop annotations, array of shape (m, 2)
        dist_thresh (float): threshold of the distance between two coordinates for them to be matched
    
    Returns:
        int, int, int, float, float, float: true positive, false positive, false negative, precision, recall, F1-score
    """

    if len(preds) == 0:
        # return all zeros if no prediction was made
        return 0, 0, 0, 0, 0, 0

    dist_matrix = cdist(preds, targets, 'euclidean')  # shape: (pred size, target size) 

    candidate_matrix = np.where(dist_matrix <= dist_thresh, 1, 0)
    # Candidate(i, j) = 1 means prediction i is close enough to targete j
    
    # the problem of uniquely assigning targets with predictions can be solved by the Hungarian algorithm
    # first, reverse of the candidate matrix to a cost matrix, Cost(i, j) = 0 iff Candidate(i, j) = 1, inf otherwise
    # we didn't use negative costs to fit the standard setting of the assignment problem
    # math.inf will cause problems with linear_sum_assignment(), using a large number instead
    cost_matrix = np.where(candidate_matrix == 1, 0, 10**10)

    # pad the cost matrix into a square matrix
    max_dim = max(cost_matrix.shape)
    pad_rows = max_dim - cost_matrix.shape[0]
    pad_cols = max_dim - cost_matrix.shape[1]
    cost_matrix = np.pad(cost_matrix, ((0, pad_rows), (0, pad_cols)), 'constant', constant_values=10**10)
    
    # print('[debug] compute_true_false(): cost matrix shape (afer padding):', cost_matrix.shape)

    # fit the Hungarian algorithm to find the optimal solution
    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    # the solution is represented as index pairs, no repetitive rows or columns are used
    # but it might choose an index pair from the padding region or choose an entry with large cost (sub-optimal)
    # as a result, we need to manually remove those assignments

    index_pairs = list(zip(row_ind, col_ind))  # list of two-element tuples
    # print(f'[debug] compute_true_false(): before post-processing, {len(index_pairs)} assignments')

    final_pairs = copy.deepcopy(index_pairs)

    for pair in index_pairs:
        # to be a valid assignment, the cost must be 0 (i.e., the entry must be 1 in the candidate matrix)
        # and the indcies must not be in the padding region
        if cost_matrix[pair] != 0 or pair[0] >= candidate_matrix.shape[0] or pair[1] >= candidate_matrix.shape[1]:
            final_pairs.remove(pair)

    # print(f'[debug] compute_true_false(): after post-processing, {len(final_pairs)} assignments')

    # the true positive is then the number of remaining assignments
    tp = len(final_pairs)

    # total predictions - true positives (unassigned predictions)
    fp = preds.shape[0] - tp

    # total targets - true positives (unassigned targets)
    fn = targets.shape[0] - tp

    try:
        precision = tp / preds.shape[0]
        recall = tp / targets.shape[0]
        f1score = 2 * precision * recall / (precision + recall)
    except ZeroDivisionError:
        # in case TP=0, then all 3 metrics should be zero
        precision, recall, f1score = 0, 0, 0

    return tp, fp, fn, precision, recall, f1score



In [3]:
def read_loops(loop_path):
    loops = []
    score_id = 6
    if 'chromosight' in loop_path:
        score_id = 10
    elif 'train' in loop_path:
        score_id = 0
    elif 'hiccups' in loop_path:
        score_id = 16
    
    with open(loop_path, 'r') as loop_file:
        for line in loop_file:
            if line.strip('\n').split('\t')[0] == 'chrom1':
                continue
            if '#' in line.strip('\n').split('\t')[0] :
                continue
            
            line_list = line.strip('\n').split('\t')

            loop_info = line_list[:6]
            
            if score_id !=0:
                if 'hicexplorer' in loop_path or 'hiccups' in loop_path:
                    loop_score = 1- float(line_list[score_id])
                    loop_info.append(loop_score)
                    loops.append(loop_info)
                else:
                    loop_score = float(line_list[score_id])
                    loop_info.append(loop_score)
                    loops.append(loop_info)
            else:
                loops.append(loop_info)          
    return loops

In [4]:
threshold_list = np.arange(0, 1.1, 0.1)

In [5]:
cell_type = 'gm12878'

In [None]:
gt_file = '/Dataset/HiC/hic/loop_train/ctcf_{}.bedpe'.format(cell_type)



gt_loops = read_loops(gt_file)
print(len(gt_loops))

threshold_list = np.arange(0, 1.1, 0.1)

PR_Dic = {
    'Chromosight':[],
    'HiCExplorer':[],
    'YOLOOP':[],
    'Peakachu':[],
    'HICCUPS':[]
}




for benchmark in ['HICCUPS','Chromosight','HiCExplorer','YOLOOP','Peakachu']:

    print("method: {}".format(benchmark))
    if benchmark == 'YOLOOP':
        pred_file = 'yoloop_prediction/10kb/yoloop_pred_{}.bedpe'.format(cell_type)
        pred_loops = read_loops(pred_file)
    elif benchmark == 'HICCUPS':
        pred_file = 'benchmarks/hiccups/{}-hic_10kb/merged_loops.bedpe'.format(cell_type)
        pred_loops = read_loops(pred_file)
    elif benchmark == 'Chromosight':
        pred_file = 'benchmarks/chromosight/{}-hic_10kb/{}-hic_10kb.tsv'.format(cell_type,cell_type)
        pred_loops = read_loops(pred_file)
    elif benchmark == 'HiCExplorer':
        pred_file = 'benchmarks/hicexplorer/{}-hic_10kb/{}-hic_10kb.bedgraph'.format(cell_type,cell_type)
        pred_loops = read_loops(pred_file)
    elif benchmark == 'Peakachu':
        pred_loops = []
        for i in [1,9,14]:
            pred_file = 'benchmarks/peakachu/{}-hic_ctcf-chiapet_10kb/pool/chr{}.bedpe'.format(cell_type,i)
            chr_loops = read_loops(pred_file)
            pred_loops +=chr_loops
    
    
    # PR-Threshold
    for threshold in threshold_list:

        thresholded_pred = []

        for pred in pred_loops:
            score = float(pred[-1])
            if score < threshold:
                continue
            thresholded_pred.append(pred)
        print('Threshold : {}'.format(threshold))
        
        precision_list = []
        recall_list = []
                
        for target_chrom in ['chr1','chr9','chr14']:
            gt_list = []
            pred_list = []

            for pred_loop in tqdm(thresholded_pred):
                pred_chr = pred_loop[0]

                if 'chr' not in pred_loop[0]:
                    pred_chr = 'chr'+str(pred_loop[0])

                if pred_chr !=target_chrom:
                    continue
                x = int((int(pred_loop[1]) + int(pred_loop[2])) * 0.5)
                y = int((int(pred_loop[4]) + int(pred_loop[5])) * 0.5)
                pred_list.append([min(x, y), max(x, y)])

            for gt_loop in gt_loops:

                if gt_loop[0] !=target_chrom:
                    continue
                x = int((int(gt_loop[1]) + int(gt_loop[2])) * 0.5)
                y = int((int(gt_loop[4]) + int(gt_loop[5])) * 0.5)
                gt_list.append([min(x, y), max(x, y)])
    
        
            _, _, _, precision, recall, _ = compute_true_false( np.array(pred_list),np.array(gt_list), 10 * 10000)
            precision_list.append(precision)
            recall_list.append(recall)
        avg_precision = np.mean(precision)
        avg_recall = np.mean(recall)
        PR_Dic[benchmark].append([avg_precision,avg_recall])


55086
method: HICCUPS
Threshold : 0.0


100%|██████████| 8469/8469 [00:00<00:00, 1563793.11it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1593466.74it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1682927.97it/s]


Threshold : 0.1


100%|██████████| 8469/8469 [00:00<00:00, 1381731.78it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1646880.27it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1482597.80it/s]


Threshold : 0.2


100%|██████████| 8469/8469 [00:00<00:00, 1474719.17it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1827052.80it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1663461.67it/s]


Threshold : 0.30000000000000004


100%|██████████| 8469/8469 [00:00<00:00, 1506172.01it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1682370.02it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1648714.81it/s]


Threshold : 0.4


100%|██████████| 8469/8469 [00:00<00:00, 1482412.18it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1798742.18it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1654859.57it/s]


Threshold : 0.5


100%|██████████| 8469/8469 [00:00<00:00, 1506938.77it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1772090.82it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1466802.68it/s]


Threshold : 0.6000000000000001


100%|██████████| 8469/8469 [00:00<00:00, 1494449.10it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1778478.98it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1638977.56it/s]


Threshold : 0.7000000000000001


100%|██████████| 8469/8469 [00:00<00:00, 1484146.43it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1648102.84it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1484580.62it/s]


Threshold : 0.8


100%|██████████| 8469/8469 [00:00<00:00, 1504768.30it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1743132.82it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1680857.45it/s]


Threshold : 0.9


100%|██████████| 8469/8469 [00:00<00:00, 1480928.90it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1541667.49it/s]
100%|██████████| 8469/8469 [00:00<00:00, 1607528.65it/s]


Threshold : 1.0


100%|██████████| 4495/4495 [00:00<00:00, 1412345.23it/s]
100%|██████████| 4495/4495 [00:00<00:00, 1548660.79it/s]
100%|██████████| 4495/4495 [00:00<00:00, 1475457.54it/s]


method: Chromosight
Threshold : 0.0


100%|██████████| 40732/40732 [00:00<00:00, 2058317.26it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2157863.77it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2197414.57it/s]


Threshold : 0.1


100%|██████████| 40732/40732 [00:00<00:00, 1962848.30it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2182368.98it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2208321.69it/s]


Threshold : 0.2


100%|██████████| 40732/40732 [00:00<00:00, 2004416.02it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2148852.77it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2184936.76it/s]


Threshold : 0.30000000000000004


100%|██████████| 40732/40732 [00:00<00:00, 2067909.25it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2110989.63it/s]
100%|██████████| 40732/40732 [00:00<00:00, 2143998.68it/s]


Threshold : 0.4


100%|██████████| 23669/23669 [00:00<00:00, 1901528.14it/s]
100%|██████████| 23669/23669 [00:00<00:00, 2104844.30it/s]
100%|██████████| 23669/23669 [00:00<00:00, 2139685.35it/s]


Threshold : 0.5


100%|██████████| 14893/14893 [00:00<00:00, 1652183.92it/s]
100%|██████████| 14893/14893 [00:00<00:00, 1996668.35it/s]
100%|██████████| 14893/14893 [00:00<00:00, 1842213.33it/s]


Threshold : 0.6000000000000001


100%|██████████| 9445/9445 [00:00<00:00, 1470115.46it/s]
100%|██████████| 9445/9445 [00:00<00:00, 1557997.45it/s]
100%|██████████| 9445/9445 [00:00<00:00, 1543910.57it/s]


Threshold : 0.7000000000000001


100%|██████████| 5355/5355 [00:00<00:00, 1368457.80it/s]
100%|██████████| 5355/5355 [00:00<00:00, 1634560.65it/s]
100%|██████████| 5355/5355 [00:00<00:00, 1435083.89it/s]


Threshold : 0.8


100%|██████████| 2025/2025 [00:00<00:00, 1206286.83it/s]
100%|██████████| 2025/2025 [00:00<00:00, 1494802.11it/s]
100%|██████████| 2025/2025 [00:00<00:00, 1217177.64it/s]


Threshold : 0.9


100%|██████████| 124/124 [00:00<00:00, 587013.20it/s]
100%|██████████| 124/124 [00:00<00:00, 895169.87it/s]
100%|██████████| 124/124 [00:00<00:00, 824237.24it/s]


Threshold : 1.0


0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]


method: HiCExplorer
Threshold : 0.0


100%|██████████| 9653/9653 [00:00<00:00, 1537932.71it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1522033.63it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1569835.08it/s]


Threshold : 0.1


100%|██████████| 9653/9653 [00:00<00:00, 1468805.24it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1527776.93it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1499319.23it/s]


Threshold : 0.2


100%|██████████| 9653/9653 [00:00<00:00, 1443563.18it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1379757.92it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1491366.45it/s]


Threshold : 0.30000000000000004


100%|██████████| 9653/9653 [00:00<00:00, 1498209.61it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1549883.88it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1546982.14it/s]


Threshold : 0.4


100%|██████████| 9653/9653 [00:00<00:00, 1498043.31it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1832101.75it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1580806.52it/s]


Threshold : 0.5


100%|██████████| 9653/9653 [00:00<00:00, 1513951.93it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1545741.86it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1501989.04it/s]


Threshold : 0.6000000000000001


100%|██████████| 9653/9653 [00:00<00:00, 1454192.10it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1557575.46it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1588185.64it/s]


Threshold : 0.7000000000000001


100%|██████████| 9653/9653 [00:00<00:00, 1499319.23it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1553034.77it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1560637.42it/s]


Threshold : 0.8


100%|██████████| 9653/9653 [00:00<00:00, 1513612.34it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1546568.49it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1546155.06it/s]


Threshold : 0.9


100%|██████████| 9653/9653 [00:00<00:00, 1509774.27it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1583465.00it/s]
100%|██████████| 9653/9653 [00:00<00:00, 1572762.17it/s]


Threshold : 1.0


100%|██████████| 3386/3386 [00:00<00:00, 1255917.35it/s]
100%|██████████| 3386/3386 [00:00<00:00, 1355273.72it/s]
100%|██████████| 3386/3386 [00:00<00:00, 1369387.07it/s]


method: YOLOOP
Threshold : 0.0


100%|██████████| 106417/106417 [00:00<00:00, 1983668.50it/s]
100%|██████████| 106417/106417 [00:00<00:00, 2061069.67it/s]
100%|██████████| 106417/106417 [00:00<00:00, 2175405.01it/s]


Threshold : 0.1


100%|██████████| 106417/106417 [00:00<00:00, 793920.44it/s]
100%|██████████| 106417/106417 [00:00<00:00, 1863833.47it/s]
100%|██████████| 106417/106417 [00:00<00:00, 2238037.51it/s]


Threshold : 0.2


100%|██████████| 106417/106417 [00:00<00:00, 1973258.92it/s]
100%|██████████| 106417/106417 [00:00<00:00, 2126537.09it/s]
100%|██████████| 106417/106417 [00:00<00:00, 2212828.69it/s]


Threshold : 0.30000000000000004


100%|██████████| 100740/100740 [00:00<00:00, 1979741.20it/s]
100%|██████████| 100740/100740 [00:00<00:00, 2147547.09it/s]
100%|██████████| 100740/100740 [00:00<00:00, 2222589.07it/s]


Threshold : 0.4


100%|██████████| 90960/90960 [00:00<00:00, 1977698.65it/s]
100%|██████████| 90960/90960 [00:00<00:00, 2089067.66it/s]
100%|██████████| 90960/90960 [00:00<00:00, 2231545.19it/s]


Threshold : 0.5


100%|██████████| 82655/82655 [00:00<00:00, 1944789.93it/s]
100%|██████████| 82655/82655 [00:00<00:00, 2116588.09it/s]
100%|██████████| 82655/82655 [00:00<00:00, 2195345.61it/s]


Threshold : 0.6000000000000001


100%|██████████| 74860/74860 [00:00<00:00, 1866351.22it/s]
100%|██████████| 74860/74860 [00:00<00:00, 2136479.66it/s]
100%|██████████| 74860/74860 [00:00<00:00, 2206178.97it/s]


Threshold : 0.7000000000000001


100%|██████████| 66892/66892 [00:00<00:00, 2005614.29it/s]
100%|██████████| 66892/66892 [00:00<00:00, 2120916.08it/s]
100%|██████████| 66892/66892 [00:00<00:00, 2095663.87it/s]


Threshold : 0.8


100%|██████████| 58494/58494 [00:00<00:00, 1856932.37it/s]
100%|██████████| 58494/58494 [00:00<00:00, 2124721.73it/s]
100%|██████████| 58494/58494 [00:00<00:00, 2231615.88it/s]


Threshold : 0.9


100%|██████████| 47020/47020 [00:00<00:00, 1894742.56it/s]
100%|██████████| 47020/47020 [00:00<00:00, 2137844.71it/s]
100%|██████████| 47020/47020 [00:00<00:00, 2211861.13it/s]


Threshold : 1.0


100%|██████████| 1255/1255 [00:00<00:00, 2245670.44it/s]
100%|██████████| 1255/1255 [00:00<00:00, 2037882.90it/s]
100%|██████████| 1255/1255 [00:00<00:00, 2277737.57it/s]


method: Peakachu
Threshold : 0.0


100%|██████████| 50896/50896 [00:00<00:00, 787041.85it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1133187.69it/s]
100%|██████████| 50896/50896 [00:00<00:00, 414339.08it/s]


Threshold : 0.1


100%|██████████| 50896/50896 [00:00<00:00, 846655.97it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1091153.63it/s]
100%|██████████| 50896/50896 [00:00<00:00, 702764.98it/s]


Threshold : 0.2


100%|██████████| 50896/50896 [00:00<00:00, 822636.30it/s]
100%|██████████| 50896/50896 [00:00<00:00, 833531.67it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1282160.40it/s]


Threshold : 0.30000000000000004


100%|██████████| 50896/50896 [00:00<00:00, 809632.21it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1201678.04it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1251125.55it/s]


Threshold : 0.4


100%|██████████| 50896/50896 [00:00<00:00, 370778.80it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1279539.77it/s]
100%|██████████| 50896/50896 [00:00<00:00, 1499120.75it/s]


Threshold : 0.5


100%|██████████| 50896/50896 [00:00<00:00, 411406.14it/s]


In [None]:
PR_Dic

In [None]:
# Calculate AUC

for method in PR_Dic.keys():
    precision_list = [PR_Dic[method][i][0] for i in range(len(PR_Dic[method]))]
    recall_list = [PR_Dic[method][i][1] for i in range(len(PR_Dic[method]))]
    
    pr_auc = auc(recall_list, precision_list)
    
    print("{}: AUC:{}".format(method,pr_auc))