In [None]:
import os
import re
import sys

import nibabel as nib
import numpy as np

In [None]:
det_path = "/home/yutongx/src_data/det_bbox"
gt_path = "/home/yutongx/src_data/labels"
save_path = "./bbox"
 
gt_train_path = os.path.join(gt_path, "train")
gt_val_path = os.path.join(gt_path, "val")

save_train_path = os.path.join(save_path, "train")
save_val_path = os.path.join(save_path, "val")

In [None]:
idx = lambda name: re.sub(r"\D", "", name)
get_names =lambda path: sorted([os.path.join(path, name) for name in os.listdir(path)])

det_names = get_names(det_path)
if len(det_names) == 500:
    det_train_names = sorted([name for name in det_names if int(idx(name))<=420], key=idx)
    det_val_names = sorted([name for name in det_names if int(idx(name))>420], key=idx)
else:
    raise ValueError('Detection names mismatch.')
    
gt_train_names = get_names(gt_train_path)
gt_val_names = get_names(gt_val_path)

In [None]:
def masks2bboxes(masks, border):
    instance_nums = [num for num in np.unique(masks) if num]
    bboxes = []
    
    for i in instance_nums:
        mask = (masks == i)
        if np.any(mask):
            zz, yy, xx = np.where(mask)
            bboxes.append([(zz.max() + zz.min()) / 2.,
                           (yy.max() + yy.min()) / 2.,
                           (xx.max() + xx.min()) / 2.,
                           zz.max() - zz.min() + 1 + border,
                           yy.max() - yy.min() + 1 + border,
                           xx.max() - xx.min() + 1 + border])
    return bboxes

def remove_array(lst,arr,idx=0):
    size = len(lst)
    while idx != size and not np.array_equal(lst[idx],arr):
        idx += 1
    if idx != size:
        lst.pop(idx)
    else:
        raise ValueError('array not found in list.')

def preprocessing(det_names, gt_names, save_path, border=8, th=0.2, eps=0.02):
    save_path = save_path if os.path.exists(save_path) else os.makedirs(save_path)
    
    for det_name, gt_name in zip(det_names, gt_names):
        print('====================')
        print('det_name:', det_name)
        print('gt_name:', gt_name)
        
        assert idx(det_name) == idx(gt_name), 'Index Mismatch.'
        
        det_src = np.load(det_name, allow_pickle=True)
        gt_src = (np.swapaxes(nib.load(gt_name).get_fdata(), -1, 0)).astype(np.uint8)

        gt_pos = masks2bboxes(gt_src, border=border)
        rpn_pos = []
        rpn_neg = [det_src[i, 1:] for i in range(det_src.shape[0])]
        
        print('gt_num:', len(gt_pos))
        for index, gt in enumerate(gt_pos):
            print('--------------------')
            print('gt_idx:', index)
            print('gt_cen:', np.around(gt[:3], 2))
            print('gt_size:', np.around(gt[3:], 2))
            
            hit = 1
            for rpn in rpn_neg:
                intsc = (rpn[3:] + gt[3:])/2 - np.abs(rpn[:3] - gt[:3])#calculate intersection size in 3 dims
                
                if np.all(intsc >= 1):
                    print('----------')
                    print('hit #%d:'%(hit))
                    print('rpn_cen:', np.around(rpn[:3], 2))
                    print('rpn_size:', np.around(rpn[3:], 2))
                    
                    hit +=1
                    intsc_vol = np.prod(intsc)
                    iou = intsc_vol / (np.prod(rpn[3:]) + np.prod(gt[3:]) - intsc_vol 
                                   + np.finfo(intsc_vol.dtype).eps)#calculate iou
                    
                    print('iou:', round(iou, 4))
                    if iou > th - eps:
                        remove_array(rpn_neg, rpn)#ambi
                        if iou > th + eps:#pos
                            rpn_pos.append(rpn)
            
        
        np.savez(os.path.join(save_path, ''.join((idx(gt_name), '_bbox.npz'))), 
                 gt_pos = np.array(gt_pos), rpn_pos=np.array(rpn_pos), rpn_neg=np.array(rpn_neg))

In [None]:
preprocessing(det_train_names, gt_train_names, save_train_path, border=8, th=0.2, eps=0.02)

In [None]:
preprocessing(det_val_names, gt_val_names, save_val_path, border=8, th=0.2, eps=0.02)