In [1]:
import os
import cv2
import numpy as np
import trimesh
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
from models.networks.smpl import SMPL
from models.networks.render import SMPLRenderer
from utils.util import load_obj, load_pickle_file, write_pickle_file

In [2]:
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class iPER_Dataset(Dataset):
    def __init__(self, imgs_path, pose_shape_pkl_path, image_size=256):
        self.imgs_path_list = os.listdir(imgs_path)
        self.pose_shape_pkl = load_pickle_file(pose_shape_pkl_path)
        self.Resize = transforms.Resize(image_size)
        self.ToTensor = transforms.ToTensor()
        
        if len(self.imgs_path_list) != self.pose_shape_pkl['pose'].shape[0]:
            print('images: ', len(self.imgs_path_list))
            print('smpls: ', self.pose_shape_pkl['pose'].shape[0])
        
        
    def __getitem__(self, index):
        img_path = self.imgs_path_list[index]
        img = self.Resize(Image.open(img_path).convert('RGB'))
        pose = self.pose_shape_pkl['pose'][index]
        shape = self.pose_shape_pkl['shape'][index]
        cam = self.pose_shape_pkl['cams'][index]
        
        output = {
            'image': self.ToTensor(img).float(),
            'pose': torch.from_numpy(pose).float(),
            'shape': torch.from_numpy(shape).float(),
            'cam': torch.from_numpy(cam).float()
        }
        
    def __len__(self):
        return len(self.imgs_path_list)

In [3]:
data_root = 'data/iPER'
batch_size = 1
image_size = 256

In [4]:
train_ID_list = []
for line in open(os.path.join(data_root, 'train.txt')):
    train_ID_list.append(line.split()[0])
print(len(train_ID_list))

164


In [5]:
video_ID = train_ID_list[0]
print(video_ID)

006/1/1


In [6]:
imgs_path = os.path.join(data_root, 'images', video_ID)
pose_shape_pkl_path = os.path.join(data_root, 'smpls', video_ID, 'pose_shape.pkl')
train_dataset = iPER_Dataset(imgs_path, pose_shape_pkl_path, image_size=image_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UVImageModel(nn.Module):
    def __init__(self, uv, image_size):
        super(UVImageModel, self).__init__()
        # (1, 3, image_size, image_size)
        self.weight = nn.Parameter(torch.zeros(1, 3, image_size, image_size) - 1.0)
         # (f, t, t, 2)
        self.f, self.t = uv.shape[:2]
        # (1, f, t*t, 2)
        uv = uv.reshape(1, self.f, self.t * self.t, 2)
        self.uv = torch.FloatTensor(uv).cuda()
    
    def forward(self):
        uv_image = torch.tanh(self.weight)
        texture = F.grid_sample(uv_image, self.uv)
        # (1,3,f,t,t)
        texture = texture.view(1, 3, self.f, self.t, self.t)
        # (1,f,t,t,3)
        texture = texture.permute(0, 2, 3, 4, 1)

        return texture

    def get_uv_image(self):
        return torch.tanh(self.weight)
    
def compute_uv_image(uv, texture, uv_size=224):
    """
    :param uv: (f, t, t, 2)
    :param texture: torch.Tensor [f, t, t, 3]
    :param uv_size: int, default is 224
    :return: uv_image (3,h,w) rgb(-1,1)
    """
    with torch.enable_grad():
        uv_image_model = UVImageModel(uv, image_size=uv_size).cuda()
        opt = torch.optim.Adam(uv_image_model.parameters(), lr=1e-2)
        for epoch in range(2000):
            pred_texture = uv_image_model()

            loss = ((pred_texture - texture) ** 2).mean()
            loss.backward()
            opt.step()
            opt.zero_grad()

            if epoch % 50 == 0:
                print(epoch, loss.item())

    return uv_image_model.get_uv_image()[0]