In [1]:
%matplotlib notebook

In [2]:
SMPLSH_Dir = r'..\SMPL_Socks\SMPL_reimp'

import sys
sys.path.insert(0, SMPLSH_Dir)
import smplsh_torch
import numpy as np

In [3]:
import os
import torch
import matplotlib.pyplot as plt
from skimage.io import imread

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes
from tqdm import tqdm_notebook
import torch.nn as nn
import torch.nn.functional as F
# Data structures and functions for rendering
from pytorch3d.structures import Meshes, Textures
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    TexturedSoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    HardPhongShader,
    BlendParams
)
from pytorch3d.transforms.so3 import (
    so3_exponential_map,
    so3_relative_angle,
)
# add path for demo utils functions 
import sys
import os
sys.path.append(os.path.abspath(''))

In [4]:
smplshData = r'C:\Code\MyRepo\ChbCapture\06_Deformation\SMPL_Socks\SMPLSH\SmplshModel.npz'

In [5]:
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)

pose_size = 3 * 52
beta_size = 10

smplsh = smplsh_torch.SMPLModel(device, smplshData)
np.random.seed(9608)
pose = torch.from_numpy((np.random.rand(pose_size) - 0.5) * 0.4)\
        .type(torch.float64).to(device)
betas = torch.from_numpy((np.random.rand(beta_size) - 0.5) * 0.06) \
        .type(torch.float64).to(device)
trans = torch.from_numpy(np.zeros(3)).type(torch.float64).to(device)

verts = smplsh(betas, pose, trans).type(torch.float32)
# Initialize each vertex to be gray in color.
verts_rgb = ( 0.5 *torch.ones_like(verts))[None]  # (1, V, 3)
textures = Textures(verts_rgb=verts_rgb.to(device))

smplshMesh = Meshes([verts], [smplsh.faces.to(device)], textures=textures)


In [6]:
# Initialize an OpenGL perspective camera.
# With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. 
# So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. 
R, T = look_at_view_transform(2.7, 0, 180) 
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. Refer to rasterize_meshes.py for explanations of these parameters. 
raster_settings = RasterizationSettings(
    image_size=512, 
#     blur_radius= np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    blur_radius= np.log(1. / 1e-4 - 1.) * blend_params.sigma, 

    # blur_radius=0.0005, 
    faces_per_pixel=150, 
    bin_size=0
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model
rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    )
renderer = MeshRenderer(
    rasterizer = rasterizer,
    # shader=SoftPhongShader(
    #     device=device, 
    #     cameras=cameras,
    #     lights=lights
    # )
    shader=SoftSilhouetteShader(
        blend_params=blend_params
        # device=device, 
        # cameras=cameras,
        # lights=lights
    )
)

In [7]:
imageRef = renderer(smplshMesh)
plt.figure(figsize=(10, 10))
plt.imshow(imageRef[0, ..., 3].cpu().numpy())

# plt.grid("off");
# plt.axis("off");

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x26113d08ba8>

In [22]:
rootNodeRot = torch.tensor(np.random.randn(3)*0.1, requires_grad = True, dtype=torch.float64, device=device)
# rootNodeRot.

In [23]:
rootNodeRot.is_leaf

True

In [24]:
print(rootNodeRot.grad_fn)

None


In [25]:
keepPoses = pose[3:].clone().detach()
# # newPose.require_grad = True
# newPose[:3] += rootNodeRot

newPose = torch.cat([rootNodeRot, keepPoses])
# newPose.require_grad = True

In [26]:
newPose.is_leaf

False

In [27]:
print(newPose.grad_fn)

<CatBackward object at 0x000002616F65F978>


In [28]:
vertsPerturbed = smplsh(betas, newPose, trans).type(torch.float32)
smplshMeshPerturbed = Meshes([vertsPerturbed], [smplsh.faces.to(device)], textures=textures)

In [29]:
images = renderer(smplshMeshPerturbed)
plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., 3].cpu().detach().numpy())
plt.grid("off");
plt.axis("off");

<IPython.core.display.Javascript object>

In [34]:
optimizer = torch.optim.Adam([rootNodeRot], lr=0.005)

In [35]:
loss = torch.sum((imageRef[..., 3] - images[..., 3]) ** 2)
print(loss.data)
print(loss.item())

tensor(24.4107, device='cuda:0')
24.41067123413086


In [36]:
loop = tqdm_notebook(range(500))
for i in loop:
    optimizer.zero_grad()
    
    newPose = torch.cat([rootNodeRot, keepPoses])
    vertsPerturbed = smplsh(betas, newPose, trans).type(torch.float32)
    smplshMeshPerturbed = Meshes([vertsPerturbed], [smplsh.faces.to(device)], textures=textures)
    
    images = renderer(smplshMeshPerturbed)
    # targetImg = images[0, ..., :3]
    loss = torch.sum((imageRef[..., 3] - images[..., 3]) ** 2)
    # loss, _ = model()
    
    loss.backward()
    optimizer.step()
    
    loop.set_description('Optimizing (loss %.4f)' % loss.data)
    
    # if loss.item() < 200:
    #     break
    
    # Save outputs to create a GIF. 
    if i % 50 == 0:
        # R = look_at_rotation(model.camera_position[None, :], device=model.device)
        # T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        # image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        # image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        # image = img_as_ubyte(image)
        # writer.append_data(image)
        
        plt.figure()
        plt.imshow(images[0, ..., 3].cpu().detach().numpy())
        plt.title("iter: %d, loss: %0.2f, rootRot: (%0.2f, %0.2f, %0.2f)" % (i, loss.data, rootNodeRot[0], rootNodeRot[1], rootNodeRot[2]))
        plt.grid("off")
        plt.axis("off")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))

<IPython.core.display.Javascript object>




KeyboardInterrupt: 

In [33]:
plt.figure()
plt.imshow(images[0, ..., 3].cpu().detach().numpy())
# plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
plt.title("iter: %d, loss: %0.2f, rootRot: (%0.4f, %0.4f, %0.4f)" % (i, loss.data, rootNodeRot[0], rootNodeRot[1], rootNodeRot[2]))
plt.grid("off")
plt.axis("off")

<IPython.core.display.Javascript object>

(-0.5, 511.5, 511.5, -0.5)