In [1]:
import os
import re

import nibabel as nib
import numpy as np
from tqdm import tqdm

In [None]:
det_path = "/home/yxy/disk/Repository/RibFrac/NoduleNet/results/cross_val_test/res/72"
gt_path = "/home/yutongx/src_data/labels"
save_path = "./bbox"
 
gt_train_path = os.path.join(gt_path, "train")
gt_val_path = os.path.join(gt_path, "val")

save_train_path = os.path.join(save_path, "train")
save_val_path = os.path.join(save_path, "val")

In [None]:
det_names = sorted(os.listdir(det_path))

if len(det_names) == 500:
    det_train_names = det_names[:420]
    det_val_names = det_names[-80:]
else:
    raise ValueError('Detection names mismatch.')
    
gt_train_names = sorted(os.listdir(gt_train_path))
gt_val_names = sorted(os.listdir(gt_val_path))

In [None]:
def masks2bboxes(masks, border):
    instance_nums = [num for num in np.unique(masks) if num]
    bboxes = []
    
    for i in instance_nums:
        mask = (masks == i)
        if np.any(mask):
            zz, yy, xx = np.where(mask)
            bboxes.append([(zz.max() + zz.min()) / 2.,
                           (yy.max() + yy.min()) / 2.,
                           (xx.max() + xx.min()) / 2.,
                           zz.max() - zz.min() + 1 + border,
                           yy.max() - yy.min() + 1 + border,
                           xx.max() - xx.min() + 1 + border,
                           1])
    return bboxes

def preprocessing(det_names, gt_names, save_path, border=8, th=0.2):
    idx = lambda name: re.sub(r"\D", "", name)
    save_path = save_path if os.path.exists(save_path) else os.makedirs(save_path)
    
    for det_name, gt_name in tqdm(zip(det_names, gt_names)):
        print('====================')
        print('det_name:', det_name)
        print('gt_name:', gt_name)
        
        if idx(det_name) != idx(gt_name):
            raise ValueError('Index Mismatch.')
        
        det_src = np.load(det_name, allow_pickle = True)
        gt_src = (np.swapaxis(nib.load(gt_name).get_fdata(), -1, 0)).astype(np.uint8)

        gt_pos = masks2bboxes(gt_src, border=border)
        rpn_pos = []
        rpn_neg = [det[i, 1:] for i in range(det_src.shape[0])]
        
        print('--------------------')
        print('det_nums': det_pos.shape)
        print('gt_nums': gt_pos.shape)
        
        for gt in tqdm(gt_pos):
            print('cen:', np.around(gt[3:],2))
            ##cnt = 0
            for rpn in rpn_neg:
                intsc = (rpn[3:] + gt[3:])/2 - np.abs(rpn[:3] - gt[:3])#calculate intersection size in 3 dims
                if np.all(intsc > border//2):
                    print('rpn hit:')
                    intsc_vol = np.prod(intsc)
                    iou = intsc_vol / (np.prod(rpn[3:]) + np.prod(gt[3:]) - intsc_vol 
                                   + np.finfo(intsc_vol.dtype).eps)
                    print('iou:', iou)
                    if iou > th:
                        #reduce ambiguity
                        rpn_pos.append(rpn)
                        rpn_neg.remove(rpn)
        
        np.savez(os.path.join(save_path, idx(gt_name).join('_bbox.npz')), 
                 gt_pos = np.array(gt_pos), rpn_pos=np.array(rpn_pos), rpn_neg=np.array(rpn_neg))

In [3]:
from time import sleep
pbar = tqdm(total=100)
for i in range(10):
    sleep(0.1)
    pbar.update(10)
pbar.close()


  0%|          | 0/100 [00:14<?, ?it/s][A

 10%|█         | 10/100 [00:00<00:00, 98.58it/s][A
 20%|██        | 20/100 [00:00<00:00, 98.24it/s][A
 30%|███       | 30/100 [00:00<00:00, 98.01it/s][A
 40%|████      | 40/100 [00:00<00:00, 97.80it/s][A
 50%|█████     | 50/100 [00:00<00:00, 97.79it/s][A
 60%|██████    | 60/100 [00:00<00:00, 97.71it/s][A
 70%|███████   | 70/100 [00:00<00:00, 97.66it/s][A
 80%|████████  | 80/100 [00:00<00:00, 97.58it/s][A
 90%|█████████ | 90/100 [00:00<00:00, 97.52it/s][A
100%|██████████| 100/100 [00:01<00:00, 97.27it/s][A
