In [460]:
import os.path as osp
import numpy as np
import torch
from PIL import Image
import zarr


from modules.smpl_model.config_smpl import *
from modules.utils.image_utils import to_tensor, transform
from modules.utils.data_utils_h36m import get_data_list_h36m
from modules.utils.geometry import get_smpl_coord

In [None]:
class ImageWiseH36M(torch.utils.data.Dataset):
    def __init__(self,
                data_path:str='../H36M',
                split:str = 'train',
                load_from_zarr:str=None,
                load_datalist:str=None,
                fitting_thr:int=None,
                img_size=224):
        super(ImageWiseH36M, self).__init__()
        self.split = split
        self.img_dir = osp.join(data_path, 'Human36M', 'images')
        self.annot_dir = osp.join(data_path, 'Human36M', 'annotations')
        self.load_from_zarr = load_from_zarr
        self.fitting_thr = fitting_thr  # milimeter --> Threshhold joints from smpl mesh to h36m gt
        self.subject_list = [1, 5, 6, 7, 8, 9, 11]
        self.img_size = img_size
        self.datalist = get_data_list_h36m(annot_dir=self.annot_dir,
                                            subject_list=self.subject_list,
                                            fitting_thr=self.fitting_thr,
                                            load_from_pkl=load_datalist,
                                            store_as_pkl=False,
                                            out_dir=None,
                                            )
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu' 
        if self.load_from_zarr is not None:
            self.imgs = torch.from_numpy(zarr.load(self.load_from_zarr)) ### Load array into memory
        elif self.store_images:
            self.img_cache_indicator = torch.zeros(self.__len__(), dtype=torch.bool).to(self.device)
            self.img_cache = torch.empty(self.__len__(), 3, img_size, img_size, dtype=torch.float32)
  
    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        
        data = self.datalist[index]
        img_path = osp.joint(self.img_dir, data['img_name'])
        if self.load_from_zarr is not None:
            img_tensor = self.imgs[index]
        elif self.store_images and self.img_cache_indicator[index]:
            img_tensor = self.img_cache[index]
        else:
            ## Open Image
            img_path = osp.join(self.img_dir, data['img_name'])
            img = np.array(Image.open(img_path))
            ## Open Mask
            sub_dir, img_name = osp.split(data['img_name'])
            mask_name = img_name.split('.')[-2]+'_mask.jpg'
            mask_path = osp.join(self.img_dir, sub_dir, mask_name)
            mask = np.array(Image.open(mask_path))

            if data['bbox'] is not None:
                x_min, y_min, x_max, y_max = data['bbox']
                img = img[y_min:y_max, x_min:x_max]
                mask = mask[y_min:y_max, x_min:x_max]
                img == img[mask != 0, :]
            img_tensor = to_tensor(img).to(self.device)
            img_tensor = transform(img_tensor, img_size=self.img_size)
            if self.store_images:
                self.img_cache[index] = img_tensor
                self.img_cache_indicator[index] = True
        data['img_path'] = img_path
        data['img'] = img_tensor
        return data

In [30]:
datalist = get_data_list_h36m(annot_dir='../H36M/annotations', fitting_thr=25, subject_list=[1, 5, 6, 7, 8, 9, 11], out_dir='../H36M', store_as_pkl=True)

KeyError: 0