In [35]:
import os.path as osp
import numpy as np
import torch
import json
from collections import defaultdict
from PIL import Image


from modules.smpl_model._smpl import SMPL
from modules.smpl_model.config_smpl import *
from modules.utils.image_utils import to_tensor, transform, transform_visualize, crop_box


class Human36M(torch.utils.data.Dataset):
    def __init__(self,
                data_path:str='../H36M',
                split:str = 'train',
                num_required_keypoints:int= 0,
                debug:bool=True,
                store_images=True,
                load_from_zarr:str=None,
                img_size=224):
        
        self.debug = debug
        self.split = split
        self.img_dir = osp.join(data_path, 'Human36M', 'images')
        self.annot_path = osp.join(data_path, 'Human36M', 'annotations')
        self.action_name = ['Directions', 'Discussion', 'Eating', 'Greeting', 'Phoning', 'Posing', 'Purchases',
                            'Sitting', 'SittingDown', 'Smoking', 'Photo', 'Waiting', 'Walking', 'WalkDog',
                            'WalkTogether']
        self.fitting_thr = 25  # milimeter --> Threshhold joints from smpl mesh to h36m gt
        self.subject_list = [1, 5, 6, 7, 8, 9, 11]
        self.datalist, skip_idx, skip_img_path = self.load_data()
        
        if self.load_from_zarr is not None:
            self.imgs = torch.from_numpy(zarr.load(self.load_from_zarr)) ### Load array into memory
        elif self.store_images:
            self.img_size = img_size
            self.img_cache_indicator = torch.zeros(self.__len__(), dtype=torch.bool)
            self.img_cache = torch.empty(self.__len__(), 3, img_size, img_size, dtype=torch.float32)

    def load_data(self):
        #data_dict = {'images': [], 'annotation': []}
        data_dict = defaultdict(list)
        cameras = {}
        smpl_params = {}
        for subject in self.subject_list:
            ### Load data and image annotations
            with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_data.json'), 'r') as f:
                annotations = json.load(f)
            data_dict['images'].append(annotations['images'])
            data_dict['annotation'].append(annotations['annotations'])
            ### Load cameras
            with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_camera.json'), 'r') as f:
                cams = json.load(f)
            cameras[str(subject)] = {cam_id: get_cam_pose_intr(cam) for cam_id, cam in cams.items()}
            ### Load fitted smpl parameter
            with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_smpl_param.json'), 'r') as f:
                smpl_params[str(subject)] = json.load(f)
            ### Load 3d Joint ground truth (17x3)
            with open(osp.join(self.annot_path, 'Human36M_subject' + str(subject) + '_joint_3d.json'), 'r') as f:
                joints[str(subject)] = json.load(f)
                
        id_to_imgs = {} # Maps ann/img-id to ann file ((image/annotation)'id' = 'image_id')
        id_to_anns = {} # Maps ann/img-id to annotation file
        
        for ann, img in zip(dataset['annotations'], dataset['images']):
            id_to_anns[ann['image_id']] = ann 
            id_to_imgs[img['id']] = img
        
        datalist = []
    
        for img_id, img in id_to_imgs.items():
            img_path = osp.join(self.img_dir, img['file_name'])  
            # check smpl parameter exist
            subject = img['subject'];
            action = img['action_idx'];
            subaction = img['subaction_idx'];
            frame = img['frame_idx']; 
            try:
                smpl_param = smpl_params[str(subject)][str(action)][str(subaction)][str(frame)]
            except KeyError:
                continue 
                
            joint3d_smpl = np.array(smpl_param['fitted_3d_pose'], np.float32)
            joint3d_h36m_gt = np.array(joints[str(subject)][str(action_idx)][str(subaction_idx)][str(frame_idx)],
                                   dtype=np.float32)
            if self.get_fitting_error(joint3d_h36m_gt, joint3d_smpl) > self.fitting_thr: #check threshhold of h36m gt and smpl-mesh h36m joints
                continue
            beta = torch.FloatTensor(smpl_param['shape'])
            pose = torch.FloatTensor(smpl_param['pose'])
            trans = torch.FloatTensor(smpl_param['trans'])           
            cam_id = img['cam_idx']
            cam_param = cameras[str(subject)][str(cam_idx)]
            cam_pose, cam_intr = torch.FloatTensor(cam_param['cam_pose']), torch.FloatTensor(cam_param['cam_intr'])
            bbox = id_to_anns[img_id]['bbox']
            datalist.append({
                'img_path': img_path,
                'img_id': image_id,
                'betas': beta,
                'poses': pose,
                'trans': trans,
                'bbox': bbox,
                'cam_pose': cam_param,
                'cam_intr': cam_intr,
                })

        datalist = sorted(datalist, key=lambda x: x['img_id'])
        return datalist

    def get_fitting_error(self, joint3d_h36m_gt, joint3d_smpl):
        joint3d_h36m_gt = joint3d_h36m_gt - joint3d_h36m_gt[H36M_J17_NAME.index('Pelvis'), None,:] # root-relative
        # translation alignment
        joint3d_smpl = joint3d_smpl - np.mean(joint3d_smpl,0)[None,:] + np.mean(joint3d_h36m_gt,0)[None,:]
        error = np.sqrt(np.sum((joint3d_h36m_gt - joint3d_smpl)**2, 1)).mean()
        return error

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        data = copy.deepcopy(self.datalist[index])
        img_id, img_path = data['img_id'], data['img_path']
        
        if self.load_from_zarr is not None:
            img_tensor = self.imgs[index]
        elif self.store_images and self.img_cache_indicator[index]:
            img_tensor = self.img_cache[index]
        else:
            img = np.array(Image.open(img_path))
            if data['bbox'] is not None:
                x_min, y_min, x_max, y_max = data['bbox']
                img = img[y_min:y_max, x_min:x_max]
            img_tensor = to_tensor(img)
            img_tensor = transform(img_tensor, img_size=self.img_size)
            if self.store_images:
                self.img_cache[index] = img_tensor
                self.img_cache_indicator[index] = True
        data['img'] = img_tensor
        return data

ModuleNotFoundError: No module named 'modules'

In [26]:
a = {1:{3:4}, 2: 4, 5:6}

In [28]:
for b, v in a.items():
    print(b, v)

1 {3: 4}
2 {4: 5}


In [29]:
def get_cam_pose_intr(cam_dict):
    cam_pose = torch.cat((torch.FloatTensor(cam_dict['R']), torch.FloatTensor(cam_dict['t'])[:,None]), dim = 1)
    cam_pose = torch.cat((cam_pose, torch.FloatTensor([[0, 0, 0, 1]])), dim=0)
    cam_intr = torch.zeros(3,3)
    cam_intr[0,0], cam_intr[1,1] = cam_dict['f']
    cam_intr[0:2,2] = torch.tensor(cam_dict['c'])
    cam_intr[2,2] = 1.
    return {'cam_pose': cam_pose, 'cam_intr': cam_intr}
                                

In [10]:
import torch

In [24]:
a = torch.zeros(3,3)
a

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [17]:
from collections import defaultdict

In [None]:
torch.FloatTensor()