Skip to content

Commit

Permalink
Added back support for ft_fx3 in DIB-R (#152)
Browse files Browse the repository at this point in the history
Signed-off-by: Tommy Xiang <txiang@dhcp-10-20-53-98.nvidia.com>
  • Loading branch information
TommyX12 committed Feb 18, 2020
1 parent dcacac5 commit cbd057c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
33 changes: 25 additions & 8 deletions kaolin/graphics/dib_renderer/renderer/phongrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@ def __init__(self, height, width):

def set_smooth(self, pfmtx):
self.smooth = True
self.pfmtx = torch.from_numpy(pfmtx).view(1, pfmtx.shape[0], pfmtx.shape[1]).cuda()
self.pfmtx = torch.from_numpy(pfmtx).view(
1, pfmtx.shape[0], pfmtx.shape[1]).cuda()

def forward(self,
points,
cameras,
uv_bxpx2,
texture_bx3xthxtw,
lightdirect_bx3,
material_bx3x3,
shininess_bx1,
ft_fx3=None):

def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightdirect_bx3, material_bx3x3, shininess_bx1):
assert lightdirect_bx3 is not None, 'When using the Phong model, light parameters must be passed'
assert material_bx3x3 is not None, 'When using the Phong model, material parameters must be passed'
assert shininess_bx1 is not None, 'When using the Phong model, shininess parameters must be passed'
Expand All @@ -54,6 +64,10 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightdirect_bx3,
# first, MVP projection in vertexshader
points_bxpx3, faces_fx3 = points

# use faces_fx3 as ft_fx3 if not given
if ft_fx3 is None:
ft_fx3 = faces_fx3

# camera_rot_bx3x3, camera_pos_bx3, camera_proj_3x1 = cameras

points3d_bxfx9, points2d_bxfx6, normal_bxfx3 = \
Expand All @@ -72,7 +86,8 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightdirect_bx3,
####################################################
# smooth or not
if self.smooth:
normal_bxpx3 = torch.matmul(self.pfmtx.repeat(normal_bxfx3.shape[0], 1, 1), normal_bxfx3)
normal_bxpx3 = torch.matmul(self.pfmtx.repeat(
normal_bxfx3.shape[0], 1, 1), normal_bxfx3)
n0 = normal_bxpx3[:, faces_fx3[:, 0], :]
n1 = normal_bxpx3[:, faces_fx3[:, 1], :]
n2 = normal_bxpx3[:, faces_fx3[:, 2], :]
Expand All @@ -86,18 +101,20 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightdirect_bx3,
bnum = normal1_bxfx3.shape[0]

# we have uv, normal, eye to interpolate
c0 = uv_bxpx2[:, faces_fx3[:, 0], :]
c1 = uv_bxpx2[:, faces_fx3[:, 1], :]
c2 = uv_bxpx2[:, faces_fx3[:, 2], :]
c0 = uv_bxpx2[:, ft_fx3[:, 0], :]
c1 = uv_bxpx2[:, ft_fx3[:, 1], :]
c2 = uv_bxpx2[:, ft_fx3[:, 2], :]
mask = torch.ones_like(c0[:, :, :1])
uv_bxfx3x3 = torch.cat((c0, mask, c1, mask, c2, mask), dim=2).view(bnum, fnum, 3, -1)
uv_bxfx3x3 = torch.cat(
(c0, mask, c1, mask, c2, mask), dim=2).view(bnum, fnum, 3, -1)

# normal & eye direction
normal_bxfx3x3 = normal_bxfx9.view(bnum, fnum, 3, -1)
eyedirect_bxfx9 = -points3d_bxfx9
eyedirect_bxfx3x3 = eyedirect_bxfx9.view(-1, fnum, 3, 3)

feat = torch.cat((normal_bxfx3x3, eyedirect_bxfx3x3, uv_bxfx3x3), dim=3)
feat = torch.cat(
(normal_bxfx3x3, eyedirect_bxfx3x3, uv_bxfx3x3), dim=3)
feat = feat.view(bnum, fnum, -1)
imfeature, improb_bxhxwx1 = linear_rasterizer(self.height, self.width,
points3d_bxfx9, points2d_bxfx6, normalz_bxfx1, feat)
Expand Down
25 changes: 19 additions & 6 deletions kaolin/graphics/dib_renderer/renderer/shrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,24 @@ def set_smooth(self, pfmtx):
self.smooth = True
self.pfmtx = pfmtx

def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightparam):
def forward(self,
points,
cameras,
uv_bxpx2,
texture_bx3xthxtw,
lightparam,
ft_fx3=None):

assert lightparam is not None, 'When using the Spherical Harmonics model, light parameters must be passed'

##############################################################
# first, MVP projection in vertexshader
points_bxpx3, faces_fx3 = points

# use faces_fx3 as ft_fx3 if not given
if ft_fx3 is None:
ft_fx3 = faces_fx3

# camera_rot_bx3x3, camera_pos_bx3, camera_proj_3x1 = cameras

points3d_bxfx9, points2d_bxfx6, normal_bxfx3 = \
Expand Down Expand Up @@ -83,11 +94,12 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightparam):
fnum = normal1_bxfx3.shape[1]
bnum = normal1_bxfx3.shape[0]

c0 = uv_bxpx2[:, faces_fx3[:, 0], :]
c1 = uv_bxpx2[:, faces_fx3[:, 1], :]
c2 = uv_bxpx2[:, faces_fx3[:, 2], :]
c0 = uv_bxpx2[:, ft_fx3[:, 0], :]
c1 = uv_bxpx2[:, ft_fx3[:, 1], :]
c2 = uv_bxpx2[:, ft_fx3[:, 2], :]
mask = torch.ones_like(c0[:, :, :1])
uv_bxfx3x3 = torch.cat((c0, mask, c1, mask, c2, mask), dim=2).view(bnum, fnum, 3, -1)
uv_bxfx3x3 = torch.cat(
(c0, mask, c1, mask, c2, mask), dim=2).view(bnum, fnum, 3, -1)

# normal
normal_bxfx3x3 = normal_bxfx9.view(bnum, fnum, 3, -1)
Expand All @@ -104,6 +116,7 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw, lightparam):
# fragrement shader
# parallel light
imnormal1_bxhxwx3 = datanormalize(imnormal_bxhxwx3, axis=3)
imrender = fragmentshader(imnormal1_bxhxwx3, lightparam, imtexcoords, texture_bx3xthxtw, hardmask)
imrender = fragmentshader(
imnormal1_bxhxwx3, lightparam, imtexcoords, texture_bx3xthxtw, hardmask)

return imrender, improb_bxhxwx1, normal1_bxfx3
17 changes: 13 additions & 4 deletions kaolin/graphics/dib_renderer/renderer/texrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ def __init__(self, height, width, filtering='nearest'):
self.width = width
self.filtering = filtering

def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw):
def forward(self,
points,
cameras,
uv_bxpx2,
texture_bx3xthxtw,
ft_fx3=None):

##############################################################
# first, MVP projection in vertexshader
points_bxpx3, faces_fx3 = points

# use faces_fx3 as ft_fx3 if not given
if ft_fx3 is None:
ft_fx3 = faces_fx3

# camera_rot_bx3x3, camera_pos_bx3, camera_proj_3x1 = cameras

points3d_bxfx9, points2d_bxfx6, normal_bxfx3 = \
Expand All @@ -62,9 +71,9 @@ def forward(self, points, cameras, uv_bxpx2, texture_bx3xthxtw):

############################################################
# second, rasterization
c0 = uv_bxpx2[:, faces_fx3[:, 0], :]
c1 = uv_bxpx2[:, faces_fx3[:, 1], :]
c2 = uv_bxpx2[:, faces_fx3[:, 2], :]
c0 = uv_bxpx2[:, ft_fx3[:, 0], :]
c1 = uv_bxpx2[:, ft_fx3[:, 1], :]
c2 = uv_bxpx2[:, ft_fx3[:, 2], :]
mask = torch.ones_like(c0[:, :, :1])
uv_bxfx9 = torch.cat((c0, mask, c1, mask, c2, mask), dim=2)

Expand Down

0 comments on commit cbd057c

Please sign in to comment.