Skip to content

Commit

Permalink
Add bake texture related nodes and classes
Browse files Browse the repository at this point in the history
  • Loading branch information
MrForExample committed Jan 6, 2024
1 parent 1c4ec03 commit 0d5e579
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 59 deletions.
10 changes: 5 additions & 5 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
import importlib
import inspect

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}

import importlib
import inspect

nodes_filename = "nodes"
module = importlib.import_module(f".{nodes_filename}", package=__name__)
for name, cls in inspect.getmembers(module, inspect.isclass):
Expand All @@ -16,4 +14,6 @@
disp = f"{name}"

NODE_CLASS_MAPPINGS[node] = cls
NODE_DISPLAY_NAME_MAPPINGS[node] = disp
NODE_DISPLAY_NAME_MAPPINGS[node] = disp

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
40 changes: 21 additions & 19 deletions diff_rast/diff_mesh_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as F
import nvdiffrast.torch as dr

from ..mesh_processer.mesh import Mesh, safe_normalize
from ..mesh_processer.mesh import safe_normalize

def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
Expand Down Expand Up @@ -38,41 +38,41 @@ def make_divisible(x, m=8):
return int(math.ceil(x / m) * m)

class Renderer(nn.Module):
def __init__(self, opt):
def __init__(self, mesh, force_cuda_rast):

super().__init__()

self.opt = opt
self.mesh = mesh

self.mesh = Mesh.load(self.opt.mesh, resize=False)

if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
self.glctx = dr.RasterizeGLContext()
else:
if force_cuda_rast or os.name != 'nt':
self.glctx = dr.RasterizeCudaContext()
else:
self.glctx = dr.RasterizeGLContext()

# extract trainable parameters
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo))
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v), requires_grad=True)
self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo), requires_grad=True)


def get_params(self):
def get_params(self, texture_lr, train_geo, geom_lr):

params = [
{'params': self.raw_albedo, 'lr': self.opt.texture_lr},
{'params': self.raw_albedo, 'lr': texture_lr},
]

if self.opt.train_geo:
params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
self.train_geo = train_geo
if train_geo:
params.append({'params': self.v_offsets, 'lr': geom_lr})

return params

@torch.no_grad()
def export_mesh(self, save_path):
self.mesh.v = (self.mesh.v + self.v_offsets).detach()
self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
self.mesh.write(save_path)

@torch.no_grad()
def update_mesh(self):
self.mesh.v = (self.mesh.v + self.v_offsets).detach()
self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())

def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):

Expand All @@ -86,7 +86,7 @@ def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-
results = {}

# get v
if self.opt.train_geo:
if self.train_geo:
v = self.mesh.v + self.v_offsets # [N, 3]
else:
v = self.mesh.v
Expand All @@ -106,9 +106,10 @@ def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-

texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
print(f"albedo = dr.texture: {albedo.requires_grad}")
albedo = torch.sigmoid(albedo)
# get vn and render normal
if self.opt.train_geo:
if self.train_geo:
i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]

Expand All @@ -133,6 +134,7 @@ def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-

# antialias
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
print(f"albedo = dr.antialias: {albedo.requires_grad}")
albedo = alpha * albedo + (1 - alpha) * bg_color

# ssaa
Expand Down
135 changes: 133 additions & 2 deletions diff_rast/diff_texturing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,135 @@
import sys
import random

from ..shared_utils.common_utils import get_persistent_directory
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_msssim import SSIM, MS_SSIM

from ..diff_rast.diff_mesh_renderer import Renderer
from ..shared_utils.camera_utils import orbit_camera, OrbitCamera

class DiffTextureBaker:

def __init__(self, reference_images, reference_masks, reference_orbit_camera_poses, reference_orbit_camera_fovy, mesh,
training_iterations, batch_size, texture_learning_rate, train_mesh_geometry, geometry_learning_rate, ms_ssim_loss_weight, force_cuda_rasterize):
self.device = torch.device("cuda")

self.ref_imgs_num = len(reference_images)

self.all_ref_cam_poses = reference_orbit_camera_poses
self.ref_cam_fovy = reference_orbit_camera_fovy

self.ref_size_H = reference_images[0].shape[0]
self.ref_size_W = reference_images[0].shape[1]

self.cam = OrbitCamera(self.ref_size_W, self.ref_size_H, fovy=reference_orbit_camera_fovy)

# prepare main components for optimization
self.renderer = Renderer(mesh, force_cuda_rasterize).to(self.device)

self.optimizer = torch.optim.Adam(self.renderer.get_params(texture_learning_rate, train_mesh_geometry, geometry_learning_rate))
#self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
self.ms_ssim_loss = MS_SSIM(data_range=1, size_average=True, channel=3)
self.lambda_ssim = ms_ssim_loss_weight

self.training_iterations = training_iterations

self.batch_size = batch_size

# prepare reference images and masks
ref_imgs_torch_list = []
ref_masks_torch_list = []
for i in range(self.ref_imgs_num):
ref_imgs_torch_list.append(self.prepare_img(reference_images[i]))
ref_masks_torch_list.append(self.prepare_img(reference_masks[i].unsqueeze(2)))

self.ref_imgs_torch = torch.cat(ref_imgs_torch_list, dim=0)
self.ref_masks_torch = torch.cat(ref_masks_torch_list, dim=0)

def prepare_img(self, img):
img_new = img.permute(2, 0, 1).unsqueeze(0).to(self.device)
img_new = F.interpolate(img_new, (self.ref_size_H, self.ref_size_W), mode="bilinear", align_corners=False).contiguous()
return img_new

def training(self):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()

ref_imgs_masked = []
for i in range(self.ref_imgs_num):
ref_imgs_masked.append((self.ref_imgs_torch[i] * self.ref_masks_torch[i]).unsqueeze(0))

ref_imgs_num_minus_1 = self.ref_imgs_num-1

for step in range(self.training_iterations):

### calculate loss between reference and rendered image from known view
loss = 0
masked_rendered_img_batch = []
masked_ref_img_batch = []
for _ in range(self.batch_size):

i = random.randint(0, ref_imgs_num_minus_1)

radius, elevation, azimuth, center_X, center_Y, center_Z = self.all_ref_cam_poses[i]

# render output
orbit_target = np.array([center_X, center_Y, center_Z], dtype=np.float32)
ref_pose = orbit_camera(elevation, azimuth, radius, target=orbit_target)
ref_cam = (ref_pose, self.cam.perspective)
out = self.renderer.render(*ref_cam, self.ref_size_H, self.ref_size_W, ssaa=1) #ssaa = min(2.0, max(0.125, 2 * np.random.random()))

image = out["image"] # [H, W, 3] in [0, 1]
image = image.permute(2, 0, 1).contiguous() # [3, H, W] in [0, 1]

#print(f"image.requires_grad: {image.requires_grad}")

image_masked = (image * self.ref_masks_torch[i]).unsqueeze(0)

#print(f"image_masked.requires_grad: {image_masked.requires_grad}")
#print(f"ref_imgs_masked[i].requires_grad: {ref_imgs_masked[i].requires_grad}")

masked_rendered_img_batch.append(image_masked)
masked_ref_img_batch.append(ref_imgs_masked[i])

masked_rendered_img_batch_torch = torch.cat(masked_rendered_img_batch, dim=0)
masked_ref_img_batch_torch = torch.cat(masked_ref_img_batch, dim=0)

# rgb loss
loss += (1 - self.lambda_ssim) * F.mse_loss(masked_rendered_img_batch_torch, masked_ref_img_batch_torch)

# D-SSIM loss
# [1, 3, H, W] in [0, 1]
#loss += self.lambda_ssim * (1 - self.ssim_loss(X, Y))
loss += self.lambda_ssim * (1 - self.ms_ssim_loss(masked_ref_img_batch_torch, masked_rendered_img_batch_torch))

print(f"masked_rendered_img_batch_torch.requires_grad: {masked_rendered_img_batch_torch.requires_grad}")
print(f"masked_ref_img_batch_torch.requires_grad: {masked_ref_img_batch_torch.requires_grad}")

print(f"loss.requires_grad: {loss.requires_grad}")

print(f"self.renderer.raw_albedo.requires_grad: {self.renderer.raw_albedo.requires_grad}")

# import kiui
# kiui.lo(hor, ver)
# kiui.vis.plot_image(image)

# optimize step
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()

torch.cuda.synchronize()

self.need_update = True

print(f"Step: {step}")

self.renderer.update_mesh()

ender.record()
t = starter.elapsed_time(ender)

def get_mesh_and_texture(self):
return (self.renderer.mesh, self.renderer.mesh.albedo, )
8 changes: 4 additions & 4 deletions mesh_processer/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load(cls, path=None, resize=True, renormal=True, retex=False, front_dir='+z'
if path is None:
mesh = cls(**kwargs)
# obj supports face uv
elif path.endswith(".obj"):
elif path.lower().endswith(".obj"):
mesh = cls.load_obj(path, **kwargs)
# trimesh only supports vertex uv, but can load more formats
else:
Expand Down Expand Up @@ -421,11 +421,11 @@ def to(self, device):
return self

def write(self, path):
if path.endswith(".ply"):
if path.lower().endswith(".ply"):
self.write_ply(path)
elif path.endswith(".obj"):
elif path.lower().endswith(".obj"):
self.write_obj(path)
elif path.endswith(".glb") or path.endswith(".gltf"):
elif path.lower().endswith(".glb") or path.lower().endswith(".gltf"):
self.write_glb(path)
else:
raise NotImplementedError(f"format {path} not supported!")
Expand Down
Loading

0 comments on commit 0d5e579

Please sign in to comment.