diff --git a/src/Datasets.py b/src/Datasets.py index 44ab2c5..bbe3795 100644 --- a/src/Datasets.py +++ b/src/Datasets.py @@ -3,6 +3,7 @@ import watch_n_patch import scipy from torch.utils.data import Dataset +import cv2 PATCH = 'watch_n_patch' @@ -54,8 +55,10 @@ def __len__(self): def __getitem__(self, idx): name = list(self.joints.keys())[idx] + name_rgb = name.replace(name[name.find("."):], ".jpg").replace("depth", "rgbjpg") img = scipy.io.loadmat(name)['depth'] + img_rgb = cv2.imread(name_rgb) arr = np.array(img) tmp = np.zeros((arr.shape[0], arr.shape[1], 3)) @@ -68,4 +71,7 @@ def __getitem__(self, idx): kpts = np.array(kpts) kpts = kpts[np.newaxis, :] - return img, kpts, name + img = img * 255 / np.amax(img) + img = img.astype(np.uint8) + + return [img, img_rgb], kpts, [name, name_rgb] \ No newline at end of file diff --git a/src/Noter.py b/src/Noter.py index c9e5908..cc903cc 100644 --- a/src/Noter.py +++ b/src/Noter.py @@ -1,6 +1,7 @@ -from Datasets import ComposedDataset +from src.Datasets import ComposedDataset +from src.RGB_Dataset import ComposedDataset as RGB_Dataset import numpy as np -from watch_n_patch import WATCH_N_PATCH_JOINTS +from src.watch_n_patch import WATCH_N_PATCH_JOINTS import json import os import cv2 @@ -91,7 +92,11 @@ def start(self, skip_or_keep: str = "skip"): # Starting function self.master.update() next_name = None - for i, (img, kpts, name) in enumerate(self.dataset): + for i, (imgs, kpts, names) in enumerate(self.dataset): + if isinstance(names, list): + name = names[0] + else: + name = names if name in self.json_dict: if skip_or_keep == "keep": kpts = np.array(self.json_dict[name]) @@ -128,26 +133,29 @@ def start(self, skip_or_keep: str = "skip"): self.status.set(f"[{curr_seq}] di [{seq}]") self.master.update() - rgb_name = name.split(self.slash) - if name.split('.')[-1] == 'png': - rgb_name[-2] = 'RGB' - last_split = rgb_name[-1].split('_') - last_split[-1] = 'RGB.png' - rgb_name[-1] = "_".join(last_split) - rgb_name = self.slash.join(rgb_name) - if name.split('.')[-1] == 'mat': - rgb_name[-2] = 'rgbjpg' - last_split = rgb_name[-1].split('.') - last_split[-1] = 'jpg' - rgb_name[-1] = ".".join(last_split) - rgb_name = self.slash.join(rgb_name) - - rgb = cv2.imread(rgb_name) - rgb = cv2.resize(rgb, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC) - - img = img * 255 / np.amax(img) - img = img.astype(np.uint8) - + # rgb_name = name.split(self.slash) + # if name.split('.')[-1] == 'png': + # rgb_name[-2] = 'RGB' + # last_split = rgb_name[-1].split('_') + # last_split[-1] = 'RGB.png' + # rgb_name[-1] = "_".join(last_split) + # rgb_name = self.slash.join(rgb_name) + # if name.split('.')[-1] == 'mat': + # rgb_name[-2] = 'rgbjpg' + # last_split = rgb_name[-1].split('.') + # last_split[-1] = 'jpg' + # rgb_name[-1] = ".".join(last_split) + # rgb_name = self.slash.join(rgb_name) + + new_imgs = list() + if isinstance(imgs, list): + img = imgs[0] + for index, new_img in enumerate(imgs[1:]): + new_imgs.append(cv2.resize(new_img, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC)) + cv2.namedWindow(names[index + 1]) + cv2.moveWindow(names[index + 1], 750, 300) + else: + img = imgs img, kpts = self.upscale(img, kpts) tmp = img.astype(np.uint8).copy() @@ -155,13 +163,13 @@ def start(self, skip_or_keep: str = "skip"): self.draw_kpts(tmp, kpts, self.radius) cv2.namedWindow(name) - cv2.namedWindow(rgb_name) - cv2.moveWindow(rgb_name, 750, 300) cv2.moveWindow(name, 15, 150) cv2.setMouseCallback(name, self.click_left, [name, tmp, kpts]) while True: - cv2.imshow(rgb_name, rgb) + if len(new_imgs) > 0: + for new_name, el in zip(names[1:], new_imgs): + cv2.imshow(new_name, el) cv2.imshow(name, tmp) key = cv2.waitKey(1) & 0xFF diff --git a/src/RGB_Dataset.py b/src/RGB_Dataset.py new file mode 100644 index 0000000..1d94c66 --- /dev/null +++ b/src/RGB_Dataset.py @@ -0,0 +1,66 @@ +import os +import numpy as np +import watch_n_patch +import scipy +from torch.utils.data import Dataset +import cv2 + +PATCH = 'watch_n_patch' + +OFFICE_SPLIT = ['data_03-58-25', 'data_03-25-32', 'data_02-32-08', 'data_03-05-15', 'data_11-11-59', + 'data_03-21-23', 'data_03-35-07', 'data_03-04-16', 'data_04-30-36', 'data_02-50-20'] +KITCHEN_SPLIT = ['data_04-51-42', 'data_04-52-02', 'data_02-10-35', 'data_03-45-21', 'data_03-53-06', + 'data_12-07-43', 'data_05-04-12', 'data_04-27-09', 'data_04-13-06', 'data_01-52-55'] + + +class ComposedDataset(Dataset): + def __init__(self, root_dir=None): + """ + Args: + root_dir (string): Directory with all the images. + split (string): Split for custom Dataset + """ + print("Loader started.") + self.root_dir = root_dir + self.P_ID = list() + self.joints = dict() + + # Load Watch-n-patch + print("Loading Watch-n-patch...") + + mat = scipy.io.loadmat(os.path.join(root_dir, PATCH, 'data_splits', 'kitchen_split.mat')) + kitchen_splits = mat['test_name'][0] + + mat = scipy.io.loadmat(os.path.join(root_dir, PATCH, 'data_splits', 'office_split.mat')) + office_splits = mat['test_name'][0] + + patch_joints = dict() + for el in kitchen_splits: + if el not in KITCHEN_SPLIT: + continue + patch_joints = {**patch_joints, **watch_n_patch.get_joints_rgb(os.path.join(root_dir, PATCH, "kitchen", el[0]))} + for el in office_splits: + if el not in OFFICE_SPLIT: + continue + patch_joints = {**patch_joints, **watch_n_patch.get_joints_rgb(os.path.join(root_dir, PATCH, "office", el[0]))} + print("Done.") + + self.size = len(patch_joints) + print(f"{self.size} images loaded.\n") + self.joints = {**patch_joints} + + + def __len__(self): + return self.size + + def __getitem__(self, idx): + name = list(self.joints.keys())[idx] + + img = cv2.imread(name) + img = cv2.resize(img, None, fx=0.4, fy=0.4, interpolation=cv2.INTER_CUBIC) + + kpts = [i for i in self.joints[name].values()] + kpts = np.array(kpts) + kpts = kpts[np.newaxis, :] + + return img, kpts * 0.4, name \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/watch_n_patch.py b/src/watch_n_patch.py index 995104e..f06e1df 100644 --- a/src/watch_n_patch.py +++ b/src/watch_n_patch.py @@ -49,6 +49,40 @@ def get_joints(data_path: str): return joints +def get_joints_rgb(data_path: str): + """ + :param + data_path : str + path to watch-n-patch .jpg file + :return + joints : dict + dictionary with frame_path as key and joints annotation as value + """ + joints = dict() + body = scipy.io.loadmat(os.path.join(data_path, 'body.mat'))['body'] + + # Watch-n-patch save depth images on \depth folder, if you're on unix system change the path format + DEPTH = os.path.join(data_path, 'rgbjpg') + names = get_image_name(DEPTH) + for frame in range(len(body)): + for k in range(6): + if body[frame][k]['isBodyTracked'] == 1: + joint_tracked = body[frame][k]['joints'] + joints[os.path.join(data_path, 'rgbjpg', names[frame])] = dict() + for i in range(len(joint_tracked[0][0][0])): + # Joint not tracked are set as not visible, using (-1, -1) as coordinates + if joint_tracked[0][0][0][i]['trackingState'][0][0][0][0] == 0: + joints[os.path.join(data_path, 'rgbjpg', names[frame])][i] = (-1, -1) + else: + # Getting joint annotations + for j in joint_tracked[0][0][0][i]['color'][0]: + x = j[0][0] + y = j[1][0] + joints[os.path.join(data_path, 'rgbjpg', names[frame])][i] = (round(x), round(y)) + break + return joints + + def get_image_name(img_dir: str): images = os.listdir(img_dir) images.sort() @@ -56,4 +90,4 @@ def get_image_name(img_dir: str): images.remove(".DS_Store") if any("._.DS_Store" in s for s in images): images.remove("._.DS_Store") - return images + return images \ No newline at end of file