Skip to content

Commit

Permalink
Added Custom RGB-only Dataset boilerplate code
Browse files Browse the repository at this point in the history
  • Loading branch information
Deusy94 committed Aug 22, 2019
1 parent 2ba959e commit 5055c29
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 28 deletions.
8 changes: 7 additions & 1 deletion src/Datasets.py
Expand Up @@ -3,6 +3,7 @@
import watch_n_patch
import scipy
from torch.utils.data import Dataset
import cv2

PATCH = 'watch_n_patch'

Expand Down Expand Up @@ -54,8 +55,10 @@ def __len__(self):

def __getitem__(self, idx):
name = list(self.joints.keys())[idx]
name_rgb = name.replace(name[name.find("."):], ".jpg").replace("depth", "rgbjpg")

img = scipy.io.loadmat(name)['depth']
img_rgb = cv2.imread(name_rgb)

arr = np.array(img)
tmp = np.zeros((arr.shape[0], arr.shape[1], 3))
Expand All @@ -68,4 +71,7 @@ def __getitem__(self, idx):
kpts = np.array(kpts)
kpts = kpts[np.newaxis, :]

return img, kpts, name
img = img * 255 / np.amax(img)
img = img.astype(np.uint8)

return [img, img_rgb], kpts, [name, name_rgb]
60 changes: 34 additions & 26 deletions src/Noter.py
@@ -1,6 +1,7 @@
from Datasets import ComposedDataset
from src.Datasets import ComposedDataset
from src.RGB_Dataset import ComposedDataset as RGB_Dataset
import numpy as np
from watch_n_patch import WATCH_N_PATCH_JOINTS
from src.watch_n_patch import WATCH_N_PATCH_JOINTS
import json
import os
import cv2
Expand Down Expand Up @@ -91,7 +92,11 @@ def start(self, skip_or_keep: str = "skip"):
# Starting function
self.master.update()
next_name = None
for i, (img, kpts, name) in enumerate(self.dataset):
for i, (imgs, kpts, names) in enumerate(self.dataset):
if isinstance(names, list):
name = names[0]
else:
name = names
if name in self.json_dict:
if skip_or_keep == "keep":
kpts = np.array(self.json_dict[name])
Expand Down Expand Up @@ -128,40 +133,43 @@ def start(self, skip_or_keep: str = "skip"):
self.status.set(f"[{curr_seq}] di [{seq}]")
self.master.update()

rgb_name = name.split(self.slash)
if name.split('.')[-1] == 'png':
rgb_name[-2] = 'RGB'
last_split = rgb_name[-1].split('_')
last_split[-1] = 'RGB.png'
rgb_name[-1] = "_".join(last_split)
rgb_name = self.slash.join(rgb_name)
if name.split('.')[-1] == 'mat':
rgb_name[-2] = 'rgbjpg'
last_split = rgb_name[-1].split('.')
last_split[-1] = 'jpg'
rgb_name[-1] = ".".join(last_split)
rgb_name = self.slash.join(rgb_name)

rgb = cv2.imread(rgb_name)
rgb = cv2.resize(rgb, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC)

img = img * 255 / np.amax(img)
img = img.astype(np.uint8)

# rgb_name = name.split(self.slash)
# if name.split('.')[-1] == 'png':
# rgb_name[-2] = 'RGB'
# last_split = rgb_name[-1].split('_')
# last_split[-1] = 'RGB.png'
# rgb_name[-1] = "_".join(last_split)
# rgb_name = self.slash.join(rgb_name)
# if name.split('.')[-1] == 'mat':
# rgb_name[-2] = 'rgbjpg'
# last_split = rgb_name[-1].split('.')
# last_split[-1] = 'jpg'
# rgb_name[-1] = ".".join(last_split)
# rgb_name = self.slash.join(rgb_name)

new_imgs = list()
if isinstance(imgs, list):
img = imgs[0]
for index, new_img in enumerate(imgs[1:]):
new_imgs.append(cv2.resize(new_img, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC))
cv2.namedWindow(names[index + 1])
cv2.moveWindow(names[index + 1], 750, 300)
else:
img = imgs
img, kpts = self.upscale(img, kpts)

tmp = img.astype(np.uint8).copy()
kpts_back = kpts.copy()

self.draw_kpts(tmp, kpts, self.radius)
cv2.namedWindow(name)
cv2.namedWindow(rgb_name)
cv2.moveWindow(rgb_name, 750, 300)
cv2.moveWindow(name, 15, 150)
cv2.setMouseCallback(name, self.click_left, [name, tmp, kpts])

while True:
cv2.imshow(rgb_name, rgb)
if len(new_imgs) > 0:
for new_name, el in zip(names[1:], new_imgs):
cv2.imshow(new_name, el)
cv2.imshow(name, tmp)
key = cv2.waitKey(1) & 0xFF

Expand Down
66 changes: 66 additions & 0 deletions src/RGB_Dataset.py
@@ -0,0 +1,66 @@
import os
import numpy as np
import watch_n_patch
import scipy
from torch.utils.data import Dataset
import cv2

PATCH = 'watch_n_patch'

OFFICE_SPLIT = ['data_03-58-25', 'data_03-25-32', 'data_02-32-08', 'data_03-05-15', 'data_11-11-59',
'data_03-21-23', 'data_03-35-07', 'data_03-04-16', 'data_04-30-36', 'data_02-50-20']
KITCHEN_SPLIT = ['data_04-51-42', 'data_04-52-02', 'data_02-10-35', 'data_03-45-21', 'data_03-53-06',
'data_12-07-43', 'data_05-04-12', 'data_04-27-09', 'data_04-13-06', 'data_01-52-55']


class ComposedDataset(Dataset):
def __init__(self, root_dir=None):
"""
Args:
root_dir (string): Directory with all the images.
split (string): Split for custom Dataset
"""
print("Loader started.")
self.root_dir = root_dir
self.P_ID = list()
self.joints = dict()

# Load Watch-n-patch
print("Loading Watch-n-patch...")

mat = scipy.io.loadmat(os.path.join(root_dir, PATCH, 'data_splits', 'kitchen_split.mat'))
kitchen_splits = mat['test_name'][0]

mat = scipy.io.loadmat(os.path.join(root_dir, PATCH, 'data_splits', 'office_split.mat'))
office_splits = mat['test_name'][0]

patch_joints = dict()
for el in kitchen_splits:
if el not in KITCHEN_SPLIT:
continue
patch_joints = {**patch_joints, **watch_n_patch.get_joints_rgb(os.path.join(root_dir, PATCH, "kitchen", el[0]))}
for el in office_splits:
if el not in OFFICE_SPLIT:
continue
patch_joints = {**patch_joints, **watch_n_patch.get_joints_rgb(os.path.join(root_dir, PATCH, "office", el[0]))}
print("Done.")

self.size = len(patch_joints)
print(f"{self.size} images loaded.\n")
self.joints = {**patch_joints}


def __len__(self):
return self.size

def __getitem__(self, idx):
name = list(self.joints.keys())[idx]

img = cv2.imread(name)
img = cv2.resize(img, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC)

kpts = [i for i in self.joints[name].values()]
kpts = np.array(kpts)
kpts = kpts[np.newaxis, :]

return img, kpts * 0.4, name
Empty file added src/__init__.py
Empty file.
36 changes: 35 additions & 1 deletion src/watch_n_patch.py
Expand Up @@ -49,11 +49,45 @@ def get_joints(data_path: str):
return joints


def get_joints_rgb(data_path: str):
"""
:param
data_path : str
path to watch-n-patch .jpg file
:return
joints : dict
dictionary with frame_path as key and joints annotation as value
"""
joints = dict()
body = scipy.io.loadmat(os.path.join(data_path, 'body.mat'))['body']

# Watch-n-patch save depth images on \depth folder, if you're on unix system change the path format
DEPTH = os.path.join(data_path, 'rgbjpg')
names = get_image_name(DEPTH)
for frame in range(len(body)):
for k in range(6):
if body[frame][k]['isBodyTracked'] == 1:
joint_tracked = body[frame][k]['joints']
joints[os.path.join(data_path, 'rgbjpg', names[frame])] = dict()
for i in range(len(joint_tracked[0][0][0])):
# Joint not tracked are set as not visible, using (-1, -1) as coordinates
if joint_tracked[0][0][0][i]['trackingState'][0][0][0][0] == 0:
joints[os.path.join(data_path, 'rgbjpg', names[frame])][i] = (-1, -1)
else:
# Getting joint annotations
for j in joint_tracked[0][0][0][i]['color'][0]:
x = j[0][0]
y = j[1][0]
joints[os.path.join(data_path, 'rgbjpg', names[frame])][i] = (round(x), round(y))
break
return joints


def get_image_name(img_dir: str):
images = os.listdir(img_dir)
images.sort()
if any(".DS_Store" in s for s in images):
images.remove(".DS_Store")
if any("._.DS_Store" in s for s in images):
images.remove("._.DS_Store")
return images
return images

0 comments on commit 5055c29

Please sign in to comment.