In [47]:
import os
import torch
import numpy as np
from glob import glob
import data_util
import random
from torch.utils.data import DataLoader

In [48]:
sorted(glob(os.path.join('/home/max/Downloads/real_cap/multi', 'pose', '*.txt')))

['/home/max/Downloads/real_cap/multi/pose/001_01.txt',
 '/home/max/Downloads/real_cap/multi/pose/001_02.txt',
 '/home/max/Downloads/real_cap/multi/pose/001_03.txt',
 '/home/max/Downloads/real_cap/multi/pose/001_04.txt',
 '/home/max/Downloads/real_cap/multi/pose/001_05.txt',
 '/home/max/Downloads/real_cap/multi/pose/002_01.txt',
 '/home/max/Downloads/real_cap/multi/pose/002_02.txt',
 '/home/max/Downloads/real_cap/multi/pose/002_03.txt',
 '/home/max/Downloads/real_cap/multi/pose/002_04.txt',
 '/home/max/Downloads/real_cap/multi/pose/002_05.txt']

In [83]:
class each_object():
    def __init__(self,
                 root_dir,
                 img_size=[512,512],
                 num_inpt_views=4,
                 num_trgt_views=1):
        super().__init__()

        self.img_size = img_size
        self.num_inpt_views = num_inpt_views
        self.num_trgt_views = num_trgt_views

        self.color_dir = os.path.join(root_dir, 'rgb')
        self.pose_dir = os.path.join(root_dir, 'pose')

        if not os.path.isdir(self.color_dir):
            print("Error! root dir is wrong")
            return

        self.all_color = sorted(data_util.glob_imgs(self.color_dir))
        self.all_poses = sorted(glob(os.path.join(self.pose_dir, '*.txt')))
        
        print("Buffering files...")
        self.all_views = []
        for i in range(len(self.all_color)):
            if not i % 10:
                print(i)
            self.all_views.append(self.read_view_tuple(i))
            
    def __len__(self):
        return len(self.all_color)//(self.num_inpt_views + self.num_trgt_views)
        
    def load_rgb(self, path):
        img = data_util.load_img(path, square_crop=True, downsampling_order=1, target_size=self.img_size)
        img = img[:, :, :3].astype(np.float32) / 255. - 0.5
        img = img.transpose(2,0,1)
        return img
    
    def read_view_tuple(self, idx):
        gt_rgb = self.load_rgb(self.all_color[idx])
        pose = data_util.load_pose(self.all_poses[idx])

        this_view = {'gt_rgb': torch.from_numpy(gt_rgb),
                     'pose': torch.from_numpy(pose)}
        return this_view
    
    def __getitem__(self, idx):
        
        inpt_views = self.all_views[idx*(self.num_inpt_views+self.num_trgt_views):
                                    (idx+1)*self.num_inpt_views+idx*self.num_trgt_views]
        inpt_views = random.sample(inpt_views, len(inpt_views))
        trgt_views = self.all_views[(idx+1)*self.num_inpt_views+idx*self.num_trgt_views:
                                   (idx+1)*(self.num_inpt_views+self.num_trgt_views)]

        return inpt_views, trgt_views

In [84]:
root_dir = os.path.join('/home/max/Downloads/real_cap/multi')
dataset = each_object(root_dir=root_dir)

Buffering files...
0


In [85]:
dataloader = DataLoader(dataset, batch_size=1)

In [90]:
for inp,trt in dataloader:
    print(trt[0]['pose'])

tensor([[[ 0.9529, -0.0231, -0.3024, 10.0801],
         [-0.0430, -0.9973, -0.0593,  1.8792],
         [-0.3002,  0.0695, -0.9513, -2.0421],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]])
tensor([[[ 0.9998,  0.0153, -0.0146,  4.4225],
         [ 0.0160, -0.9984,  0.0537, -0.2396],
         [-0.0138, -0.0540, -0.9984,  0.4539],
         [ 0.0000,  0.0000,  0.0000,  1.0000]]])


In [None]:

dataloader = DataLoader(dataset, batch_size=2)
for epoch in range(10):
        for inpt_views, trgt_views in dataloader:
            for i in range(len(inpt_views)):
                assert inpt_views[i]['pose'].shape[0] == opt.batch_size