In [None]:
import dataclasses as dc
import os

import torch
import numpy as np
from PIL import Image

from pydrake.math import RollPitchYaw, RotationMatrix, RigidTransform
from pydrake.common.eigen_geometry import Quaternion

import neural_renderer as nr
from simple_mesh import Mesh, intrinsic_matrix_from_fov, RgbDepth

In [None]:
np.set_printoptions(formatter={"float_kind": lambda x: f"{x:,.4f}"})
torch.set_grad_enabled(False)

In [None]:
device = torch.device("cuda")
dtype = torch.float32

In [None]:
width = 512
height = 512
# width = 128
# height = 128
fov_y = np.pi / 4
K = intrinsic_matrix_from_fov(width, height, fov_y)
K = torch.from_numpy(K).to(device, dtype).unsqueeze(0)

renderer = nr.Renderer(
    camera_mode='projection',
    projection_left_hand=False,
    image_size=width,
    light_direction=[-1, -1, -1],
    light_intensity_directional=0.5,
    light_intensity_ambient=0.75,
    background_color=[0.0, 0.0, 1.0],
)

In [None]:
full_state0 = np.load("full_state0.npy", allow_pickle=True).item()
full_state0_img = Image.open("full_state0.png")
display(full_state0_img)

In [None]:
x, _v, _F, _C, primitive_states = full_state0["state"]

In [None]:
x.shape

In [None]:
sphere = Mesh.from_file(
    "sphere_lowpoly.obj",
    color=torch.tensor([1.0, 0.0, 0.0]),
)
r = 1. / 150
# r = 1.0
sphere.vertices *= r

In [None]:
cylinder = Mesh.from_file("textured_cylinder.obj", load_texture=True)
r = 0.03
h = 0.3
scale = torch.tensor([r, h, r]).to(cylinder.vertices)
cylinder.vertices = scale.unsqueeze(0) * cylinder.vertices

In [None]:
def rigid_transform_to_torch(X, *, unsqueeze):
    R = X.rotation().matrix().copy()
    t = X.translation().copy()
    R = torch.from_numpy(R).to(device, dtype)
    t = torch.from_numpy(t).to(device, dtype)
    if unsqueeze:
        R = R.unsqueeze(0)
        t = t.unsqueeze(0)
    return R, t

In [None]:
def render(mesh, X):
    out = RgbDepth()
    vertices, faces, textures = mesh.unsqueeze()
    R, t = rigid_transform_to_torch(X, unsqueeze=True)
    out.rgb, out.depth, _ = renderer.render(
        vertices,
        faces,
        textures=textures,
        R=R,
        t=t,
        K=K,
    )
    return out

In [None]:
R_WP = RollPitchYaw(np.deg2rad([90, 0, 0])).ToRotationMatrix()
X_WP = RigidTransform(R_WP)

In [None]:
scene = Mesh.empty()

xs_PO = x.copy()
xs_WO = (X_WP @ xs_PO.T).T
xs_WO = torch.from_numpy(xs_WO).to(device, dtype)
for p_WO in xs_WO:
    scene.add_object(sphere, p_WO)

In [None]:
scene_cylinder = Mesh.empty()

p_PS = primitive_states[:3]
qwxyz_PS = primitive_states[3:]
X_PS = RigidTransform(Quaternion(qwxyz_PS), p_PS)
X_WS = X_WP @ X_PS
R, t = rigid_transform_to_torch(X_WS, unsqueeze=False)
scene_cylinder.add_object(cylinder, t, R)

In [None]:
@dc.dataclass
class RgbDepth:
    rgb: torch.Tensor = None  # Batched, NCHW
    depth: torch.Tensor = None  # Batched

    def add(self, other):
        if self.rgb is None:
            assert self.depth is None
            self.rgb = other.rgb
            self.depth = other.depth
        else:
            assert self.depth is not None
            new = other.depth < self.depth
            new_rgb = new.unsqueeze(1).repeat(1, 3, 1, 1)  # For color channel
            self.rgb[new_rgb] = other.rgb[new_rgb]
            self.depth[new] = other.depth[new]

    def numpy(self):
        rgb = self.rgb.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        depth = self.depth.cpu().numpy()
        return rgb, depth

In [None]:
X_WC = RigidTransform(
    p=[0.5, -0.5, 1.0],
    R=RollPitchYaw(np.deg2rad([180, 0, 0])).ToRotationMatrix(),
)
X_CW = X_WC.inverse()

full = RgbDepth()
dough = render(scene, X_CW)
full.add(dough)
cyl = render(scene_cylinder, X_CW)
full.add(cyl)

rgb, _ = full.numpy()
rgb = (rgb * 255).astype(np.uint8)
display(Image.fromarray(rgb))

In [None]:
X_CW = RigidTransform([0, 0, 0.5])

for rot_x in np.arange(0, 45, 15):
    print(rot_x)
    R_WWh = RollPitchYaw([np.deg2rad(rot_x), 0, 0]).ToRotationMatrix()

# for rot_y in np.arange(0, 45, 15):
#     print(rot_y)
#     R_WWh = RollPitchYaw([0, np.deg2rad(rot_y), 0]).ToRotationMatrix()

    X_WWh = RigidTransform(R_WWh)
    full = RgbDepth()
    full.add(render(cylinder, X_CW @ X_WWh))
    rgb, _ = full.numpy()
    rgb = (rgb * 255).astype(np.uint8)
    display(Image.fromarray(rgb))