In [1]:
import os
import re

import nibabel as nib
import numpy as np

In [2]:
det_path = "/home/yutongx/src_data/det_bbox"
gt_path = "/home/yutongx/src_data/labels"
save_path = "./bbox"
 
gt_train_path = os.path.join(gt_path, "train")
gt_val_path = os.path.join(gt_path, "val")

save_train_path = os.path.join(save_path, "train")
save_val_path = os.path.join(save_path, "val")

In [3]:
idx = lambda name: re.sub(r"\D", "", name)
get_names =lambda path: sorted([os.path.join(path, name) for name in os.listdir(path)])

det_names = get_names(det_path)
if len(det_names) == 500:
    det_train_names = sorted([name for name in det_names if int(idx(name))<=420], key=idx)
    det_val_names = sorted([name for name in det_names if int(idx(name))>420], key=idx)
else:
    raise ValueError('Detection numbers mismatch.')
    
gt_train_names = get_names(gt_train_path)
gt_val_names = get_names(gt_val_path)

In [4]:
def masks2bboxes(masks, border):
    instance_nums = [num for num in np.unique(masks) if num]
    bboxes = []
    
    for i in instance_nums:
        mask = (masks == i)
        if np.any(mask):
            zz, yy, xx = np.where(mask)
            bboxes.append([(zz.max() + zz.min()) / 2.,
                           (yy.max() + yy.min()) / 2.,
                           (xx.max() + xx.min()) / 2.,
                           zz.max() - zz.min() + 1 + border,
                           yy.max() - yy.min() + 1 + border,
                           xx.max() - xx.min() + 1 + border])
    return bboxes

def remove_array(lst,arr,idx=0):
    size = len(lst)
    while idx != size and not np.array_equal(lst[idx],arr):
        idx += 1
    if idx != size:
        lst.pop(idx)
    else:
        raise ValueError('array not found in list.')

def preprocessing(det_names, gt_names, save_path, border=8, th_pos=0.6, th_neg=0.1):
    save_path = save_path if os.path.exists(save_path) else os.makedirs(save_path)
    cnt_all = np.zeros(4)
    
    for det_name, gt_name in zip(det_names, gt_names):
        print('====================')
        print('det_name:', det_name)
        print('gt_name:', gt_name)
        
        assert idx(det_name) == idx(gt_name), 'Index Mismatch.'
        
        det_src = np.load(det_name, allow_pickle=True)
        gt_src = (np.swapaxes(nib.load(gt_name).get_fdata(), -1, 0)).astype(np.uint8)

        gt_pos = masks2bboxes(gt_src, border=border)
        rpn_pos = []
        rpn_neg = [det_src[i, 1:] for i in range(det_src.shape[0])]

        for index, gt in enumerate(gt_pos):
            print('--------------------')
            print('gt_idx:', index)
            
            hit = 1
            for rpn in reversed(rpn_neg):## del elements while iterating
                intsc = (rpn[3:] + gt[3:])/2 - np.abs(rpn[:3] - gt[:3])#calculate intersection size in 3 dims
                
                if np.all(intsc >= 1):
                    print('----------')
                    print('hit #%d:'%(hit))
                    
                    hit +=1
                    intsc_vol = np.prod(intsc)
                    iou = intsc_vol / (np.prod(rpn[3:]) + np.prod(gt[3:]) - intsc_vol 
                                   + np.finfo(intsc_vol.dtype).eps)#calculate iou
                    
                    print('iou:', round(iou, 4))
                    if iou > th_neg:
                        remove_array(rpn_neg, rpn)#ambi
                        if iou > th_pos:#pos
                            rpn_pos.append(rpn)

        
        cnt = np.array([len(gt_pos), det_src.shape[0], len(rpn_pos), len(rpn_neg)])
        cnt_all += cnt
        
        print('--------------------')
        print('gt_cnt:', cnt[0])
        print('det_cnt:', cnt[1])
        print('rpn_pos_cnt:', cnt[2])
        print('rpn_neg_cnt:', cnt[3])
        
        np.savez(os.path.join(save_path, ''.join((idx(gt_name), '_bbox.npz'))), 
                 gt_pos = np.array(gt_pos), rpn_pos=np.array(rpn_pos), rpn_neg=rpn_neg)
    
    print('====================')
    print('====================')
    print('gt_cnt_all:', cnt_all[0])
    print('det_cnt_all:', cnt_all[1])
    print('rpn_pos_cnt_all:', cnt_all[2])
    print('rpn_neg_cnt_all:', cnt_all[3])

In [None]:
preprocessing(det_train_names, gt_train_names, save_train_path, border=8, th_pos=0.6, th_neg=0.1)

In [5]:
preprocessing(det_val_names, gt_val_names, save_val_path, border=8, th_pos=0.6, th_neg=0.1)

det_name: /home/yutongx/src_data/det_bbox/421_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac421-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.7697
----------
hit #2:
iou: 0.0771
----------
hit #3:
iou: 0.0153
----------
hit #4:
iou: 0.003
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.686
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0481
----------
hit #2:
iou: 0.0895
----------
hit #3:
iou: 0.659
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.701
--------------------
gt_cnt: 4
det_cnt: 53
rpn_pos_cnt: 4
rpn_neg_cnt: 49
det_name: /home/yutongx/src_data/det_bbox/422_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac422-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.6613
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.9054
----------
hit #2:
iou: 0.0377
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0721
----------
hit #2:
iou: 0.6254
----------
hit #3:
iou: 0

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.6454
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.0732
----------
hit #2:
iou: 0.0611
----------
hit #3:
iou: 0.6983
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.18
----------
hit #2:
iou: 0.7587
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.9039
--------------------
gt_idx: 4
----------
hit #1:
iou: 0.1487
----------
hit #2:
iou: 0.0624
----------
hit #3:
iou: 0.1213
----------
hit #4:
iou: 0.7429
--------------------
gt_idx: 5
----------
hit #1:
iou: 0.8237
--------------------
gt_cnt: 6
det_cnt: 64
rpn_pos_cnt: 6
rpn_neg_cnt: 55
det_name: /home/yutongx/src_data/det_bbox/433_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac433-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.118
----------
hit #2:
iou: 0.0721
----------
hit #3:
iou: 0.7805
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.7897
----------
hit #2:
iou: 0.0681
--------------------
gt_idx

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.7968
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.9252
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0774
----------
hit #2:
iou: 0.7592
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.0493
----------
hit #2:
iou: 0.1095
----------
hit #3:
iou: 0.6207
--------------------
gt_idx: 4
----------
hit #1:
iou: 0.6831
--------------------
gt_idx: 5
----------
hit #1:
iou: 0.1161
----------
hit #2:
iou: 0.0984
----------
hit #3:
iou: 0.6214
--------------------
gt_idx: 6
----------
hit #1:
iou: 0.0529
----------
hit #2:
iou: 0.0246
----------
hit #3:
iou: 0.6505
--------------------
gt_idx: 7
----------
hit #1:
iou: 0.3844
----------
hit #2:
iou: 0.5081
--------------------
gt_idx: 8
----------
hit #1:
iou: 0.1055
----------
hit #2:
iou: 0.2707
----------
hit #3:
iou: 0.0348
----------
hit #4:
iou: 0.2488
--------------------
gt_cnt: 9
det_cnt: 55
rpn_pos_cnt: 7
rpn_neg_cnt: 41
det_name: /home/yutongx/src

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.8844
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.3765
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.4201
--------------------
gt_cnt: 3
det_cnt: 82
rpn_pos_cnt: 1
rpn_neg_cnt: 79
det_name: /home/yutongx/src_data/det_bbox/450_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac450-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.1355
----------
hit #2:
iou: 0.1154
----------
hit #3:
iou: 0.0366
----------
hit #4:
iou: 0.6503
--------------------
gt_cnt: 1
det_cnt: 65
rpn_pos_cnt: 1
rpn_neg_cnt: 62
det_name: /home/yutongx/src_data/det_bbox/451_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac451-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.9244
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.0727
----------
hit #2:
iou: 0.8163
----------
hit #3:
iou: 0.0617
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0949
---------

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.5299
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.5843
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0997
----------
hit #2:
iou: 0.7038
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.4043
----------
hit #2:
iou: 0.0481
--------------------
gt_idx: 4
----------
hit #1:
iou: 0.0966
----------
hit #2:
iou: 0.0981
----------
hit #3:
iou: 0.9212
--------------------
gt_cnt: 5
det_cnt: 84
rpn_pos_cnt: 2
rpn_neg_cnt: 79
det_name: /home/yutongx/src_data/det_bbox/465_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac465-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.0526
----------
hit #2:
iou: 0.8998
----------
hit #3:
iou: 0.0066
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.113
----------
hit #2:
iou: 0.1195
----------
hit #3:
iou: 0.0186
----------
hit #4:
iou: 0.7461
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.1603
----------
hit #2:
iou: 0

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.5924
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.0726
----------
hit #2:
iou: 0.1661
----------
hit #3:
iou: 0.7417
----------
hit #4:
iou: 0.1186
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.1206
----------
hit #2:
iou: 0.0691
----------
hit #3:
iou: 0.8268
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.1502
----------
hit #2:
iou: 0.237
----------
hit #3:
iou: 0.6607
----------
hit #4:
iou: 0.0448
--------------------
gt_idx: 4
----------
hit #1:
iou: 0.3034
--------------------
gt_cnt: 5
det_cnt: 67
rpn_pos_cnt: 3
rpn_neg_cnt: 57
det_name: /home/yutongx/src_data/det_bbox/480_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac480-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.182
----------
hit #2:
iou: 0.0955
----------
hit #3:
iou: 0.1468
----------
hit #4:
iou: 0.6073
----------
hit #5:
iou: 0.1838
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.

--------------------
gt_cnt: 0
det_cnt: 36
rpn_pos_cnt: 0
rpn_neg_cnt: 36
det_name: /home/yutongx/src_data/det_bbox/488_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac488-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.1441
----------
hit #2:
iou: 0.0649
----------
hit #3:
iou: 0.7116
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.1191
----------
hit #2:
iou: 0.0657
----------
hit #3:
iou: 0.1781
----------
hit #4:
iou: 0.004
----------
hit #5:
iou: 0.6832
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.1428
----------
hit #2:
iou: 0.0478
----------
hit #3:
iou: 0.6658
--------------------
gt_cnt: 3
det_cnt: 50
rpn_pos_cnt: 3
rpn_neg_cnt: 43
det_name: /home/yutongx/src_data/det_bbox/489_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac489-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.7481
--------------------
gt_idx: 1
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0991
----------

--------------------
gt_idx: 0
----------
hit #1:
iou: 0.071
----------
hit #2:
iou: 0.0126
----------
hit #3:
iou: 0.7386
--------------------
gt_idx: 1
----------
hit #1:
iou: 0.6585
--------------------
gt_idx: 2
----------
hit #1:
iou: 0.0725
----------
hit #2:
iou: 0.0687
----------
hit #3:
iou: 0.1292
----------
hit #4:
iou: 0.7559
--------------------
gt_idx: 3
----------
hit #1:
iou: 0.0071
----------
hit #2:
iou: 0.2386
----------
hit #3:
iou: 0.1089
----------
hit #4:
iou: 0.8406
--------------------
gt_idx: 4
----------
hit #1:
iou: 0.086
----------
hit #2:
iou: 0.0386
----------
hit #3:
iou: 0.0702
----------
hit #4:
iou: 0.629
--------------------
gt_cnt: 5
det_cnt: 73
rpn_pos_cnt: 5
rpn_neg_cnt: 65
det_name: /home/yutongx/src_data/det_bbox/499_rpns_list.npy
gt_name: /home/yutongx/src_data/labels/val/RibFrac499-label.nii.gz
--------------------
gt_idx: 0
----------
hit #1:
iou: 0.2764
----------
hit #2:
iou: 0.0955
----------
hit #3:
iou: 0.0849
----------
hit #4:
iou: 0.2