## Converting object-labeled wisdom dataset to cluttered omniglot format

In [17]:
import os
import cv2
import matplotlib
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [18]:
from dataset_utils import mkdir_if_missing
from PIL import Image
from skimage import io
from tqdm import tqdm
import pprint
import json

In [19]:
# input directories
AMODAL_MASK_DIR = "/nfs/diskstation/dmwang/labeled_wisdom_real/dataset"
SCENE_DIR = "/nfs/diskstation/dmwang/labeled_wisdom_real/phoxi/depth_ims"
JSON_DIR = "/nfs/diskstation/dmwang/labeled_wisdom_real/phoxi/color_ims"
MASK_DIR = "/nfs/diskstation/dmwang/labeled_wisdom_real/phoxi/modal_segmasks"

# output directories
OUT_DIR = "/nfs/diskstation/projects/dex-net/segmentation/datasets/mask-net-real/fold_0000"
mkdir_if_missing(OUT_DIR)
mkdir_if_missing(os.path.join(OUT_DIR, "train"))
mkdir_if_missing(os.path.join(OUT_DIR, "val-train"))
mkdir_if_missing(os.path.join(OUT_DIR, "val-one-shot"))
mkdir_if_missing(os.path.join(OUT_DIR, "test-train"))
mkdir_if_missing(os.path.join(OUT_DIR, "test-one-shot"))

# real dataset parameters
NUM_IMS = 400

# original image shape
IM_WIDTH = 772
IM_HEIGHT = 1032


# 1:3 ratio between im_size and tar_size
IM_SIZE = 384
TAR_SIZE = 128

# Image distortion
ANGLE = 100
SHEAR = 4

# For storage purposes
BLOCK_SIZE = 500

In [20]:
def rot_x(phi, theta, ptx, pty):
    return np.cos(phi+theta)*ptx + np.sin(phi-theta)*pty


def rot_y(phi, theta, ptx, pty):
    return -np.sin(phi+theta)*ptx + np.cos(phi-theta)*pty


def prepare_img(img, angle=100, shear=2.5, scale=2):
    # Apply affine transformations and scale characters for data augmentation
    phi = np.radians(np.random.uniform(-angle, angle))
    theta = np.radians(np.random.uniform(-shear, shear))
    a = scale**np.random.uniform(-1, 1)
    b = scale**np.random.uniform(-1, 1)
    (x, y) = img.shape
    x = a * x
    y = b * y
    xextremes = [rot_x(phi, theta, 0, 0), rot_x(phi, theta, 0, y), rot_x(phi, theta, x, 0), rot_x(phi, theta, x, y)]
    yextremes = [rot_y(phi, theta, 0, 0), rot_y(phi, theta, 0, y), rot_y(phi, theta, x, 0), rot_y(phi, theta, x, y)]
    mnx = min(xextremes)
    mxx = max(xextremes)
    mny = min(yextremes)
    mxy = max(yextremes)

    aff_bas = np.array([[a*np.cos(phi+theta), b*np.sin(phi-theta), -mnx], [-a*np.sin(phi+theta), b*np.cos(phi-theta), -mny], [0, 0, 1]])
    aff_prm = np.linalg.inv(aff_bas)
    pil_img = Image.fromarray(img)
    pil_img = pil_img.transform((int(mxx - mnx),int(mxy - mny)),
                                    method=Image.AFFINE,
                                    data=np.ndarray.flatten(aff_prm[0:2, :]))
    pil_img = pil_img.resize((int(TAR_SIZE * (mxx - mnx) / 100), int(TAR_SIZE * (mxy - mny) / 100)))

    return np.array(pil_img)


def bbox(im):
    # get bounding box coordinates
    rows = np.any(im, axis=1)
    cols = np.any(im, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax, cmin, cmax


def make_target(modal_mask, angle=0, shear=0, scale=1):
    # make target image
    transformed_mask = prepare_img(modal_mask, angle, shear, scale)
    top, bot, left, right = bbox(transformed_mask)
    obj_size = max(bot - top, right - left)
    margin = max((TAR_SIZE * 2 - obj_size) // 2, 0)
    return cv2.resize(
        transformed_mask[max(0, top - margin):min(transformed_mask.shape[0], bot + margin),
                         max(0, left - margin):min(transformed_mask.shape[1], right + margin)],
        (TAR_SIZE, TAR_SIZE),
        interpolation=cv2.INTER_NEAREST)

def resize_scene(im):
    if len(im.shape) == 2:
        im = np.pad(im, (((IM_WIDTH - IM_SIZE) // 2, (IM_WIDTH - IM_SIZE) // 2), (0, 0)), mode="constant")
    elif len(im.shape) == 3:
        im = np.pad(im, (((IM_WIDTH - IM_SIZE) // 2, (IM_WIDTH - IM_SIZE) // 2), (0, 0), (0, 0)), mode="constant")
    else:
        raise Exception("image dimensions not valid for scene/ground truth, shape: {}".format(im.shape))
    return cv2.resize(
        im,
        (IM_SIZE, IM_SIZE),
        interpolation=cv2.INTER_NEAREST)

In [None]:
# Looping through all image indices
#   Looping through all labels in the json list
#     Get the name
#     Get the target file corresponding to the name
#     Add the image to the batch
#     Process the segmask from modal_segmasks

data_count = 0
for meta_idx in tqdm(range(NUM_IMS * 30)):
    idx = meta_idx % NUM_IMS
    f = open(os.path.join(JSON_DIR, "image_{:06d}.json".format(idx)))
    json_file = json.load(f)
    scene_path = os.path.join(SCENE_DIR, "image_{:06d}.png".format(idx))
    mask_path = os.path.join(MASK_DIR, "image_{:06d}.png".format(idx))

    # read and blockify scene
    scene_im = io.imread(scene_path, as_gray=True)
    scene_im = resize_scene(scene_im)
    
    # triplicate the scene (384, 384) --> (384, 384, 3)
    scene_im = np.stack([scene_im, scene_im, scene_im], axis=2)
    
    scene_im = np.expand_dims(scene_im, axis=0)

    # for modal masks
    joint_mask = io.imread(mask_path, as_gray=True)
    for label in json_file["labels"]:
        target_name = label["label_class"]
        target_id = label["object_id"]
        target_path = os.path.join(AMODAL_MASK_DIR, 
                                   "image_{:06d}".format(idx), 
                                   "{}.png".format(target_name))
        try:
            target_im = io.imread(target_path, as_gray=True)
            amodal_mask = make_target(target_im, angle=0, shear=0)
            amodal_mask[amodal_mask > 0] = 1.0
            """print(np.unique(amodal_mask))
            plt.imshow(amodal_mask)
            plt.show()"""
        except:
            continue
        
        modal_mask = np.copy(joint_mask)
        modal_mask[modal_mask != target_id] = 0
        modal_mask = resize_scene(modal_mask)
        modal_mask[modal_mask > 0] = 1.0
        modal_mask = modal_mask.astype("float64")
        data_count += 1

        """plt.imshow(scene_im[0])
        plt.show()
        plt.imshow(modal_mask)
        plt.show()
        plt.imshow(amodal_mask)
        plt.show()"""
        
        # triplicate the amodal mask (128, 128) --> (128, 128, 3)
        amodal_mask = np.stack([amodal_mask, amodal_mask, amodal_mask], axis=2)
        amodal_mask = amodal_mask.astype("float64")
        
        # create 1-sized blocks
        amodal_mask = np.expand_dims(amodal_mask, axis=0)
        modal_mask = np.expand_dims(modal_mask, axis=2)
        modal_mask = np.expand_dims(modal_mask, axis=0)
        
        """print(modal_mask.shape)
        print(amodal_mask.shape)
        print(scene_im.shape)
        print(np.unique(modal_mask))
        print(np.unique(amodal_mask))
        print(np.unique(scene_im))
        print(modal_mask.dtype)
        print(amodal_mask.dtype)
        print(scene_im.dtype)"""
        
        np.save(os.path.join(
            OUT_DIR,
            "test-train/",
            "image_{:08d}.npy".format(meta_idx)),
               scene_im)
        np.save(os.path.join(
            OUT_DIR,
            "test-train/",
            "segmentation_{:08d}.npy".format(meta_idx)),
               modal_mask)
        np.save(os.path.join(
            OUT_DIR,
            "test-train/",
            "target_{:08d}.npy".format(meta_idx)),
               amodal_mask)
        np.save(os.path.join(
            OUT_DIR,
            "test-one-shot/",
            "image_{:08d}.npy".format(meta_idx)),
               scene_im)
        np.save(os.path.join(
            OUT_DIR,
            "test-one-shot/",
            "segmentation_{:08d}.npy".format(meta_idx)),
               modal_mask)
        np.save(os.path.join(
            OUT_DIR,
            "test-one-shot/",
            "target_{:08d}.npy".format(meta_idx)),
               amodal_mask)        
        
            
        
        



  0%|          | 0/12000 [00:00<?, ?it/s][A
  0%|          | 1/12000 [00:00<2:25:47,  1.37it/s][A
  0%|          | 2/12000 [00:01<2:10:16,  1.53it/s][A
  0%|          | 3/12000 [00:01<2:06:29,  1.58it/s][A
  0%|          | 4/12000 [00:02<2:16:20,  1.47it/s][A
  0%|          | 5/12000 [00:03<2:11:27,  1.52it/s][A
  0%|          | 6/12000 [00:03<2:15:05,  1.48it/s][A
  0%|          | 7/12000 [00:05<3:14:29,  1.03it/s][A
  0%|          | 8/12000 [00:06<3:03:43,  1.09it/s][A
  0%|          | 9/12000 [00:07<2:54:59,  1.14it/s][A
  0%|          | 10/12000 [00:07<2:41:07,  1.24it/s][A
  0%|          | 11/12000 [00:08<2:37:11,  1.27it/s][A
  0%|          | 12/12000 [00:09<2:56:06,  1.13it/s][A
  0%|          | 13/12000 [00:10<2:57:20,  1.13it/s][A
  0%|          | 14/12000 [00:12<4:21:49,  1.31s/it][A
  0%|          | 15/12000 [00:13<3:48:00,  1.14s/it][A
  0%|          | 16/12000 [00:14<3:29:46,  1.05s/it][A
  0%|          | 17/12000 [00:15<3:18:51,  1.00it/s][A
  0%|      

  1%|          | 145/12000 [02:06<3:11:19,  1.03it/s][A
  1%|          | 146/12000 [02:06<2:50:08,  1.16it/s][A
  1%|          | 147/12000 [02:07<2:53:35,  1.14it/s][A
  1%|          | 148/12000 [02:09<3:46:39,  1.15s/it][A
  1%|          | 149/12000 [02:12<5:02:46,  1.53s/it][A
  1%|▏         | 150/12000 [02:14<5:51:58,  1.78s/it][A
  1%|▏         | 151/12000 [02:16<6:13:09,  1.89s/it][A
  1%|▏         | 152/12000 [02:20<8:16:01,  2.51s/it][A
  1%|▏         | 153/12000 [02:24<9:14:44,  2.81s/it][A
  1%|▏         | 154/12000 [02:25<7:34:20,  2.30s/it][A
  1%|▏         | 155/12000 [02:28<8:08:07,  2.47s/it][A
  1%|▏         | 156/12000 [02:30<7:50:47,  2.38s/it][A
  1%|▏         | 157/12000 [02:35<10:32:39,  3.21s/it][A
  1%|▏         | 158/12000 [02:37<9:51:13,  3.00s/it] [A
  1%|▏         | 159/12000 [02:40<9:20:14,  2.84s/it][A
  1%|▏         | 160/12000 [02:44<10:18:46,  3.14s/it][A
  1%|▏         | 161/12000 [02:45<8:22:14,  2.55s/it] [A
  1%|▏         | 162/12000 

  2%|▏         | 288/12000 [05:16<2:17:36,  1.42it/s][A
  2%|▏         | 289/12000 [05:16<2:17:34,  1.42it/s][A
  2%|▏         | 290/12000 [05:17<2:18:22,  1.41it/s][A
  2%|▏         | 291/12000 [05:18<2:06:26,  1.54it/s][A
  2%|▏         | 292/12000 [05:18<2:02:46,  1.59it/s][A
  2%|▏         | 293/12000 [05:19<2:02:48,  1.59it/s][A
  2%|▏         | 294/12000 [05:21<3:28:15,  1.07s/it][A
  2%|▏         | 295/12000 [05:21<2:49:03,  1.15it/s][A
  2%|▏         | 296/12000 [05:22<2:40:59,  1.21it/s][A
  2%|▏         | 297/12000 [05:23<2:26:40,  1.33it/s][A
  2%|▏         | 298/12000 [05:24<2:41:12,  1.21it/s][A
  2%|▏         | 299/12000 [05:24<2:28:21,  1.31it/s][A
  2%|▎         | 300/12000 [05:25<2:20:48,  1.38it/s][A
  3%|▎         | 301/12000 [05:26<2:23:19,  1.36it/s][A
  3%|▎         | 302/12000 [05:26<2:03:05,  1.58it/s][A
  3%|▎         | 303/12000 [05:27<2:08:00,  1.52it/s][A
  3%|▎         | 304/12000 [05:27<1:49:51,  1.77it/s][A
  3%|▎         | 305/12000 [05: