# $\color{green}{\text{Neural Mesh Renderer in pytorch}}$
## [Kato etal (CVPR18')](http://openaccess.thecvf.com/content_cvpr_2018/papers/Kato_Neural_3D_Mesh_CVPR_2018_paper.pdf)
### implements four examples to use Neaural Renderer with pytorch
-------------------------------------
--------------------------
go to the repo [here](https://github.com/hiroharu-kato/neural_renderer) and hit 
```bash
sudo python setup.py install
```

#### importing relevent modules 

In [1]:
from __future__ import division
import os
import argparse
import glob

import torch
import torch.nn as nn
import numpy as np
from skimage.io import imread, imsave
from tqdm import tqdm_notebook as tqdm

import imageio
from IPython.display import display, HTML

import neural_renderer as nr

current_dir = os.getcwd()
data_dir = os.path.join(current_dir, 'data/3d_renderer')


#### supporting functions

In [2]:
def make_gif(filename):
    with imageio.get_writer(filename, mode='I') as writer:
        for filename in sorted(glob.glob('/tmp/_tmp_*.png')):
            writer.append_data(imread(filename))
            os.remove(filename)
    writer.close()

#### constants file names 

In [9]:
filename_obj =  os.path.join(data_dir,'motor.obj')
camera_distance = 2.732

## example1 : rendering obj file from different angles

In [10]:
filename_ref = os.path.join(data_dir,'example1_ref.png')
file_out = os.path.join(data_dir,'example1_out.gif')
file_opt = os.path.join(data_dir,'example1_opt.gif')

# other settings

elevation = 35
texture_size = 2

# load .obj
vertices, faces = nr.load_obj(filename_obj)
vertices = vertices[None, :, :]  # [num_vertices, XYZ] -> [batch_size=1, num_vertices, XYZ]
faces = faces[None, :, :]  # [num_faces, 3] -> [batch_size=1, num_faces, 3]

# create texture [batch_size=1, num_faces, texture_size, texture_size, texture_size, RGB]
textures = torch.ones(1, faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32).cuda()

In [11]:
# to gpu

# create renderer
renderer = nr.Renderer(camera_mode='look_at')
image_collection = []
# draw object
loop = tqdm(range(0, 360, 4))
writer = imageio.get_writer(file_out, mode='I')
for num, azimuth in enumerate(loop):
    loop.set_description('Drawing')
    renderer.eye = nr.get_points_from_angles(camera_distance, elevation, azimuth)
    images = renderer(vertices, faces, textures,)  # [batch_size, RGB, image_size, image_size]
    image = images.detach().cpu().numpy()[0].transpose((1, 2, 0))  # [image_size, image_size, RGB]
    writer.append_data((255*image).astype(np.uint8))
    image_collection.append((255*image).astype(np.uint8))
writer.close()
display(HTML("<img src='data/3d_renderer/example1_out.gif'></img>"))

HBox(children=(IntProgress(value=0, max=90), HTML(value='')))




In [12]:
?nr.lighting

[0;31mSignature:[0m [0mnr[0m[0;34m.[0m[0mlighting[0m[0;34m([0m[0mfaces[0m[0;34m,[0m [0mtextures[0m[0;34m,[0m [0mintensity_ambient[0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m [0mintensity_directional[0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m [0mcolor_ambient[0m[0;34m=[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m,[0m [0mcolor_directional[0m[0;34m=[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m,[0m [0mdirection[0m[0;34m=[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/anaconda3/envs/mytorch/lib/python3.7/site-packages/neural_renderer-1.1.3-py3.7-linux-x86_64.egg/neural_renderer/lighting.py
[0;31mType:[0m      function


## example2 : optimizaing the mesh vertices

In [28]:
filename_ref = os.path.join(data_dir,'example2_ref.png')
file_out = os.path.join(data_dir,'example2_out.gif')
file_opt = os.path.join(data_dir,'example2_opt.gif')

class Model(nn.Module):
    def __init__(self, filename_obj, filename_ref):
        super(Model, self).__init__()

        # load .obj
        vertices, faces = nr.load_obj(filename_obj)
        self.vertices = nn.Parameter(vertices[None, :, :])
        self.register_buffer('faces', faces[None, :, :])

        # create textures
        texture_size = 2
        textures = torch.ones(1, self.faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32)
        self.register_buffer('textures', textures)

        # load reference image
        image_ref = torch.from_numpy(imread(filename_ref).astype(np.float32).mean(-1) / 255.)[None, ::]
        self.register_buffer('image_ref', image_ref)

        # setup renderer
        renderer = nr.Renderer(camera_mode='look_at')
        self.renderer = renderer

    def forward(self):
        self.renderer.eye = nr.get_points_from_angles(2.732, 0, 90)
        image = self.renderer(self.vertices, self.faces, mode='silhouettes')
        loss = torch.sum((image - self.image_ref[None, :, :])**2)
        return loss

In [29]:
model = Model(filename_obj, filename_ref)
model.cuda()

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
# optimizer.setup(model)
loop = tqdm(range(300))
for i in loop:
    loop.set_description('Optimizing')
    # optimizer.target.cleargrads()
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    images = model.renderer(model.vertices, model.faces, mode='silhouettes')
    image = images.detach().cpu().numpy()[0]
    imsave('/tmp/_tmp_%04d.png' % i, image)
make_gif(file_opt)

# draw object
loop = tqdm(range(0, 360, 4))
for num, azimuth in enumerate(loop):
    loop.set_description('Drawing')
    model.renderer.eye = nr.get_points_from_angles(2.732, 0, azimuth)
    images = model.renderer(model.vertices, model.faces, model.textures)
    image = images.detach().cpu().numpy()[0].transpose((1, 2, 0))
    imsave('/tmp/_tmp_%04d.png' % num, image)
make_gif(file_out)
display(HTML("<img src='data/3d_renderer/example2_out.gif'>output</img>"))
print("\n\n")
display(HTML("<img src='data/3d_renderer/example2_opt.gif'>optimization</img>"))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

  .format(dtypeobj_in, dtypeobj_out))


HBox(children=(IntProgress(value=0, max=90), HTML(value='')))

  .format(dtypeobj_in, dtypeobj_out))







## example3 : optimizaing mesh texture

In [30]:
filename_ref = os.path.join(data_dir,'example3_ref.png')
file_out = os.path.join(data_dir,'example3_out.gif')
file_opt = os.path.join(data_dir,'example3_opt.gif')

class Model(nn.Module):
    def __init__(self, filename_obj, filename_ref):
        super(Model, self).__init__()
        vertices, faces = nr.load_obj(filename_obj)
        self.register_buffer('vertices', vertices[None, :, :])
        self.register_buffer('faces', faces[None, :, :])

        # create textures
        texture_size = 4
        textures = torch.zeros(1, self.faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32)
        self.textures = nn.Parameter(textures)

        # load reference image
        image_ref = torch.from_numpy(imread(filename_ref).astype('float32') / 255.).permute(2,0,1)[None, ::]
        self.register_buffer('image_ref', image_ref)

        # setup renderer
        renderer = nr.Renderer(camera_mode='look_at')
        renderer.perspective = False
        renderer.light_intensity_directional = 0.0
        renderer.light_intensity_ambient = 1.0
        self.renderer = renderer


    def forward(self):
        self.renderer.eye = nr.get_points_from_angles(2.732, 0, np.random.uniform(0, 360))
        image = self.renderer(self.vertices, self.faces, torch.tanh(self.textures))
        loss = torch.sum((image - self.image_ref) ** 2)
        return loss

In [31]:

model = Model(filename_obj, filename_ref)
model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1, betas=(0.5,0.999))
loop = tqdm(range(300))
for _ in loop:
    loop.set_description('Optimizing')
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()

# draw object
loop = tqdm(range(0, 360, 4))
for num, azimuth in enumerate(loop):
    loop.set_description('Drawing')
    model.renderer.eye = nr.get_points_from_angles(2.732, 0, azimuth)
    images = model.renderer(model.vertices, model.faces, torch.tanh(model.textures))
    image = images.detach().cpu().numpy()[0].transpose((1, 2, 0))
    imsave('/tmp/_tmp_%04d.png' % num, image)
make_gif(file_out)
display(HTML("<img src='data/3d_renderer/example3_out.gif'>output</img>"))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

HBox(children=(IntProgress(value=0, max=90), HTML(value='')))

  .format(dtypeobj_in, dtypeobj_out))


## example4 : optimizaing Camera external paramters 

In [32]:
filename_ref = os.path.join(data_dir,'example4_ref.png')
file_out = os.path.join(data_dir,'example4_out.gif')
file_opt = os.path.join(data_dir,'example4_opt.gif')
is_make_reference_image = False

class Model(nn.Module):
    def __init__(self, filename_obj, filename_ref=None):
        super(Model, self).__init__()
        # load .obj
        vertices, faces = nr.load_obj(filename_obj)
        self.register_buffer('vertices', vertices[None, :, :])
        self.register_buffer('faces', faces[None, :, :])

        # create textures
        texture_size = 2
        textures = torch.ones(1, self.faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32)
        self.register_buffer('textures', textures)

        # load reference image
        image_ref = torch.from_numpy((imread(filename_ref).max(-1) != 0).astype(np.float32))
        self.register_buffer('image_ref', image_ref)

        # camera parameters
        self.camera_position = nn.Parameter(torch.from_numpy(np.array([6, 10, -14], dtype=np.float32)))

        # setup renderer
        renderer = nr.Renderer(camera_mode='look_at')
        renderer.eye = self.camera_position
        self.renderer = renderer

    def forward(self):
        image = self.renderer(self.vertices, self.faces, mode='silhouettes')
        loss = torch.sum((image - self.image_ref[None, :, :]) ** 2)
        return loss





def make_reference_image(filename_ref, filename_obj):
    model = Model(filename_obj)
    model.cuda()

    model.renderer.eye = nr.get_points_from_angles(2.732, 30, -15)
    images = model.renderer.render(model.vertices, model.faces, torch.tanh(model.textures))
    image = images.detach().cpu().numpy()[0]
    imsave(filename_ref, image)

In [33]:
if is_make_reference_image:
    make_reference_image(filename_ref, filename_obj)

model = Model(filename_obj, filename_ref)
model.cuda()

# optimizer = chainer.optimizers.Adam(alpha=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
loop = tqdm(range(1000))
for i in loop:
    optimizer.zero_grad()
    loss = model()
    loss.backward()
    optimizer.step()
    images = model.renderer(model.vertices, model.faces, torch.tanh(model.textures))
    image = images.detach().cpu().numpy()[0].transpose(1,2,0)
    imsave('/tmp/_tmp_%04d.png' % i, image)
    loop.set_description('Optimizing (loss %.4f)' % loss.data)
    if loss.item() < 70:
        break
make_gif(file_out)
display(HTML("<img src='data/3d_renderer/example4_out.gif'>output</img>"))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

  warn('%s is a low contrast image' % fname)
  .format(dtypeobj_in, dtypeobj_out))
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a l