In [1]:
%matplotlib notebook

In [2]:
SMPLSH_Dir = r'..\SMPL_Socks\SMPL_reimp'

import sys
sys.path.insert(0, SMPLSH_Dir)
import smplsh_torch
import numpy as np

In [3]:
import os
import torch
import matplotlib.pyplot as plt
from skimage.io import imread

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes
from tqdm import tqdm_notebook
import torch.nn as nn
import torch.nn.functional as F
# Data structures and functions for rendering
from pytorch3d.structures import Meshes, Textures
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    TexturedSoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    HardPhongShader,
    BlendParams
)
from pytorch3d.transforms.so3 import (
    so3_exponential_map,
    so3_relative_angle,
)
# add path for demo utils functions 
import sys
import os
sys.path.append(os.path.abspath(''))

In [5]:
smplshData = r'C:\Code\MyRepo\ChbCapture\06_Deformation\SMPL_Socks\SMPLSH\SmplshModel.npz'
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)

pose_size = 3 * 52
beta_size = 10

smplsh = smplsh_torch.SMPLModel(device, smplshData)
np.random.seed(9608)
pose = torch.from_numpy((np.random.rand(pose_size) - 0.5) * 0.4)\
        .type(torch.float64).to(device)
betas = torch.from_numpy((np.random.rand(beta_size) - 0.5) * 0.06) \
        .type(torch.float64).to(device)
trans = torch.from_numpy(np.zeros(3)).type(torch.float64).to(device)

verts = smplsh(betas, pose, trans).type(torch.float32)
# Initialize each vertex to be gray in color.
verts_rgb = ( 0.5 *torch.ones_like(verts))[None]  # (1, V, 3)
textures = Textures(verts_rgb=verts_rgb.to(device))

smplshMesh = Meshes([verts], [smplsh.faces.to(device)], textures=textures)


In [None]:
class RenderingCfg:
    def __init_(s):
        s.sigma = 1e-4
        s.blurRange = 1e-4
        s.face_per_pixel = 50
        

In [6]:
# Initialize an OpenGL perspective camera.
# With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. 
# So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. 
R, T = look_at_view_transform(2.7, 0, 180) 
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. Refer to rasterize_meshes.py for explanations of these parameters. 
raster_settings = RasterizationSettings(
    image_size=512, 
#     blur_radius= np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    blur_radius= np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
#     blur_radius= 0, 


    # blur_radius=0.0005, 
#     faces_per_pixel=10, 
  faces_per_pixel=50, 

    bin_size=0
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model
rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    )
renderer = MeshRenderer(
    rasterizer = rasterizer,
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights,
        blend_params=blend_params
    )
#     shader=SoftSilhouetteShader(
#         blend_params=blend_params
#         # device=device, 
#         # cameras=cameras,
#         # lights=lights
#     )
)

In [8]:
with torch.no_grad():
    imageRef = renderer(smplshMesh)
plt.figure(figsize=(10, 10))
plt.imshow(imageRef[0, ..., :3].cpu().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1f013ebcbe0>

In [9]:
memStats = torch.cuda.memory_stats(device=device)
print('Before release: active_bytes.all.current:',memStats['active_bytes.all.current'] / 1000000)
torch.cuda.empty_cache()
memStats = torch.cuda.memory_stats(device=device)
print('After release: active_bytes.all.current:',memStats['active_bytes.all.current'] / 1000000)

Before release: active_bytes.all.current: 87.561216
After release: active_bytes.all.current: 87.561216


In [10]:
noiseLevel = 0.8
numBodyParameters = 3 * 22
# posePerturbed = torch.tensor(pose.cpu().numpy() + (np.random.rand(pose_size) - 0.5) * noiseLevel, dtype=torch.float64, device=device, requires_grad=True)
# Keep hand fixed
poseHands = pose[numBodyParameters:].clone().detach()
poseBody = torch.tensor(pose[:numBodyParameters].cpu().numpy() + (np.random.rand(numBodyParameters) - 0.5) * noiseLevel, dtype=torch.float64, device=device, requires_grad=True)
posePerturbed = torch.cat([poseBody, poseHands])

In [11]:
vertsPerturbed = smplsh(betas, posePerturbed, trans).type(torch.float32)
smplshMeshPerturbed = Meshes([vertsPerturbed], [smplsh.faces.to(device)], textures=textures)

In [12]:
with torch.no_grad():
    image = renderer(smplshMeshPerturbed)
plt.figure(figsize=(10, 10))
plt.imshow(image[0, ..., :3].cpu().detach().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1f0140b52b0>

In [13]:
memStats = torch.cuda.memory_stats(device=device)
print('Before release: active_bytes.all.current:',memStats['active_bytes.all.current'] / 1000000)
torch.cuda.empty_cache()
memStats = torch.cuda.memory_stats(device=device)
print('After release: active_bytes.all.current:',memStats['active_bytes.all.current'] / 1000000)

Before release: active_bytes.all.current: 94.431744
After release: active_bytes.all.current: 94.431744


In [14]:
optimizer = torch.optim.Adam([poseBody], lr=0.001)

In [15]:
loss = torch.sum((imageRef[..., 3] - image[..., 3]) ** 2)
print(loss.data)
print(loss.item())

tensor(16571.5547, device='cuda:0')
16571.5546875


In [16]:
loop = tqdm_notebook(range(2000))
for i in loop:
    optimizer.zero_grad()
    torch.cuda.empty_cache()

#     newPose = torch.cat([rootNodeRot, keepPoses])
    poseHands = pose[numBodyParameters:].clone().detach()
#     poseBody = torch.tensor(pose[:numBodyParameters].cpu().numpy() + (np.random.rand(numBodyParameters) - 0.5) * noiseLevel, dtype=torch.float64, device=device, requires_grad=True)
    posePerturbed = torch.cat([poseBody, poseHands])
    
    vertsPerturbed = smplsh(betas, posePerturbed, trans).type(torch.float32)
    smplshMeshPerturbed = Meshes([vertsPerturbed], [smplsh.faces.to(device)], textures=textures)
    
    images = renderer(smplshMeshPerturbed)
    # targetImg = images[0, ..., :3]
    loss = torch.sum((imageRef[..., 3] - images[..., 3]) ** 2)
    # loss, _ = model()
    
    loss.backward()
    optimizer.step()
    
    loop.set_description('Optimizing (loss %.4f, poseDiff: %.4f)' % (loss.data, torch.sum((pose-posePerturbed)**2).item()))
    
    if loss.item() < 200:
        break
    
    # Save outputs to create a GIF. 
    if i % 50 == 0:
        # R = look_at_rotation(model.camera_position[None, :], device=model.device)
        # T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        # image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        # image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        # image = img_as_ubyte(image)
        # writer.append_data(image)
        
        plt.figure()
        plt.imshow(images[0, ..., 3].cpu().detach().numpy())
        plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
        plt.grid("off")
        plt.axis("off")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>




KeyboardInterrupt: 

In [13]:
pose-posePerturbed

tensor([-0.2447,  0.3422, -0.2373,  0.3649, -0.2616, -0.0476, -0.0941, -0.2357,
        -0.3787, -0.1301, -0.1537,  0.1795,  0.0179, -0.1825, -0.2310,  0.2016,
        -0.2271, -0.1093, -0.3683, -0.4409, -0.3348, -0.3873,  0.1037, -0.1938,
        -0.0957,  0.5329, -0.3495, -0.0790, -0.2934,  0.2958, -0.1063, -0.8500,
        -0.1057,  0.1648, -0.6699,  0.8646,  0.0738, -0.3528, -0.3130, -0.3838,
        -0.1699, -0.2309, -0.1751,  0.1931,  0.2607,  0.3415, -0.4133,  0.4704,
        -0.1910, -0.1993, -0.4412, -0.2046,  0.0514,  0.0082, -0.2211, -0.2421,
        -0.1057,  0.5190, -0.2806,  0.1567, -0.0867, -0.1846,  0.0584,  0.1515,
        -0.2763, -0.0747,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.00

In [17]:
with torch.no_grad():
    image = renderer(smplshMeshPerturbed)
plt.figure(figsize=(10, 10))
plt.imshow(image[0, ..., :3].cpu().detach().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1f030272f98>