# Inference of models

In [None]:
from render import Renderer
import torch
import nvdiffrast.torch as dr
from models_animated import EncoderPlusSDF, RenderHead
from utils import *
import matplotlib.pyplot as plt
from rendering_gaussian import render_with_seperate_args, Camera

## Breaking Sphere

In [None]:
Rs = []

for i in range(5):
    R = Renderer(
        50,
        512,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Breaking_Sphere/{i+1}.obj",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/10_DEFORMABLE_BREAKING_SPHERE/checkpoints/model_2500.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(5):
    renderer = Rs[t]
    time_step = t / 4
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 5}")
print(f"Average PSNR: {psnr_ / 5}")
print(f"Average SSIM: {ssim_ / 5}")
print(f"Average LPIPS: {lpips_ / 5}")
print(f"Average Chamfer distance: {cd_ / 5}")

In [None]:
R = Renderer(1, 1024, dr.RasterizeGLContext(), "data/Breaking_Sphere/1.obj")

fig, ax = plt.subplots(5, 4, figsize=(25, 20), dpi=300)

i = 0
while i < 20:
    t = i / 20
    vertices_np, faces_np = sdf_head.get_zero_points(t, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = R.render(vertices, faces, vertex_normals)[..., :3].squeeze()
    ax[i // 5, i % 5].imshow(imgs.cpu().numpy().clip(0, 1))
    ax[i // 5, i % 5].axis("off")
    ax[i // 5, i % 5].set_title(f"t = {t:.2f}", fontsize=12)
    i += 1

plt.tight_layout()
plt.show()

## SMPL First Scene

In [None]:
Rs = []

for i in range(3):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_1/{i+1}.obj",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/2_MULTIFRAME_SMPL_FIRST_SCENE/checkpoints/model_2800.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(3):
    renderer = Rs[t]
    time_step = t / 2
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 3}")
print(f"Average PSNR: {psnr_ / 3}")
print(f"Average SSIM: {ssim_ / 3}")
print(f"Average LPIPS: {lpips_ / 3}")
print(f"Average Chamfer distance: {cd_ / 3}")

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(20, 10), dpi=300)

j = 0
while j < 3:
    ax[0, j].imshow(plot_images[j])
    ax[0, j].axis("off")
    ax[0, j].set_title(f"t = {t:.2f}", fontsize=12)
    ax[1, j].imshow(plot_targets[j])
    ax[1, j].axis("off")
    ax[1, j].set_title(f"GT at {t:.2f}")
    j += 1

plt.tight_layout()
plt.show()

## SMPL Second Scene

In [None]:
Rs = []

for i in range(3):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_2/{i+1}.obj",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_SECOND_SCENE/model.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(3):
    renderer = Rs[t]
    time_step = t / 2
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 3}")
print(f"Average PSNR: {psnr_ / 3}")
print(f"Average SSIM: {ssim_ / 3}")
print(f"Average LPIPS: {lpips_ / 3}")
print(f"Average Chamfer distance: {cd_ / 3}")

## SMPL Third Scene

In [None]:
Rs = []

for i in range(3):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_3/{i+1}.obj",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_THIRD_SCENE/model_1.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(3):
    renderer = Rs[t]
    time_step = t / 2
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 3}")
print(f"Average PSNR: {psnr_ / 3}")
print(f"Average SSIM: {ssim_ / 3}")
print(f"Average LPIPS: {lpips_ / 3}")
print(f"Average Chamfer distance: {cd_ / 3}")

## Deformable Bunny

In [None]:
Rs = []

for i in range(3):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeCudaContext(),
        fname=f"data/Bunny_Scaling/{i+1}.obj",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_DEFORMABLE_BUNNY/model.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(3):
    renderer = Rs[t]
    time_step = t / 2
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs.clamp(0, 1), target_imgs.clamp(0, 1))
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 3}")
print(f"Average PSNR: {psnr_ / 3}")
print(f"Average SSIM: {ssim_ / 3}")
print(f"Average LPIPS: {lpips_ / 3}")
print(f"Average Chamfer distance: {cd_ / 3}")

In [None]:
plt.imshow(plot_images[0])
plt.axis("off")
plt.show()

plt.imshow(plot_targets[0])
plt.axis("off")
plt.show()

In [None]:
plt.imshow(plot_images[1])
plt.axis("off")
plt.show()

plt.imshow(plot_targets[1])
plt.axis("off")
plt.show()

In [None]:
plt.imshow(plot_images[2])
plt.axis("off")
plt.show()

plt.imshow(plot_targets[2])
plt.axis("off")
plt.show()

## STATIC SCREAMING FACE

In [None]:
R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Static_Screaming_Face/{1}.obj",
        scale=1.7,
        with_texture=False,
    )

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_SINGLEFRAME_STATIC_SCREAMING_FACE/model.pth") # and also 2_SINGLEFRAME_STATIC_SCREAMING_FACE
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(1):
    renderer = R
    vertices_np, faces_np = sdf_head.get_zero_points(0, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

## Static Multi object scene

In [None]:
R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Multi_Obj/{1}.obj",
        scale=1.7,
        with_texture=False,
    )

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_SINGLEFRAME_MULTIOBJECT/model.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(1):
    renderer = R
    vertices_np, faces_np = sdf_head.get_zero_points(0, 256)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

In [None]:
plt.imshow(imgs[-14].permute(1, 2, 0).cpu().numpy().clip(0, 1))
plt.axis("off")
plt.show()

## Dynamic Chair Deformation

In [None]:
Rs = []

for i in range(2):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Chair_Deform/{i+1}.ply",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_CHAIR_DEFORMATION/model_1.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(2):
    renderer = Rs[t]
    time_step = t
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 2}")
print(f"Average PSNR: {psnr_ / 2}")
print(f"Average SSIM: {ssim_ / 2}")
print(f"Average LPIPS: {lpips_ / 2}")
print(f"Average Chamfer distance: {cd_ / 2}")

In [None]:
plot_images = []

for t in [0.0, 0.25, 0.5, 1.0, 1.5]:
    if t >= 1.0:
        renderer = Rs[1]
    else:
        renderer = Rs[0]
    time_step = t
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
    plt.imshow(imgs[-2].cpu().numpy().clip(0, 1))
    plt.show()

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(30, 10), dpi=300)

ts = [0.0, 0.25, 0.5, 1.0, 1.5]
for i, img in enumerate(plot_images):
    ax[i // 5, i % 5].imshow(img)
    ax[i // 5, i % 5].axis("off")
    ax[i // 5, i % 5].set_title(f"t = {ts[i]}", fontsize=12)

for i in range(5):
    ax[1, i].axis("off")

ax[1, 0].imshow(plot_targets[0])
ax[1, 0].set_title("GT at 0.0", fontsize=12)
ax[1, 4].imshow(plot_targets[1])
ax[1, 4].set_title("GT at 1.0", fontsize=12)

plt.tight_layout()
plt.show()

## Static Bunny With GS

In [None]:
R_Silhouette = Renderer(
    50,
    512,
    glctx=dr.RasterizeGLContext(),
    fname="data/Colored_Bunny/1.obj",
    scale=1.7,
    with_texture=False,
)

In [None]:
R_Colored = Renderer(
    50,
    512,
    glctx=dr.RasterizeGLContext(),
    fname="data/Colored_Bunny/1.obj",
    scale=1.7,
    with_texture=True,
)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_SINGLEFRAME_STATIC_COLORED_BUNNY_WITH_GS/model.pth")
)
sdf_head.eval()

render_head = RenderHead().cuda()
render_head.load_state_dict(
    torch.load("logs/1_SINGLEFRAME_STATIC_COLORED_BUNNY_WITH_GS/render_head.pth")
)
render_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

renderer = R_Silhouette
time_step = 0
vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
vertices = torch.from_numpy(vertices_np).float().cuda()
faces = torch.from_numpy(faces_np.copy()).int().cuda()
face_normals = compute_face_normals(vertices, faces)
vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
target_imgs = renderer.target_imgs[..., :3]
mse = torch.nn.MSELoss()(imgs, target_imgs)
mse_ += mse.item()
print(f"MSE: {mse.item()}")
psnr = -10 * torch.log10(mse)
psnr_ += psnr.item()
print(f"PSNR: {psnr.item()}")
plot_images.append(imgs[-2].cpu().numpy().clip(0, 1))
plot_targets.append(target_imgs[-2].cpu().numpy().clip(0, 1))
imgs = imgs.permute(0, 3, 1, 2)
target_imgs = target_imgs.permute(0, 3, 1, 2)
ssim_score = ssim(imgs, target_imgs)
print(f"SSIM: {ssim_score.item()}")
lpip_score = lpips(imgs, target_imgs)
print(f"LPIPS: {lpip_score.item()}")
ssim_ += ssim_score.item()
lpips_ += lpip_score.item()
mesh = trimesh.Trimesh(vertices_np, faces_np)
chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
print(f"Chamfer distance: {chamfer_distance}")
print()
cd_ += chamfer_distance

In [None]:
plt.imshow(imgs[15].permute(1, 2, 0).cpu().numpy().clip(0, 1))
plt.axis("off")
plt.show()

In [None]:
from rendering_gaussian import render_with_seperate_args, Camera

cameras = []
renderer = R_Colored

time_step = 0
vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
vertices = torch.from_numpy(vertices_np).float().cuda()
faces = torch.from_numpy(faces_np.copy()).int().cuda()
face_normals = compute_face_normals(vertices, faces)
vertex_normals = compute_vertex_normals(vertices, faces, face_normals)

for j in range(50):
    cam = Camera(
        renderer.fov_x,
        renderer.fov_y,
        512,
        512,
        renderer.view_mats[j],
        renderer.mvps[j],
        renderer.view_mats[j][:3, 3],
    )
    cameras.append(cam)

vert = torch.cat([vertices, torch.zeros_like(vertices[..., :1])], dim=-1)

sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())

mse_ = 0
psnr_ = 0

for VIEW_IDX in range(50):
    shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
    dir_pp = vertices[..., :3] - renderer.camera_positions[VIEW_IDX].repeat(
        sh_coeffs.shape[0], 1
    )
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
    colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    est = renderer.render_coloured(
        vertices[..., :3],
        faces,
        vertex_normals,
        vertex_colors=colors_precomp,
        view_idx=VIEW_IDX,
        albedo=1.0,
    )[0][..., :3]
    mse = torch.nn.MSELoss()(est, R_Colored.target_imgs[VIEW_IDX][..., :3])
    psnr = -10 * torch.log10(mse)
    mse_ += mse
    psnr_ += psnr
    plt.imshow(est.cpu().detach().numpy().clip(0, 1))
    plt.axis("off")
    plt.show()

print(f"Average MSE: {mse_ / 50}")
print(f"Average PSNR: {psnr_ / 50}")

## Dynamic Chair with GS

In [None]:
Rs = []

for i in range(2):
    R_Colored = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Chair_Deform/{i+1}.ply",
        scale=1.7,
        with_texture=True,
    )
    Rs.append(R_Colored)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_CHAIR_DEFORMATION_WITH_GS/model.pth")
)
sdf_head.eval()

render_head = RenderHead().cuda()
render_head.load_state_dict(
    torch.load("logs/1_CHAIR_DEFORMATION_WITH_GS/render_head.pth")
)
render_head.eval()

In [None]:
from rendering_gaussian import Camera

plots_1 = []
plots_2 = []
plots_3 = []

total_mse = 0
total_psnr = 0

for i, R_Colored in enumerate(Rs):
    cameras = []
    renderer = R_Colored

    time_step = i
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)

    for j in range(100):
        cam = Camera(
            renderer.fov_x,
            renderer.fov_y,
            256,
            256,
            renderer.view_mats[j],
            renderer.mvps[j],
            renderer.view_mats[j][:3, 3],
        )
        cameras.append(cam)

    vert = torch.cat([vertices, torch.zeros_like(vertices[..., :1])], dim=-1)

    sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())

    mse_ = 0
    psnr_ = 0

    for VIEW_IDX in range(100):
        shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
        dir_pp = vertices[..., :3] - renderer.camera_positions[VIEW_IDX].repeat(
            sh_coeffs.shape[0], 1
        )
        dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
        sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        est = renderer.render_coloured(
            vertices[..., :3],
            faces,
            vertex_normals,
            vertex_colors=colors_precomp,
            view_idx=VIEW_IDX,
            albedo=1.0,
        )[0][..., :3]
        mse = torch.nn.MSELoss()(est, R_Colored.target_imgs[VIEW_IDX][..., :3])
        psnr = -10 * torch.log10(mse)
        mse_ += mse
        psnr_ += psnr
        if i == 0:
            plots_1.append(est.cpu().detach().numpy().clip(0, 1))
        elif i == 1:
            plots_2.append(est.cpu().detach().numpy().clip(0, 1))
        else:
            plots_3.append(est.cpu().detach().numpy().clip(0, 1))

    print(f"Average MSE: {mse_ / 100}")
    print(f"Average PSNR: {psnr_ / 100}")

    total_mse += mse_ / 2 / 100
    total_psnr += psnr_ / 2 / 100

print(f"Total Average MSE: {total_mse}")  
print(f"Total Average PSNR: {total_psnr}")

In [None]:
plt.imshow(plots_1[56])
plt.axis("off")
plt.show()

In [None]:
plt.imshow(plots_2[77])
plt.axis("off")
plt.show()

## Dynamic Eagle Statue

In [None]:
Rs = []

for i in range(2):
    R = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/Deforming_Eagle_Statue/{i+1}.ply",
        scale=1.7,
        with_texture=False,
    )
    Rs.append(R)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/3_EAGLE_STATUE_DEFORMATION/model.pth")
)
sdf_head.eval()

In [None]:
from torchmetrics.image import ssim, lpip
from utils import compute_trimesh_chamfer
import trimesh

mse_ = 0
psnr_ = 0
ssim = ssim.StructuralSimilarityIndexMeasure(
    data_range=1.0, reduction="elementwise_mean"
).cuda()
lpips = lpip.LearnedPerceptualImagePatchSimilarity("vgg").cuda()

ssim_ = 0
lpips_ = 0
cd_ = 0
plot_images = []
plot_targets = []

for t in range(2):
    renderer = Rs[t]
    time_step = t
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    target_imgs = renderer.target_imgs[..., :3]
    mse = torch.nn.MSELoss()(imgs, target_imgs)
    mse_ += mse.item()
    print(f"MSE: {mse.item()}")
    psnr = -10 * torch.log10(mse)
    psnr_ += psnr.item()
    print(f"PSNR: {psnr.item()}")
    plot_images.append(imgs[-3].cpu().numpy().clip(0, 1))
    plot_targets.append(target_imgs[-3].cpu().numpy().clip(0, 1))
    imgs = imgs.permute(0, 3, 1, 2)
    target_imgs = target_imgs.permute(0, 3, 1, 2)
    ssim_score = ssim(imgs, target_imgs)
    print(f"SSIM: {ssim_score.item()}")
    lpip_score = lpips(imgs, target_imgs)
    print(f"LPIPS: {lpip_score.item()}")
    ssim_ += ssim_score.item()
    lpips_ += lpip_score.item()
    mesh = trimesh.Trimesh(vertices_np, faces_np)
    chamfer_distance = compute_trimesh_chamfer(renderer.mesh, mesh)
    print(f"Chamfer distance: {chamfer_distance}")
    print()
    cd_ += chamfer_distance

print(f"Average MSE: {mse_ / 2}")
print(f"Average PSNR: {psnr_ / 2}")
print(f"Average SSIM: {ssim_ / 2}")
print(f"Average LPIPS: {lpips_ / 2}")
print(f"Average Chamfer distance: {cd_ / 2}")

In [None]:
plot_images = []

for t in [0.0, 1.0, 2.0]:
    renderer = Rs[0]
    time_step = t
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = renderer.render(vertices, faces, vertex_normals)[..., :3]
    plot_images.append(imgs[-1].cpu().numpy().clip(0, 1))
    plt.imshow(imgs[-1].cpu().numpy().clip(0, 1))
    plt.show()

## Dynamic SMPL2 with GS

In [None]:
Rs = []

for i in range(3):
    R_Colored = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_2/{i+1}.obj",
        scale=1.7,
        with_texture=True,
    )
    Rs.append(R_Colored)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_SECOND_SCENE_WITH_GS/model_2.pth")
)
sdf_head.eval()

render_head = RenderHead().cuda()
render_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_SECOND_SCENE_WITH_GS/render_head_2.pth")
)
render_head.eval()

In [None]:
from rendering_gaussian import Camera

plots_1 = []
plots_2 = []
plots_3 = []

for i, R_Colored in enumerate(Rs):
    cameras = []
    renderer = R_Colored

    time_step = 0
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)

    for j in range(100):
        cam = Camera(
            renderer.fov_x,
            renderer.fov_y,
            256,
            256,
            renderer.view_mats[j],
            renderer.mvps[j],
            renderer.view_mats[j][:3, 3],
        )
        cameras.append(cam)

    vert = torch.cat([vertices, torch.zeros_like(vertices[..., :1])], dim=-1)

    sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())

    mse_ = 0
    psnr_ = 0

    for VIEW_IDX in range(100):
        shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
        dir_pp = vertices[..., :3] - renderer.camera_positions[VIEW_IDX].repeat(
            sh_coeffs.shape[0], 1
        )
        dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
        sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        est = renderer.render_coloured(
            vertices[..., :3],
            faces,
            vertex_normals,
            vertex_colors=colors_precomp,
            view_idx=VIEW_IDX,
            albedo=1.0,
        )[0][..., :3]
        mse = torch.nn.MSELoss()(est, R_Colored.target_imgs[VIEW_IDX][..., :3])
        psnr = -10 * torch.log10(mse)
        mse_ += mse
        psnr_ += psnr
        if i == 0:
            plots_1.append(est.cpu().detach().numpy().clip(0, 1))
        elif i == 1:
            plots_2.append(est.cpu().detach().numpy().clip(0, 1))
        else:
            plots_3.append(est.cpu().detach().numpy().clip(0, 1))

    print(f"Average MSE: {mse_ / 100}")
    print(f"Average PSNR: {psnr_ / 100}")

In [None]:
plt.imshow(plots_1[2])
plt.axis("off")
plt.show()

In [None]:
for im in plots_2:
  plt.imshow(im)
  plt.axis("off")
  plt.show()

In [None]:
plt.imshow(plots_3[65])
plt.axis("off")
plt.show()

## Dynamic SMPL1 with GS

In [None]:
Rs = []

for i in range(3):
    R_Colored = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_1/{i+1}.obj",
        scale=1.7,
        with_texture=True,
    )
    Rs.append(R_Colored)

In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/3_MULTIFRAME_SMPL_FIRST_SCENE_WITH_GS/model.pth")
)
sdf_head.eval()

render_head = RenderHead().cuda()
render_head.load_state_dict(
    torch.load("logs/3_MULTIFRAME_SMPL_FIRST_SCENE_WITH_GS/render_head.pth")
)
render_head.eval()

In [None]:
from rendering_gaussian import Camera

plots_1 = []
plots_2 = []
plots_3 = []

total_mse = 0
total_psnr = 0

for i, R_Colored in enumerate(Rs):
    cameras = []
    renderer = R_Colored

    time_step = 0
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)

    for j in range(100):
        cam = Camera(
            renderer.fov_x,
            renderer.fov_y,
            256,
            256,
            renderer.view_mats[j],
            renderer.mvps[j],
            renderer.view_mats[j][:3, 3],
        )
        cameras.append(cam)

    vert = torch.cat([vertices, torch.zeros_like(vertices[..., :1])], dim=-1)

    sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())

    mse_ = 0
    psnr_ = 0

    for VIEW_IDX in range(100):
        shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
        dir_pp = vertices[..., :3] - renderer.camera_positions[VIEW_IDX].repeat(
            sh_coeffs.shape[0], 1
        )
        dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
        sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        est = renderer.render_coloured(
            vertices[..., :3],
            faces,
            vertex_normals,
            vertex_colors=colors_precomp,
            view_idx=VIEW_IDX,
            albedo=1.0,
        )[0][..., :3]
        mse = torch.nn.MSELoss()(est, R_Colored.target_imgs[VIEW_IDX][..., :3])
        psnr = -10 * torch.log10(mse)
        mse_ += mse
        psnr_ += psnr
        if i == 0:
            plots_1.append(est.cpu().detach().numpy().clip(0, 1))
        elif i == 1:
            plots_2.append(est.cpu().detach().numpy().clip(0, 1))
        else:
            plots_3.append(est.cpu().detach().numpy().clip(0, 1))

    print(f"Average MSE: {mse_ / 100}")
    print(f"Average PSNR: {psnr_ / 100}")

    total_mse += mse_ / 3 / 100
    total_psnr += psnr_ / 3 / 100

print(f"Total Average MSE: {total_mse}")  
print(f"Total Average PSNR: {total_psnr}")

In [None]:
plt.imshow(plots_1[42])
plt.axis("off")
plt.show()

In [None]:
for im in plots_2:
    plt.imshow(im)
    plt.axis("off")
    plt.show()

In [None]:
plt.imshow(plots_3[66])
plt.axis("off")
plt.show()

## Dynamic SMPL3 with GS

In [None]:
Rs = []

for i in range(3):
    R_Colored = Renderer(
        100,
        256,
        glctx=dr.RasterizeGLContext(),
        fname=f"data/SMPL/Colored_1/{i+1}.obj",
        scale=1.7,
        with_texture=True,
    )
    Rs.append(R_Colored)


In [None]:
sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_THIRD_SCENE_WITH_GS/model.pth")
)
sdf_head.eval()

render_head = RenderHead().cuda()
render_head.load_state_dict(
    torch.load("logs/1_MULTIFRAME_SMPL_THIRD_SCENE_WITH_GS/render_head.pth")
)
render_head.eval()

In [None]:
from rendering_gaussian import Camera

plots_1 = []
plots_2 = []
plots_3 = []

total_mse = 0
total_psnr = 0

for i, R_Colored in enumerate(Rs):
    cameras = []
    renderer = R_Colored

    time_step = 0
    vertices_np, faces_np = sdf_head.get_zero_points(time_step, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)

    for j in range(100):
        cam = Camera(
            renderer.fov_x,
            renderer.fov_y,
            256,
            256,
            renderer.view_mats[j],
            renderer.mvps[j],
            renderer.view_mats[j][:3, 3],
        )
        cameras.append(cam)

    vert = torch.cat([vertices, torch.zeros_like(vertices[..., :1])], dim=-1)

    sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())

    mse_ = 0
    psnr_ = 0

    for VIEW_IDX in range(100):
        shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
        dir_pp = vertices[..., :3] - renderer.camera_positions[VIEW_IDX].repeat(
            sh_coeffs.shape[0], 1
        )
        dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
        sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        est = renderer.render_coloured(
            vertices[..., :3],
            faces,
            vertex_normals,
            vertex_colors=colors_precomp,
            view_idx=VIEW_IDX,
            albedo=1.0,
        )[0][..., :3]
        mse = torch.nn.MSELoss()(est, R_Colored.target_imgs[VIEW_IDX][..., :3])
        psnr = -10 * torch.log10(mse)
        mse_ += mse
        psnr_ += psnr
        if i == 0:
            plots_1.append(est.cpu().detach().numpy().clip(0, 1))
        elif i == 1:
            plots_2.append(est.cpu().detach().numpy().clip(0, 1))
        else:
            plots_3.append(est.cpu().detach().numpy().clip(0, 1))

    print(f"Average MSE: {mse_ / 100}")
    print(f"Average PSNR: {psnr_ / 100}")

    total_mse += mse_ / 3 / 100
    total_psnr += psnr_ / 3 / 100

print(f"Total Average MSE: {total_mse}")  
print(f"Total Average PSNR: {total_psnr}")

In [None]:
for im in plots_3:
  plt.imshow(im)
  plt.axis("off")
  plt.show()

## Initializations

In [None]:
R = Renderer(
    1,
    1024,
    glctx=dr.RasterizeGLContext(),
    fname="data/Colored_Bunny/1.obj",
    scale=1.7,
    with_texture=False,
)

sdf_head = EncoderPlusSDF().cuda()
sdf_head.load_state_dict(torch.load("checkpoints/model.pth"))
sdf_head.eval()

In [None]:
cam = Camera(
    R.fov_x,
    R.fov_x,
    1024,
    1024,
    R.view_mats[0],
    R.mvps[0],
    R.camera_positions[0],
)

In [None]:
plot_imgs = []

for t in [0, 5, 10]:
    t = t / 10
    vertices_np, faces_np = sdf_head.get_zero_points(t, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    imgs = R.render(vertices, faces, vertex_normals)[..., :3]
    plot_imgs.append(imgs[-1].cpu().numpy().clip(0, 1))

In [None]:
plt.imshow(plot_imgs[0])
plt.axis("off")
plt.title("t = 0.0")
plt.show()

plt.imshow(plot_imgs[5])
plt.axis("off")
plt.title("t = 0.5")
plt.show()

plt.imshow(plot_imgs[10])
plt.axis("off")
plt.title("t = 1.0")
plt.show()

In [None]:
render_head = RenderHead().cuda()
render_head.load_state_dict(torch.load("checkpoints/render_init.pth"))
render_head.eval()

In [None]:
plot_imgs = []
plot_splats = []

for t in [0, 5, 10]:
    t = t / 10
    vertices_np, faces_np = sdf_head.get_zero_points(t, 200)
    vertices = torch.from_numpy(vertices_np).float().cuda()
    faces = torch.from_numpy(faces_np.copy()).int().cuda()
    face_normals = compute_face_normals(vertices, faces)
    vertex_normals = compute_vertex_normals(vertices, faces, face_normals)
    vert = torch.cat([vertices, torch.ones_like(vertices[..., :1]) * t], dim=-1)
    sh_coeffs, opacity, scaling, rotation = render_head(vert.detach())
    shs_view = sh_coeffs.transpose(1, 2).view(-1, 3, (3 + 1) ** 2)
    dir_pp = vertices[..., :3] - R.camera_positions[0].repeat(sh_coeffs.shape[0], 1)
    dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
    sh2rgb = eval_sh(3, shs_view, dir_pp_normalized)
    colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
    imgs = R.render_coloured(
        vertices[..., :3],
        faces,
        vertex_normals,
        vertex_colors=colors_precomp,
        view_idx=0,
        albedo=1.0,
    )[0][..., :3]
    plot_imgs.append(imgs.cpu().detach().numpy().clip(0, 1))
    splats = render_with_seperate_args(
        cam,
        vertices[..., :3],
        3,
        opacity,
        scaling,
        rotation,
        sh_coeffs,
        torch.tensor([0.0, 0.0, 0.0]).cuda(),
    )["render"]
    plot_splats.append(splats.cpu().detach().numpy().clip(0, 1))

In [None]:
plt.imshow(plot_imgs[0])
plt.axis("off")
plt.title("t = 0.0")
plt.show()

plt.imshow(plot_imgs[1])
plt.axis("off")
plt.title("t = 0.5")
plt.show()

plt.imshow(plot_imgs[2])
plt.axis("off")
plt.title("t = 1.0")
plt.show()