In [9]:
import os
from pathlib import Path
import sys
sys.path.append(str(Path(os.path.abspath('')).parent))
print(str(Path(os.path.abspath('')).parent))
import math
import torch
from torch import Tensor
from delaunay_rasterization.internal.alphablend_tiled_slang import render_alpha_blend_tiles_slang_raw
from scipy.spatial import Voronoi, Delaunay
from torch import nn
import matplotlib.pyplot as plt
import mediapy
from icecream import ic
from data import loader
import random
import time
import tinycudann as tcnn
from utils.contraction import contract_mean_std
from utils import topo_utils
from tqdm import tqdm
import numpy as np
# from dtet import DelaunayTriangulation
from dtet.build.dtet import DelaunayTriangulation

K = 20

def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))

def get_slang_projection_matrix(znear, zfar, fy, fx, height, width, device):
    tanHalfFovX = width/(2*fx)
    tanHalfFovY = height/(2*fy)

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    z_sign = 1.0

    P = torch.tensor([
       [2.0 * znear / (right - left),     0.0,                          (right + left) / (right - left), 0.0 ],
       [0.0,                              2.0 * znear / (top - bottom), (top + bottom) / (top - bottom), 0.0 ],
       [0.0,                              0.0,                          z_sign * zfar / (zfar - znear),  -(zfar * znear) / (zfar - znear) ],
       [0.0,                              0.0,                          z_sign,                          0.0 ]
    ], device=device)

    return P

def common_camera_properties_from_gsplat(viewmats, Ks, height, width):
  """ Fetches all the Camera properties from the inria defined object"""
  zfar = 100.0
  znear = 0.01
  
  world_view_transform = viewmats
  fx = Ks[0,0]
  fy = Ks[1,1]
  projection_matrix = get_slang_projection_matrix(znear, zfar, fy, fx, height, width, Ks.device)
  fovx = focal2fov(fx, width)
  fovy = focal2fov(fy, height)

  cam_pos = viewmats.inverse()[:, 3]

  return world_view_transform, projection_matrix, cam_pos, fovy, fovx

/home/amai/delaunay_rasterization


In [10]:
train_cameras, test_cameras, scene_info = loader.load_dataset("/data/nerf_datasets/360/bicycle", "images_8", data_device="cuda", eval=True)

Reading camera 194/194
Loaded Train Cameras: 169
Loaded Test Cameras: 25


In [11]:

torch.manual_seed(2)
N = scene_info.point_cloud.points.shape[0]
vertices = torch.as_tensor(scene_info.point_cloud.points)[:N]
minv = vertices.min(dim=0, keepdim=True).values
maxv = vertices.max(dim=0, keepdim=True).values
repeats = 10
vertices = vertices.reshape(-1, 1, 3).expand(-1, repeats, 3)
vertices = vertices + torch.randn(*vertices.shape) * 1e-1
N = 50000
vertices = torch.cat([
  vertices.reshape(-1, 3),
  torch.rand((N, 3)) * (maxv - minv) + minv
], dim=0)
vertices = nn.Parameter(vertices.cuda())

device = torch.device('cuda')
encoding = tcnn.Encoding(3, dict(
    otype="HashGrid",
    n_levels=16,
    n_features_per_level=2,
    log2_hashmap_size=14,
    base_resolution=1,
    per_level_scale=1.5
))
network = tcnn.Network(encoding.n_output_dims, 4, dict(
    # otype="CutlassMLP",
    otype="FullyFusedMLP",
    activation="ReLU",
    output_activation="None",
    n_neurons=64,
    n_hidden_layers=2,
))
net = torch.nn.Sequential(
    encoding, network
).to(device)


def safe_exp(x):
  return x.clip(max=5).exp()

def safe_trig_helper(x, fn, t=100 * torch.pi):
  """Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
  return fn(torch.nan_to_num(torch.where(torch.abs(x) < t, x, x % t)))


def safe_cos(x):
  """jnp.cos() on a TPU may NaN out for large values."""
  return safe_trig_helper(x, torch.cos)


def safe_sin(x):
  """jnp.sin() on a TPU may NaN out for large values."""
  return safe_trig_helper(x, torch.sin)

def rgbs_fn(xyz):
  cxyz, _ = contract_mean_std(xyz, torch.ones_like(xyz[..., 0]))
  rgbs_raw = net((cxyz/2 + 1)/2).float()
  rgbs = torch.cat([torch.sigmoid(rgbs_raw[:, :3]), safe_exp(rgbs_raw[:, 3:]-3)], dim=1)
  return rgbs


In [12]:
camera = train_cameras[0]
print(camera.projection_matrix)

tensor([[ 1.8801,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  2.8164,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0001,  1.0000],
        [ 0.0000,  0.0000, -0.0100,  0.0000]], device='cuda:0')


In [13]:
def render(camera, indices, vertices, cell_values=None):
    fy = fov2focal(camera.fovy, camera.image_height)
    fx = fov2focal(camera.fovx, camera.image_width)
    K = torch.tensor([
    [fx, 0, camera.image_width/2],
    [0, fy, camera.image_height/2],
    [0, 0, 1],
    ]).cuda()

    # world_view_transform, projection_matrix, cam_pos, fovy, fovx = common_camera_properties_from_gsplat(
    #     camera.world_view_transform.T, K, camera.image_height, camera.image_width)
    cam_pos = camera.world_view_transform.T.inverse()[:, 3]

    render_pkg = render_alpha_blend_tiles_slang_raw(indices, vertices, rgbs_fn,
                                                    camera.world_view_transform.T, K, cam_pos,
                                                    camera.fovy, camera.fovx, camera.image_height, camera.image_width,
                                                    cell_values=cell_values, tile_size=tile_size)
    return render_pkg

optim = torch.optim.Adam([
    {"params": net.parameters(), "lr": 1e-2},
    {"params": [vertices], "lr": 2.5e-4},
])
tile_size = 8
images = []
for i in tqdm(range(5001)):


    ind = random.randint(0, len(train_cameras)-1)
    camera = train_cameras[ind]

    # st = time.time()
    if i % 20 == 0:
        if (i > 0):
            loss.backward()
            optim.step()
            optim.zero_grad()
        loss = 0
        render_pkg = render(camera, indices, vertices)
        circumcenter = render_pkg['circumcenters']
        rgbs = torch.zeros((circumcenter.shape[0], 4), device=circumcenter.device)
        rgbs[:indices.shape[0]] = rgbs_fn(circumcenter[:indices.shape[0]])

        ind = 0
        camera = train_cameras[ind]
        render_pkg = render(camera, indices, vertices, cell_values=rgbs)
        image = render_pkg['render']
        image = image.permute(1, 2, 0)
        image = image.detach().cpu().numpy()
        images.append(image)

    if i % 20 == 0:
        # v = Delaunay(vertices.detach().cpu().numpy())
        # indices_np = v.simplices
        # indices_np = indices_np[np.lexsort(indices_np.T)]

        v = DelaunayTriangulation()
        v.init_from_points(vertices.detach().cpu().numpy())
        indices_np = v.get_cells().astype(np.int32)

        indices = torch.as_tensor(indices_np).cuda()
        old_vertices = vertices.detach().cpu().numpy()

    target = camera.original_image.cuda()

    st = time.time()
    render_pkg = render(camera, indices, vertices, cell_values=rgbs)
    # print(1/(time.time()-st))
    image = render_pkg['render']
    loss += ((target - image)**2).mean()

    # new_vertices = vertices.detach().cpu().numpy()
    # indices_np, _ = topo_utils.update_tetrahedralization(old_vertices, new_vertices, indices_np)
    # indices = torch.as_tensor(indices_np).cuda()

    if i % 200 == 0:
        # plt.imshow(image)
        # plt.show()
        mediapy.show_video(images)
mediapy.show_video(images)

  0%|          | 0/5001 [00:00<?, ?it/s]

0
This browser does not support the video tag.


  4%|▍         | 188/5001 [00:41<19:20,  4.15it/s]

0
This browser does not support the video tag.


  8%|▊         | 400/5001 [01:25<12:12,  6.28it/s]

0
This browser does not support the video tag.


 12%|█▏        | 597/5001 [02:08<12:36,  5.82it/s]

0
This browser does not support the video tag.


 16%|█▌        | 786/5001 [02:52<16:50,  4.17it/s]

0
This browser does not support the video tag.


 20%|█▉        | 1000/5001 [03:36<10:39,  6.25it/s]

0
This browser does not support the video tag.


 24%|██▍       | 1197/5001 [04:19<10:53,  5.82it/s]

0
This browser does not support the video tag.


 28%|██▊       | 1387/5001 [05:03<14:42,  4.09it/s]

0
This browser does not support the video tag.


 32%|███▏      | 1599/5001 [05:46<09:04,  6.25it/s]

0
This browser does not support the video tag.


 36%|███▌      | 1786/5001 [06:31<13:08,  4.08it/s]

0
This browser does not support the video tag.


 40%|███▉      | 1988/5001 [07:15<11:43,  4.29it/s]

0
This browser does not support the video tag.


 44%|████▍     | 2199/5001 [07:59<07:32,  6.20it/s]

0
This browser does not support the video tag.


 48%|████▊     | 2386/5001 [08:43<10:43,  4.06it/s]

0
This browser does not support the video tag.


 52%|█████▏    | 2596/5001 [09:27<07:02,  5.69it/s]

0
This browser does not support the video tag.


 56%|█████▌    | 2796/5001 [10:10<06:25,  5.72it/s]

0
This browser does not support the video tag.


 60%|█████▉    | 3000/5001 [10:55<05:23,  6.19it/s]

0
This browser does not support the video tag.


 64%|██████▍   | 3200/5001 [11:38<04:43,  6.35it/s]

0
This browser does not support the video tag.


 68%|██████▊   | 3397/5001 [12:22<04:29,  5.95it/s]

0
This browser does not support the video tag.


 70%|██████▉   | 3480/5001 [12:42<05:33,  4.56it/s]


KeyboardInterrupt: 