In [76]:
%load_ext autoreload
import torch
import gsplat
import time
import numpy as np

from render_splats.dataloading import load_splatfacto, load_3dgs, bmv

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [77]:
# settings for nerf_synthetic/lego
height, width = 800, 800
fx, fy = 1111.1110311937682, 1111.1110311937682
camera_pose = torch.tensor([
	[ 0.6374,  0.7705, -0.,      0.    ],
	[ 0.2536, -0.2098, -0.9443, -0.    ],
	[-0.7276,  0.6019, -0.3291,  4.0311],
	[ 0.,      0.,      0.,      1.    ],
], device='cuda')

def render_splats_legacy(camera_pose, gs_data):
    resolution_factor = 1
    means2d, depths, radii, conics, compensation, num_tiles_hit, cov3d = gsplat.project_gaussians(
        means3d=gs_data.means,
        scales=gs_data.scales,
        glob_scale=1.,
        quats=gs_data.quaternions,
        # viewmat=torch.tensor(view_matrix, device="cuda", dtype=torch.float32),
        viewmat=camera_pose,
        fx=fx,
        fy=fy,
        cx=width / 2,
        cy=height / 2,
        img_height=round(height / resolution_factor),
        img_width=round(width / resolution_factor),
        block_width=16,
        clip_thresh=0.01,
    )

    camera_center = torch.linalg.inv(camera_pose)[:3, -1]
    view_dirs = bmv(gs_data.rotation.T, torch.nn.functional.normalize(gs_data.means - camera_center))
    harmonics = gsplat.spherical_harmonics(
        degrees_to_use=gs_data.max_sh_degree,
        dirs=view_dirs,
        coeffs=gs_data.features,
    )
    colors = harmonics + 0.5
    rendered_image_gsplat, image_alpha = gsplat.rasterize_gaussians(
        xys=means2d,
        depths=depths,
        radii=radii,
        conics=conics,
        num_tiles_hit=num_tiles_hit,
        colors=colors,
        opacity=gs_data.opacities[:, None],
        img_height=round(height / resolution_factor),
        img_width=round(width / resolution_factor),
        block_width=16,
        background=torch.ones(3, device="cuda"),
        return_alpha=True,
    )
    return rendered_image_gsplat

def render_splats(camera_pose, gs_data):
	rendered_image, rendered_alphas, render_info = gsplat.rasterization(
		gs_data.means,
		gs_data.quaternions,
		gs_data.scales,
		gs_data.opacities,
		gs_data.features,
		camera_pose[None],
		torch.tensor([
			[fx, 0, width / 2],
			[0, fy, height / 2],
			[0, 0, 1]
		], device='cuda')[None],
		width, height,
		sh_degree=gs_data.max_sh_degree,
		backgrounds=torch.ones((1, 3), device='cuda')
	)
	return rendered_image[0]

In [78]:
gs_data = load_splatfacto('/home/linus/workspace/nerfstudio/outputs/unnamed/splatfacto/2024-05-28_124041/nerfstudio_models/step-000029999.ckpt')

### Rendering
New version is faster than legacy method as expected

#### V1.0

In [79]:
n = 200
t = time.time()
for i in range(n):
	rendered_image = render_splats(camera_pose, gs_data)
	torch.cuda.synchronize()
t = time.time() - t

print(f'avg time: {t / n * 1e3} ms')

avg time: 1.5493535995483398 ms


#### Legacy

In [80]:
t = time.time()
for i in range(n):
	rendered_image = render_splats_legacy(camera_pose, gs_data)
	torch.cuda.synchronize()
t = time.time() - t

print(f'avg time: {t / n * 1e3} ms')

avg time: 1.7884361743927002 ms


### Pose gradients

New vresion is rougly 3 times slower than legacy version.

Can be mitigated by transforming gaussians to camera coordinates and rendering with identity pose.

Still slightly slower, possibly due to extra rotation matrix to quaternion conversion?

#### V1.0

In [81]:
camera_pose.requires_grad_(True)
t = time.time()
for i in range(n):
	rendered_image = render_splats(camera_pose, gs_data)
	grad, = torch.autograd.grad(rendered_image.mean(), camera_pose)
	torch.cuda.synchronize()
t = time.time() - t

print(f'avg time: {t / n * 1e3} ms')

avg time: 13.778949975967407 ms


#### Legacy

In [82]:
camera_pose.requires_grad_(True)
t = time.time()
for i in range(n):
	rendered_image = render_splats_legacy(camera_pose, gs_data)
	grad, = torch.autograd.grad(rendered_image.mean(), camera_pose)
	torch.cuda.synchronize()
t = time.time() - t

print(f'avg time: {t / n * 1e3} ms')

avg time: 4.621554613113403 ms


#### V1.0, transformed gaussians

In [83]:
camera_pose.requires_grad_(True)
t = time.time()
for i in range(n):
	rendered_image = render_splats(torch.eye(4, device='cuda'), camera_pose @ gs_data)
	grad, = torch.autograd.grad(rendered_image.mean(), camera_pose)
	torch.cuda.synchronize()
t = time.time() - t

print(f'avg time: {t / n * 1e3} ms')

avg time: 5.722715854644775 ms
