In [2]:
import os
import numpy as np
import torch
import open3d as o3d
import matplotlib.pyplot as plt
import cv2
import torch_scatter as scatter


def euler2mat(angle):
    """Convert euler angles to rotation matrix.
     :param angle: [3] or [b, 3]
     :return
        rotmat: [3, 3] or [b, 3, 3]
    source
    https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
    """

    if len(angle.size()) == 1:
        x, y, z = angle[0], angle[1], angle[2]
        _dim = 0
        _view = [3, 3]
    elif len(angle.size()) == 2:
        b, _ = angle.size()
        x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
        _dim = 1
        _view = [b, 3, 3]

    else:
        assert False

    cosz = torch.cos(z)
    sinz = torch.sin(z)

    # zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
    # one = torch.ones([b], requires_grad=False, device=angle.device)[0]
    zero = z.detach()*0
    one = zero.detach()+1
    zmat = torch.stack([cosz, -sinz, zero,
                        sinz, cosz, zero,
                        zero, zero, one], dim=_dim).reshape(_view)

    cosy = torch.cos(y)
    siny = torch.sin(y)

    ymat = torch.stack([cosy, zero, siny,
                        zero, one, zero,
                        -siny, zero, cosy], dim=_dim).reshape(_view)

    cosx = torch.cos(x)
    sinx = torch.sin(x)

    xmat = torch.stack([one, zero, zero,
                        zero, cosx, -sinx,
                        zero, sinx, cosx], dim=_dim).reshape(_view)

    rot_mat = xmat @ ymat @ zmat
    # print(rot_mat)
    return rot_mat

def points2img(points, colors, angle, translation, image_height, image_width, size_x=4, size_y=4, return_info=False, step=False):
    """ Points to Image

    Args:
        points (Torch.Tensor): [B, num_points, 3]
        colors (Torch.Tensor): [B, num_points, 3] or None
        angle (Torch.Tensor): [v, 3]
        translation (Torch.Tensor): [v, 3]
        image_height (int): 
        image_width (int): 
        size_x (int, optional): _description_. Defaults to 4.
        size_y (int, optional): _description_. Defaults to 4.
        
    Return:
        imgs (Torch.Tensor): [B, v, image_height, image_width, 3]
    """
    B_ori, N, _ = points.shape
    
    v = angle.shape[0]
    angle = angle.to(points.device)
    rot_mat = euler2mat(angle).transpose(1, 2)
    
    # translation = torch.tensor(views_[:, 1, :]).float().to(points_.device)
    # translation = translation.unsqueeze(1)
    if torch.is_tensor(translation):
        trans = translation
    elif translation == 'mid':
        points_mean = points.mean(dim=1).unsqueeze(1) # [batch, 1, 3]
        points_mean *= torch.tensor([[1, 1, 1/3]])
        trans = points_mean
    elif translation == 'min':
        points_min = points.min(dim=1)[0].unsqueeze(1)
        trans = points_min
    else:
        trans = torch.tensor([[0, 0, 0]]).to(points.device)
    
    points -= trans
    points = torch.matmul(points.unsqueeze(1), rot_mat) # [batch, v, num_points, 3]
    points = points.reshape(-1, N, 3)

    t = points[:, :, 2].min(dim=1, keepdim=True)[0]/2

    sel = torch.zeros_like(points)
    sel[0:4, :, :] = (points[0:4, :, 2] <= 0).unsqueeze(-1)
    sel[4:12, :, :] = (points[4:12, :, 2] <= t[4:12]).unsqueeze(-1)
    
    points[0:4, :, :] = points[0:4, :, :] * (points[0:4, :, 2] <= 0).unsqueeze(-1)
    points[4:12, :, :] = points[4:12, :, :] * (points[4:12, :, 2] <= t[4:12]).unsqueeze(-1)

    B = B_ori*v
    if colors is not None:
        colors = torch.repeat_interleave(colors, v, dim=0)
    
    assert size_x % 2 == 0 or size_x == 1
    assert size_y % 2 == 0 or size_y == 1
    
    
    coord_x = points[:, :, 0] - points[:, :, 0].min(dim=1)[0].unsqueeze(-1) # [batch, num_points]
    coord_y = points[:, :, 1] - points[:, :, 1].min(dim=1)[0].unsqueeze(-1) # [batch, num_points]

    scale = 150
    
    coord_x = coord_x * scale + size_x/2
    coord_y = coord_y * scale + size_y/2
    
    _i = torch.linspace(-size_x/2, size_x/2-1, size_x, requires_grad=False, device=points.device)
    _j = torch.linspace(-size_y/2, size_y/2-1, size_y, requires_grad=False, device=points.device)
    
    extended_x = coord_x.unsqueeze(2).repeat([1, 1, size_x]) + _i # [batch, num_points, size_x]
    extended_y = coord_y.unsqueeze(2).repeat([1, 1, size_y]) + _j # [batch, num_points, size_y]
    extended_x = torch.clamp(extended_x, 0, image_width-1)
    extended_y = torch.clamp(extended_y, 0, image_height-1)
    
    extended_x = extended_x.ceil().long()
    extended_y = extended_y.ceil().long()
    
    extended_x = extended_x.unsqueeze(3).repeat([1, 1, 1, size_y]) # [batch, num_points, size_x, size_y]
    extended_y = extended_y.unsqueeze(2).repeat([1, 1, size_x, 1]) # [batch, num_points, size_x, size_y]
    
    # print(extended_x.min(), extended_x.max(), extended_y.min(), extended_y.max())
    # print(extended_x.shape, extended_y.shape)
    
    depth = points[:, :, 2]
    depth -= depth.min(dim=1)[0].unsqueeze(-1)
    depth /= depth.max(dim=1)[0].unsqueeze(-1)
    depth = depth.unsqueeze(1).unsqueeze(3).unsqueeze(4).repeat(1, 3, 1, size_x, size_y)
    if colors is not None:
        if colors.max() > 1:
            colors = colors / 255
        value = colors.unsqueeze(3).unsqueeze(4).repeat(1, 1, 1, size_x, size_y).permute(0, 2, 1, 3, 4) # [batch, 3, num_points, size_x, size_y]
    else:
        value = depth  # [batch, 3, num_points, size_x, size_y]
    
    coordinates = extended_x.reshape([B, -1]) * image_width + extended_y.reshape([B, -1]) # [batch, num_points*size_x*size_y]   
    coordinates = coordinates.unsqueeze(1).repeat(1, 3, 1) # [batch, 3, num_points*size_x*size_y]    
    coordinates[:, 1, :] += image_height * image_width
    coordinates[:, 2, :] += image_height * image_width * 2
    coordinates = coordinates.reshape([B, -1]) # [batch, 3*num_points*size_x*size_y]
    
    value = value.reshape([B, -1])
    
    # imgs = torch.zeros([B, 3*image_height*image_width], device='cpu')
    # coordinates = coordinates.to('cpu')
    # value = value.to('cpu')
    # imgs = imgs.scatter_(1, coordinates.long(), value)
    # imgs = imgs.to(points.device)
    
    depth = depth.reshape([B, -1]) * 1000
    depth = depth.floor()
    if colors is not None:
        depth += (value / 2)
    imgs = torch.zeros([B, 3*image_height*image_width]).to(points.device)
    out, argmax = scatter.scatter_max(src=depth, index=coordinates.long(), out=imgs, dim=1)
    if colors is not None:
        imgs -= imgs.floor()
        imgs *= 2

    imgs = imgs.reshape([B_ori, v, 3, image_height, image_width])
    
    if not return_info:
        return imgs
    else:
        info = {}
        info['x_min'] = points[:, :, 0].min(dim=1)[0].unsqueeze(-1) # [batch, 1]
        info['y_min'] = points[:, :, 1].min(dim=1)[0].unsqueeze(-1) # [batch, 1]
        info['scale'] = scale
        info['rot_mat'] = rot_mat
        info['through'] = t
        return imgs, info
    


n = "scene0000_00_aligned_vert.npy"
mesh_vertices = np.load(os.path.join('./scenes/', n))
points = torch.tensor(mesh_vertices[:, 0:3])
colors = torch.tensor(mesh_vertices[:, 3:6])
print(points.max(dim=0)[0], points.min(dim=0)[0])


# name = os.listdir('./raw_scenes/')
# name = [n for n in name if n.endswith('.ply')]

# n = "scene0020_00_vh_clean_2.ply"
# pcd = o3d.io.read_point_cloud(os.path.join('./raw_scenes/', n))
# points = torch.tensor(np.asarray(pcd.points)).float()
# colors = torch.tensor(np.asarray(pcd.colors)).float()
# print(points.max(dim=0)[0], points.min(dim=0)[0])

world2grid_gen = torch.tensor([[21.3325, 0.186166, 0, 15.5466], 
                               [-2.22001e-009, 2.54388e-007, 21.3333, 15.96],
                               [0.186166, -21.3325, 2.54397e-007, 81.3357], 
                               [0, 0, 0, 1]])
world2grid = torch.tensor([[20.1711, 6.94545, 0, -21.3023], 
                           [-8.28237e-008, 2.40538e-007, 21.3333, 16.0125], 
                           [6.94545, -20.1711, 2.54397e-007, 154.848], 
                           [0, 0, 0, 1]])
intrinsic_depth = torch.tensor([[573.702, 0, 324.702, 0],
                                [0, 574.764, 240.97, 0],
                                [0, 0, 1, 0],
                                [0, 0, 0, 1]])
pose = torch.tensor([[0.294751, 0.357687, -0.886105, 3.17002],
                     [0.951536, -0.195028, 0.23779, 2.91021],
                     [-0.0877612, -0.913249, -0.397837, 2.08994],
                     [0, 0, 0, 1]])


points_h = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=1)
print(points_h.max(dim=0)[0], points_h.min(dim=0)[0])

# points_h = torch.matmul(torch.inverse(world2grid_gen), points_h.transpose(0, 1)).transpose(0, 1)
# print(points_h.max(dim=0)[0], points_h.min(dim=0)[0])

points_h = torch.matmul(torch.inverse(world2grid), points_h.transpose(0, 1)).transpose(0, 1)
print(points_h.max(dim=0)[0], points_h.min(dim=0)[0])

points_h = torch.matmul(pose, points_h.transpose(0, 1)).transpose(0, 1)
print(points_h.max(dim=0)[0], points_h.min(dim=0)[0])

points_h = torch.matmul(intrinsic_depth, points_h.transpose(0, 1)).transpose(0, 1)
print(points_h.max(dim=0)[0], points_h.min(dim=0)[0])

# image_width = 320
# image_height = 240
# mask = (points_h[:, 0] > 0) & (points_h[:, 0] < image_width) & (points_h[:, 1] > 0) & (points_h[:, 1] < image_height)

# angles = torch.tensor([
#                 [0, -np.pi/3, np.pi/4],
#                 [0, -np.pi/3, 3*np.pi/4],
#                 [0, -np.pi/3, 5*np.pi/4],
#                 [0, -np.pi/3, 7*np.pi/4],
                
#                 [0, -np.pi/3, 0],
#                 [0, -np.pi/3, np.pi/4],
#                 [0, -np.pi/3, 2*np.pi/4],
#                 [0, -np.pi/3, 3*np.pi/4],
#                 [0, -np.pi/3, 4*np.pi/4],
#                 [0, -np.pi/3, 5*np.pi/4],
#                 [0, -np.pi/3, 6*np.pi/4],
#                 [0, -np.pi/3, 7*np.pi/4]
#                 ])

# imgs, info = points2img(points, angle=angles, translation='mid', image_height=256, image_width=256, colors=colors, size_x=4, size_y=4, return_info=True)

# plt.figure(figsize=(20, 20))
# for v in range(angles.shape[0]):
#     img = imgs[0, v].detach().numpy().transpose(1, 2, 0)
#     plt.subplot(3, 4, v+1)
#     plt.imshow(img)
# plt.tight_layout()    
# plt.show()

# intrisic_depth = np.array([[577.591, 0, 318.905, 0], [0, 578.73, 242.684, 0], [0, 0, 1, 0], [0, 0, 0, 1]])


tensor([3.4055, 3.6367, 2.9604]) tensor([-3.6358, -3.6122, -0.0649])
tensor([3.4055, 3.6367, 2.9604, 1.0000]) tensor([-3.6358, -3.6122, -0.0649,  1.0000])
tensor([-1.2286,  7.2390, -0.5801,  1.0000]) tensor([-1.5740,  7.0099, -0.9199,  1.0000])
tensor([ 6.1847,  0.2003, -3.8290,  1.0000]) tensor([ 5.7758, -0.1864, -4.1589,  1.0000])
tensor([ 2.2524e+03, -8.5723e+02, -3.8290e+00,  1.0000e+00]) tensor([ 2.0173e+03, -1.0544e+03, -4.1589e+00,  1.0000e+00])


In [None]:
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt