# Optimize Texture

Optimize object's texture color

## 0. Install and import requirements

* torch_psdr: core library
* pytorch (cuda): core dependency library
* pywavefront: read `.obj` format mesh object
* imageio: read image-based texture
* matplotlib: show rendered image in notebook

⚠️ ATTENTION:

**Please restart notebook kernel after installation!**

**DO NOT run following installation script if you have installed!**

In [None]:
# torch_psdr is in channel luling. This will install cuda based pytorch automatically
! conda install -y torch_psdr -c luling -c pytorch
# install other dependencies we need to run this example
! conda install -y imageio matplotlab -c conda-forge
# install forked pywavefront
! conda install pywavefront_uv -c luling

In [None]:
%matplotlib inline

import torch_psdr as dr
import random
import math
import torch
import torch.optim
import torch.nn.functional as F
from pywavefront_uv import Wavefront
import imageio.v3 as imageio
from utilities import translate, scale, rotate
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# define some global variables
device = "cuda:0"
tex_H, tex_W = 512, 512

## 1. Read mesh objects

Define functions to load scene objects and an object named `spot` with image-based texture.

In [None]:
def load_obj(obj_path):
  scene = Wavefront(obj_path, create_materials=True, collect_faces=True)
  vertices = torch.tensor(scene.vertices, dtype=torch.float32, device=device)
  uvs = None if scene.parser.tex_coords == [] else torch.tensor(scene.parser.tex_coords, dtype=torch.float32, device=device)
  objs = {}
  for name, mesh in scene.meshes.items():
    if name is None:
      name = str(hash(mesh))
    indices = torch.tensor(mesh.faces, dtype=torch.int32, device=device)
    face_indices = indices[:,:,0]
    uv_indices = indices[:,:,1] if indices[0,0,1] != -1 else None
    material = mesh.materials[0]
    obj = None
    if name == "light":
      emit = torch.tensor(material.ambient[:3], dtype=torch.float32, device=device)
      obj = dr.AreaLight(vertices, face_indices, emit)
    else:
      diffuse = torch.tensor(material.diffuse[:3], dtype=torch.float32, device=device)
      normal = None
      roughness = torch.tensor([0.5], dtype=torch.float32, device=device)
      material = dr.DiffuseBsdfMaterial(diffuse, roughness, normal)
      obj = dr.Mesh(vertices, uvs, face_indices, uv_indices, material)
    objs[name] = obj
  return objs

def load_spot(obj_path, diffuse_tex=None, rotate_angle=None):
  scene = Wavefront(obj_path, create_materials=True, collect_faces=True)
  vertices = torch.tensor(scene.vertices, dtype=torch.float32, device=device)
  if rotate_angle is not None:
    pitch, yaw, roll = rotate_angle
    vertices = rotate(vertices, pitch, yaw, roll)
  vertices.mul_(150).add_(300)
  uvs = None if scene.parser.tex_coords == [] else torch.tensor(scene.parser.tex_coords, dtype=torch.float32, device=device)
  objs = {}
  for name, mesh in scene.meshes.items():
    if name is None:
      name = str(hash(mesh))
    indices = torch.tensor(mesh.faces, dtype=torch.int32, device=device)
    face_indices = indices[:,:,0]
    uv_indices = indices[:,:,1] if indices[0,0,1] != -1 else None
    material = mesh.materials[0]
    if diffuse_tex is None:
      tex = imageio.imread(material.texture.path)
      diffuse_tex = torch.tensor(tex, dtype=torch.float32, device=device) / 255
    normal = None
    roughness = torch.tensor([1], dtype=torch.float32, device=device)
    material = dr.DiffuseBsdfMaterial(diffuse_tex, roughness, normal)
    obj = dr.Mesh(vertices, uvs, face_indices, uv_indices, material)
  objs[name] = obj
  return objs

## 2. Set camera

In [None]:
def set_camera():
  look_from = torch.tensor([278, 278, -800], dtype=torch.float32, device=device)
  look_at = torch.tensor([278, 278, 0], dtype=torch.float32, device=device)
  up = torch.tensor([0, 1, 0], dtype=torch.float32, device=device)
  vfov = torch.tensor([torch.deg2rad(torch.tensor(38.0))], dtype=torch.float32, device=device)
  height, width = 600, 600
  camera = dr.PerspectiveCamera(
    look_from = look_from,
    look_at = look_at,
    up = up,
    vfov = vfov,
    height = height,
    width = width,
  )
  return camera

## 3. Configure renderer

In [None]:
def render_img(scene, diff_mode=False):
  integrator = dr.PathIntegrator(
    n_pass=1,
    spp_interior=12,
    enable_light_visable=False,
    spp_primary_edge=1,
    spp_secondary_edge=1,
    primary_edge_preprocess_rounds=1,
    secondary_edge_preprocess_rounds=1,
    max_bounce=3,
    mis_light_samples=2,
    mis_bsdf_samples=0,
  )
  imgs = integrator.renderD(scene) if diff_mode else integrator.renderC(scene)
  return imgs

## 4. Optimize texture

In [None]:
# combine all
def render_once(obj_path, spot_path, diffuse_tex=None, spot_rotation=None, diff_mode=False):
  objs = load_obj(obj_path)
  spot = load_spot(spot_path, diffuse_tex, spot_rotation)
  objs.update(spot)

  lights = [obj for name, obj in objs.items() if name == "light"]
  meshes = [obj for name, obj in objs.items() if name != "light"]

  cameras = [set_camera()]

  scene = dr.Scene(cameras, meshes, lights)

  imgs = render_img(scene, diff_mode)

  return imgs

In [None]:
# render target image
obj_path = "../data/input/cornell_box.obj"
spot_path = "../data/input/spot/spot_triangulated.obj"
target_imgs = render_once(obj_path, spot_path)
target_img = target_imgs[0]
tgt_img = target_img.detach().cpu().numpy()
plt.imshow(tgt_img)

In [None]:
# set texture parameters to optimize
tex_param = torch.zeros((tex_H, tex_W, 3), dtype=torch.float32, device=device, requires_grad=True)

# set optimizer
optim = torch.optim.SGD([tex_param], lr=1e8)

# optimize texture parameters
for iter in range(5):
  tex = torch.sigmoid(tex_param)
  imgs = render_once(obj_path, spot_path, tex, diff_mode=True)
  img = imgs[0]
  loss = F.mse_loss(img, target_img)
  loss.backward()
  optim.step()

  clear_output(wait=False)
  print(f"iter: {iter}; avg grad: {torch.mean(tex_param.grad)}; loss: {loss}")
  img_np = img.detach().cpu().numpy()
  plt.imshow(img_np)
  plt.show()

  optim.zero_grad()


In [None]:
# view target uv texture and optimized texture
target_diffuse_tex_path = "../data/input/spot/spot_texture.png"
target_diffuse_tex = imageio.imread(target_diffuse_tex_path)
plt.imshow(target_diffuse_tex)
plt.show()

optimized_tex = tex_param.detach().cpu().numpy()
plt.imshow(optimized_tex)
plt.show()

## 5. Optimize whole texture from multiple random view directions

In [None]:
tex_param = torch.zeros((tex_H, tex_W, 3), dtype=torch.float32, device=device, requires_grad=True)
optim = torch.optim.SGD([tex_param], lr=1e7)
iterations = 40
for iter in range(iterations):
  pitch, yaw, roll = random.random() * 2 * math.pi, random.random() * 2 * math.pi, random.random() * math.pi
  # render target image with specified orientation
  target_imgs = render_once(obj_path, spot_path, diffuse_tex=None, spot_rotation=(pitch, yaw, roll), diff_mode=False)
  # render image with texture parameters
  tex = torch.sigmoid(tex_param)
  imgs = render_once(obj_path, spot_path, tex, (pitch, yaw, roll), diff_mode=True)
  loss = F.mse_loss(imgs[0], target_imgs[0])
  # compute grads, implicitly accumulate texture parameter grads
  loss.backward()

  print(f"iter: {iter}; avg grad: {torch.mean(tex_param.grad)}; loss: {loss}")

  clear_output(wait=False)
  optim.step()
  optim.zero_grad()  
  tex = torch.sigmoid(tex_param).detach().cpu().numpy()
  plt.imshow(tex)
  plt.show()  
