In [1]:
import os
import argparse
import numpy as np
import json
from tqdm import tqdm
from tars.base.dataset import Dataset, DatasetType
from tars.config.base.dataset_config import DatasetConfig
from tars.datasets.instance_segmentation_dataset import InstanceSegmentationDataset

In [2]:
def get_instance_target(raw_mask_image, task_dir):
    with open(os.path.join(task_dir, 'augmented_traj_data.json'), 'r') as f:
        color_data = json.load(f)['scene']['color_to_object_type'] 

    mask_image = np.array(raw_mask_image)
    boxes = []
    labels = []
    masks = []
    
    for k in color_data:
        # get object mask
        obj_idx = DatasetConfig.objects_vocab.word2index(color_data[k]['objectType'])
        k = tuple(map(int, k.strip('()').split(', ')))
        obj_mask = (mask_image[:, :, 0] == k[2]) & (mask_image[:, :, 1] == k[1]) & (mask_image[:, :, 2] == k[0])

        # get object bounding box coordinates
        pos = np.where(obj_mask)
        if len(pos[0]) + len(pos[1]) > 0:
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(obj_idx)
            masks.append(obj_mask)

    # convert to numpy arrays
    boxes = np.asarray(boxes, dtype=np.float32)
    labels = np.asarray(labels, dtype=np.int64)
    masks = np.asarray(masks, dtype=np.uint8)
    
    return boxes, labels, masks

In [3]:
split = "valid_unseen"
dataset = InstanceSegmentationDataset(DatasetType(split), preprocess=True)
remove_rgbs = []
remove_masks = []
for idx in tqdm(range(len(dataset))):
    rgb_img, mask_img = dataset[idx]
    task_dir = os.path.dirname(os.path.dirname(rgb_img.filename))
    boxes, labels, masks = get_instance_target(mask_img, task_dir)
    if boxes.shape[0] == 0:
        # bad target, add images to deletion lists
        remove_rgbs.append(rgb_img.filename)
        remove_masks.append(mask_img.filename)
    else:
        # save target
        new_target_file = rgb_img.filename.replace(DatasetConfig.high_res_img_dir, DatasetConfig.instance_target_dir).replace("png", "npz")
        new_target_dir = os.path.dirname(new_target_file)
        os.makedirs(new_target_dir, exist_ok=True)
        np.savez_compressed(new_target_file, boxes, labels, masks)
print(f"\nbad targets: {len(remove_rgbs)}/{len(dataset)}")

100%|██████████| 14529/14529 [55:45<00:00,  4.34it/s]bad targets: 111/14529



In [4]:
# delete images corresponding to bad targets
for rgb_file in remove_rgbs:
    os.remove(rgb_file)
for mask_file in remove_masks:
    os.remove(mask_file)

In [7]:
dataset = InstanceSegmentationDataset(DatasetType(split))

In [8]:
len(dataset)

14418

In [9]:
# verify the dataset
for i in tqdm(range(len(dataset))):
    img, tgt = dataset[i]
    assert(tgt != None)
    assert(tgt["labels"].shape[0] > 0)
    assert(tgt["labels"].shape[0] == tgt["boxes"].shape[0])
    assert(tgt["boxes"].shape[0] == tgt["masks"].shape[0])

100%|██████████| 14418/14418 [02:19<00:00, 103.60it/s]
