In [1]:
import numpy as np
from pathlib import Path
import cv2
import torch
from einops import rearrange
import k3d


def transpose(R, t, X):
    b, h, w, c = X.shape
    X = rearrange(X, "b h w c -> b c (h w)")

    X_after_R = R @ X + t[:, :, None]

    X_after_R = rearrange(X_after_R, "b c (h w) -> b h w c", h=h)
    return X_after_R


def pi_inv(K, x, d):
    fx, fy, cx, cy = K[:, 0:1, 0:1], K[:, 1:2, 1:2], K[:, 0:1, 2:3], K[:, 1:2, 2:3]
    X_x = d * (x[..., 0] - cx) / fx
    X_y = d * (x[..., 1] - cy) / fy
    X_z = d

    X = torch.stack([X_x, X_y, X_z], dim=-1)
    return X


def x_2d_coords(h, w, device):
    x_2d = torch.zeros((h, w, 2), device=device)
    for y in range(0, h):
        x_2d[y, :, 1] = y
    for x in range(0, w):
        x_2d[:, x, 0] = x
    return x_2d


def back_projection(depth, pose, K, x_2d=None):
    b, h, w = depth.shape
    if x_2d is None:
        x_2d = x_2d_coords(h, w, device=depth.device)[None, ...].repeat(b, 1, 1, 1)

    X_3d = pi_inv(K, x_2d, depth)

    Rwc, twc = pose[:, :3, :3], pose[:, :3, 3]
    X_world = transpose(Rwc, twc, X_3d)

    X_world = X_world.reshape((-1, h, w, 3))
    return X_world

In [2]:
scene_name = "LivingRoom-36282"
short_prompt = "classic"

In [3]:
output_dir = Path("outputs") / scene_name / short_prompt / "images"
depth_inv_files = sorted(output_dir.glob("*_depth_inv.png"))
depth_files = sorted(output_dir.glob("*_depth.png"))
pred_files = sorted(output_dir.glob("*_pred.png"))
pose_files = sorted(output_dir.glob("*_poses.txt"))
K_file = output_dir / "K.txt"

K = np.loadtxt(K_file)
poses = [np.loadtxt(f) for f in pose_files]
# for i in range(len(poses)):
#     poses[i][[0, 1]] *= -1
#     poses[i] = np.linalg.inv(poses[i])
# poses = [np.linalg.inv(p) for p in poses]
poses = np.stack(poses, axis=0)
depths = [cv2.imread(str(f), cv2.IMREAD_UNCHANGED) for f in depth_files]
depths = [d.astype(np.float32) / 1000.0 for d in depths]
depths = np.stack(depths, axis=0)

depths = torch.tensor(depths)
poses = torch.tensor(poses)
K = torch.tensor([K])

# latents = torch.cat([latents] * 2)
# timestep = torch.cat([timestep] * 2)
# poses = torch.cat([batch["poses"]] * 2)
# K = torch.cat([batch["K"]] * 2)
# depths = torch.cat([batch["depths"]] * 2)

  K = torch.tensor([K])


In [4]:
X_world = back_projection(depths, poses, K)

In [5]:
X_flat = X_world.reshape((-1, 3))
X_flat = X_flat.cpu().numpy()

plot = k3d.plot()
points = k3d.points(X_flat, point_size=0.01)
plot += points
plot.display()



Output()