In [1]:
import os
import re

import nibabel as nib
import numpy as np
import pandas as pd

In [2]:
det_path = "/home/yutongx/src_data/det_bbox/"
gt_path = "/home/yutongx/src_data/labels/"
csv_path = '/home/yutongx/src_data/info/'
save_path = "./bbox_multi"
 
gt_train_path = os.path.join(gt_path, "train")
gt_val_path = os.path.join(gt_path, "val")

save_train_path = os.path.join(save_path, "train")
save_val_path = os.path.join(save_path, "val")

csv_train_path = os.path.join(csv_path, 'ribfrac-train-info.csv')
csv_val_path = os.path.join(csv_path, 'ribfrac-val-info.csv')

In [3]:
idx = lambda name: re.sub(r"\D", "", name)
get_names = lambda path: sorted([os.path.join(path, name) for name in os.listdir(path)])

det_names = get_names(det_path)
if len(det_names) == 500:
    det_train_names = sorted([name for name in det_names if int(idx(name))<=420], key=idx)
    det_val_names = sorted([name for name in det_names if int(idx(name))>420], key=idx)
else:
    raise ValueError('Detection numbers mismatch.')
    
gt_train_names = get_names(gt_train_path)
gt_val_names = get_names(gt_val_path)

In [4]:
def masks2bboxes(masks, sub_df, border):
    instance_nums = [num for num in np.unique(masks) if num]
    bboxes = []
    
    for i in instance_nums:
        mask = (masks == i)
        if np.any(mask):
            zz, yy, xx = np.where(mask)   
            label = (sub_df.loc[sub_df['label_id'] == i])['label_code'].values[0]
            
            bboxes.append([(zz.max() + zz.min()) / 2.,
                           (yy.max() + yy.min()) / 2.,
                           (xx.max() + xx.min()) / 2.,
                           zz.max() - zz.min() + 1 + border,
                           yy.max() - yy.min() + 1 + border,
                           xx.max() - xx.min() + 1 + border,
                           label])
    return bboxes

def remove_array(lst,arr,idx=0):
    size = len(lst)
    while idx != size and not np.array_equal(lst[idx],arr):
        idx += 1
    if idx != size:
        lst.pop(idx)
    else:
        raise ValueError('array not found in list.')

def preprocessing(det_names, gt_names, save_path, csv_path, border=8, th_pos=0.6, th_neg=0.1):
    save_path = save_path if os.path.exists(save_path) else os.makedirs(save_path)
    cnt_all = np.zeros(4)
    
    df = pd.read_csv(csv_path)
    
    for det_name, gt_name in zip(det_names, gt_names):
        print('====================')
        print('det_name:', det_name)
        print('gt_name:', gt_name)
        
        assert idx(det_name) == idx(gt_name), 'Index Mismatch.'
        
        det_src = np.load(det_name, allow_pickle=True)
        det_src = np.pad(det_src, ((0,0),(0,1)), 'constant')
        gt_src = (np.swapaxes(nib.load(gt_name).get_fdata(), -1, 0)).astype(np.uint8)

        sub_df = df.loc[df['public_id'] == ''.join('RibFrac' + idx(gt_name))]
        
        gt_pos = masks2bboxes(gt_src, sub_df, border=border)
        rpn_pos = []
        rpn_neg = [det_src[i, 1:] for i in range(det_src.shape[0])]

        for gt in gt_pos:
            hit = 1
            for rpn in reversed(rpn_neg):## del elements while iterating
                intsc = (rpn[3:6] + gt[3:6])/2 - np.abs(rpn[:3] - gt[:3])#calculate intersection size in 3 dims
                
                if np.all(intsc >= 1):
                    print('----------')
                    print('hit #%d:'%(hit))
                    
                    hit +=1
                    intsc_vol = np.prod(intsc)
                    iou = intsc_vol / (np.prod(rpn[3:6]) + np.prod(gt[3:6]) - intsc_vol 
                                   + np.finfo(intsc_vol.dtype).eps)#calculate iou
                    
                    print('iou:', round(iou, 4))
                    if iou > th_neg:
                        remove_array(rpn_neg, rpn)#ambi
                        if iou > th_pos:#pos
                            rpn[-1] = gt[-1]#same with corresponding gt
                            rpn_pos.append(rpn)

        
        cnt = np.array([len(gt_pos), det_src.shape[0], len(rpn_pos), len(rpn_neg)])
        cnt_all += cnt
        
        print('--------------------')
        print('gt_cnt:', cnt[0])
        print('det_cnt:', cnt[1])
        print('rpn_pos_cnt:', cnt[2])
        print('rpn_neg_cnt:', cnt[3])
        
        np.savez(os.path.join(save_path, ''.join((idx(gt_name), '_bbox.npz'))), 
                 gt_pos = np.array(gt_pos), rpn_pos=np.array(rpn_pos), rpn_neg=rpn_neg)
    
    print('====================')
    print('====================')
    print('gt_cnt_all:', cnt_all[0])
    print('det_cnt_all:', cnt_all[1])
    print('rpn_pos_cnt_all:', cnt_all[2])
    print('rpn_neg_cnt_all:', cnt_all[3])

In [None]:
preprocessing(det_train_names, gt_train_names, save_train_path, csv_train_path, border=8, th_pos=0.6, th_neg=0.1)

det_name: /home/yutongx/src_data/det_bbox/1_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac1-label.nii.gz
----------
hit #1:
iou: 0.109
----------
hit #2:
iou: 0.1099
----------
hit #3:
iou: 0.8784
----------
hit #1:
iou: 0.7762
--------------------
gt_cnt: 2
det_cnt: 63
rpn_pos_cnt: 2
rpn_neg_cnt: 59
det_name: /home/yutongx/src_data/det_bbox/10_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac10-label.nii.gz
----------
hit #1:
iou: 0.0577
----------
hit #2:
iou: 0.0688
----------
hit #3:
iou: 0.0911
----------
hit #4:
iou: 0.0139
----------
hit #5:
iou: 0.6716
----------
hit #6:
iou: 0.029
----------
hit #1:
iou: 0.0532
----------
hit #2:
iou: 0.2014
----------
hit #3:
iou: 0.733
----------
hit #1:
iou: 0.0822
----------
hit #2:
iou: 0.0723
----------
hit #3:
iou: 0.127
----------
hit #4:
iou: 0.802
----------
hit #1:
iou: 0.1229
----------
hit #2:
iou: 0.781
----------
hit #1:
iou: 0.0494
----------
hit #2:
iou: 0.6818
----------
hit #3:
iou: 0.026

----------
hit #1:
iou: 0.8072
----------
hit #1:
iou: 0.5578
----------
hit #1:
iou: 0.7479
----------
hit #1:
iou: 0.6216
----------
hit #1:
iou: 0.5508
--------------------
gt_cnt: 6
det_cnt: 59
rpn_pos_cnt: 3
rpn_neg_cnt: 54
det_name: /home/yutongx/src_data/det_bbox/108_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac108-label.nii.gz
----------
hit #1:
iou: 0.1251
----------
hit #2:
iou: 0.845
----------
hit #1:
iou: 0.0185
----------
hit #2:
iou: 0.0119
----------
hit #3:
iou: 0.129
----------
hit #4:
iou: 0.818
----------
hit #1:
iou: 0.1349
----------
hit #2:
iou: 0.1082
----------
hit #3:
iou: 0.6726
----------
hit #1:
iou: 0.7597
----------
hit #1:
iou: 0.8687
----------
hit #1:
iou: 0.1472
----------
hit #2:
iou: 0.1131
----------
hit #3:
iou: 0.1229
----------
hit #4:
iou: 0.8735
----------
hit #1:
iou: 0.7409
----------
hit #1:
iou: 0.0735
----------
hit #2:
iou: 0.811
----------
hit #1:
iou: 0.0622
----------
hit #2:
iou: 0.3983
----------
hit #3:
iou: 0.

----------
hit #1:
iou: 0.5465
----------
hit #1:
iou: 0.0943
----------
hit #2:
iou: 0.7136
----------
hit #1:
iou: 0.2696
--------------------
gt_cnt: 3
det_cnt: 84
rpn_pos_cnt: 1
rpn_neg_cnt: 81
det_name: /home/yutongx/src_data/det_bbox/117_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac117-label.nii.gz
----------
hit #1:
iou: 0.0862
----------
hit #2:
iou: 0.7724
----------
hit #1:
iou: 0.0929
----------
hit #2:
iou: 0.8805
----------
hit #3:
iou: 0.011
----------
hit #1:
iou: 0.3954
----------
hit #2:
iou: 0.7704
----------
hit #1:
iou: 0.0468
----------
hit #2:
iou: 0.0973
----------
hit #3:
iou: 0.6396
----------
hit #1:
iou: 0.1003
----------
hit #2:
iou: 0.1183
----------
hit #3:
iou: 0.8691
----------
hit #1:
iou: 0.1844
----------
hit #2:
iou: 0.0953
----------
hit #3:
iou: 0.7899
----------
hit #1:
iou: 0.6135
----------
hit #1:
iou: 0.5528
----------
hit #1:
iou: 0.0056
----------
hit #2:
iou: 0.3626
----------
hit #1:
iou: 0.6947
----------
hit #1:
iou:

----------
hit #1:
iou: 0.56
----------
hit #1:
iou: 0.0505
----------
hit #2:
iou: 0.6542
----------
hit #1:
iou: 0.0843
----------
hit #2:
iou: 0.15
----------
hit #3:
iou: 0.1417
----------
hit #4:
iou: 0.0589
----------
hit #5:
iou: 0.7148
----------
hit #1:
iou: 0.1009
----------
hit #2:
iou: 0.1469
----------
hit #3:
iou: 0.1524
----------
hit #4:
iou: 0.8143
----------
hit #1:
iou: 0.0057
----------
hit #2:
iou: 0.1119
----------
hit #3:
iou: 0.6317
----------
hit #1:
iou: 0.8165
----------
hit #1:
iou: 0.6463
----------
hit #1:
iou: 0.0795
----------
hit #2:
iou: 0.7344
----------
hit #1:
iou: 0.4481
----------
hit #1:
iou: 0.7026
----------
hit #1:
iou: 0.729
--------------------
gt_cnt: 14
det_cnt: 70
rpn_pos_cnt: 9
rpn_neg_cnt: 53
det_name: /home/yutongx/src_data/det_bbox/126_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac126-label.nii.gz
----------
hit #1:
iou: 0.7558
----------
hit #2:
iou: 0.1664
----------
hit #1:
iou: 0.4144
----------
hit #2:
iou: 0.

----------
hit #1:
iou: 0.0397
----------
hit #2:
iou: 0.0865
----------
hit #3:
iou: 0.7096
----------
hit #1:
iou: 0.3852
----------
hit #2:
iou: 0.8678
----------
hit #1:
iou: 0.1844
----------
hit #2:
iou: 0.1328
----------
hit #3:
iou: 0.7821
----------
hit #1:
iou: 0.0069
----------
hit #2:
iou: 0.9389
----------
hit #1:
iou: 0.424
----------
hit #2:
iou: 0.9001
----------
hit #1:
iou: 0.7765
----------
hit #1:
iou: 0.8004
----------
hit #1:
iou: 0.473
----------
hit #2:
iou: 0.7576
----------
hit #1:
iou: 0.1134
----------
hit #2:
iou: 0.9339
----------
hit #1:
iou: 0.0429
----------
hit #2:
iou: 0.0876
----------
hit #3:
iou: 0.7796
----------
hit #1:
iou: 0.1638
----------
hit #2:
iou: 0.1457
----------
hit #3:
iou: 0.0444
----------
hit #4:
iou: 0.4408
----------
hit #1:
iou: 0.0846
----------
hit #2:
iou: 0.6256
----------
hit #1:
iou: 0.0051
----------
hit #1:
iou: 0.0462
----------
hit #2:
iou: 0.5867
----------
hit #1:
iou: 0.0919
----------
hit #2:
iou: 0.3386
----------

----------
hit #1:
iou: 0.791
----------
hit #1:
iou: 0.0894
----------
hit #2:
iou: 0.8914
----------
hit #3:
iou: 0.0299
----------
hit #1:
iou: 0.6974
----------
hit #1:
iou: 0.0828
----------
hit #2:
iou: 0.0064
----------
hit #3:
iou: 0.0804
----------
hit #4:
iou: 0.0071
----------
hit #5:
iou: 0.1528
----------
hit #6:
iou: 0.0971
----------
hit #7:
iou: 0.0207
----------
hit #8:
iou: 0.0084
----------
hit #9:
iou: 0.0737
----------
hit #10:
iou: 0.175
----------
hit #11:
iou: 0.312
----------
hit #12:
iou: 0.1121
----------
hit #13:
iou: 0.278
----------
hit #14:
iou: 0.0537
----------
hit #1:
iou: 0.7957
----------
hit #1:
iou: 0.1454
----------
hit #2:
iou: 0.0905
----------
hit #3:
iou: 0.1029
----------
hit #4:
iou: 0.8289
----------
hit #1:
iou: 0.1647
----------
hit #1:
iou: 0.092
----------
hit #2:
iou: 0.0196
----------
hit #3:
iou: 0.1729
----------
hit #1:
iou: 0.0867
----------
hit #2:
iou: 0.6374
--------------------
gt_cnt: 15
det_cnt: 65
rpn_pos_cnt: 6
rpn_neg_cnt

----------
hit #1:
iou: 0.2336
----------
hit #1:
iou: 0.0121
----------
hit #2:
iou: 0.0008
----------
hit #3:
iou: 0.0831
----------
hit #4:
iou: 0.1763
----------
hit #5:
iou: 0.5104
----------
hit #1:
iou: 0.4106
--------------------
gt_cnt: 3
det_cnt: 82
rpn_pos_cnt: 0
rpn_neg_cnt: 78
det_name: /home/yutongx/src_data/det_bbox/153_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac153-label.nii.gz
----------
hit #1:
iou: 0.0198
----------
hit #2:
iou: 0.712
----------
hit #1:
iou: 0.1282
----------
hit #2:
iou: 0.1323
----------
hit #3:
iou: 0.085
----------
hit #4:
iou: 0.686
----------
hit #1:
iou: 0.7653
----------
hit #1:
iou: 0.0517
----------
hit #2:
iou: 0.2957
----------
hit #3:
iou: 0.0619
----------
hit #1:
iou: 0.4541
----------
hit #2:
iou: 0.0122
----------
hit #3:
iou: 0.5556
----------
hit #1:
iou: 0.4494
----------
hit #1:
iou: 0.1372
----------
hit #2:
iou: 0.1282
----------
hit #3:
iou: 0.1644
----------
hit #4:
iou: 0.1176
----------
hit #5:
iou: 0

----------
hit #1:
iou: 0.0103
----------
hit #2:
iou: 0.0506
----------
hit #3:
iou: 0.1142
----------
hit #4:
iou: 0.0265
----------
hit #5:
iou: 0.7144
----------
hit #1:
iou: 0.066
----------
hit #2:
iou: 0.5903
----------
hit #1:
iou: 0.0058
----------
hit #2:
iou: 0.1308
----------
hit #3:
iou: 0.0072
----------
hit #4:
iou: 0.1097
----------
hit #5:
iou: 0.1948
----------
hit #6:
iou: 0.1049
----------
hit #7:
iou: 0.248
----------
hit #8:
iou: 0.4956
----------
hit #1:
iou: 0.5014
----------
hit #1:
iou: 0.1115
----------
hit #2:
iou: 0.8297
----------
hit #1:
iou: 0.129
----------
hit #2:
iou: 0.8556
----------
hit #1:
iou: 0.2026
----------
hit #2:
iou: 0.6645
----------
hit #1:
iou: 0.0195
----------
hit #2:
iou: 0.782
----------
hit #1:
iou: 0.0709
----------
hit #2:
iou: 0.6343
----------
hit #1:
iou: 0.0102
----------
hit #2:
iou: 0.0859
----------
hit #3:
iou: 0.5862
----------
hit #1:
iou: 0.0102
----------
hit #2:
iou: 0.8114
--------------------
gt_cnt: 11
det_cnt: 70

----------
hit #1:
iou: 0.049
----------
hit #2:
iou: 0.4136
----------
hit #3:
iou: 0.9243
--------------------
gt_cnt: 1
det_cnt: 67
rpn_pos_cnt: 1
rpn_neg_cnt: 65
det_name: /home/yutongx/src_data/det_bbox/171_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/RibFrac171-label.nii.gz
----------
hit #1:
iou: 0.0175
----------
hit #2:
iou: 0.02
----------
hit #3:
iou: 0.4617
----------
hit #1:
iou: 0.0785
----------
hit #2:
iou: 0.1345
----------
hit #3:
iou: 0.0574
----------
hit #4:
iou: 0.023
----------
hit #5:
iou: 0.5943
----------
hit #1:
iou: 0.1218
----------
hit #2:
iou: 0.6322
----------
hit #1:
iou: 0.0323
----------
hit #2:
iou: 0.3144
----------
hit #1:
iou: 0.7226
----------
hit #1:
iou: 0.652
----------
hit #1:
iou: 0.0849
----------
hit #2:
iou: 0.2796
----------
hit #1:
iou: 0.8666
--------------------
gt_cnt: 9
det_cnt: 79
rpn_pos_cnt: 4
rpn_neg_cnt: 69
det_name: /home/yutongx/src_data/det_bbox/172_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/train/Rib

----------
hit #1:
iou: 0.2456
----------
hit #2:
iou: 0.0646
----------
hit #3:
iou: 0.3911
----------
hit #1:
iou: 0.1587
----------
hit #2:
iou: 0.7739
----------
hit #1:
iou: 0.1068
----------
hit #2:
iou: 0.7824
----------
hit #1:
iou: 0.0401
----------
hit #2:
iou: 0.5402
----------
hit #1:
iou: 0.1393
----------
hit #2:
iou: 0.7277
----------
hit #1:
iou: 0.0483
----------
hit #2:
iou: 0.0244
----------
hit #1:
iou: 0.1595
----------
hit #2:
iou: 0.832
----------
hit #1:
iou: 0.0287
----------
hit #1:
iou: 0.0694
----------
hit #2:
iou: 0.267
----------
hit #3:
iou: 0.6846
----------
hit #4:
iou: 0.0025
----------
hit #1:
iou: 0.0329
----------
hit #2:
iou: 0.5787
----------
hit #1:
iou: 0.0607
----------
hit #2:
iou: 0.1302
----------
hit #3:
iou: 0.8578
----------
hit #1:
iou: 0.1615
----------
hit #2:
iou: 0.7855
----------
hit #1:
iou: 0.79
--------------------
gt_cnt: 14
det_cnt: 62
rpn_pos_cnt: 8
rpn_neg_cnt: 43
det_name: /home/yutongx/src_data/det_bbox/18_rpns_list.npy
gt

In [None]:
preprocessing(det_val_names, gt_val_names, save_val_path, csv_val_path, border=8, th_pos=0.6, th_neg=0.1)