In [None]:
import sys

sys.path.append('..')

import os
import numpy as np

# import transforms
# import bezier_interspace_transforms
from bezier_set import BezierSet
import camera_settings

import torch

import open3d as o3d

In [None]:
class DiffRenderCatheter:
    def __init__(self):

        ## initialize camera parameters
        self.setCameraParams(camera_settings.a, camera_settings.b, camera_settings.center_x, camera_settings.center_y,
                             camera_settings.image_size_x, camera_settings.image_size_y, camera_settings.extrinsics,
                             camera_settings.intrinsics)


        # mesh_cylinder = o3d.geometry.TriangleMesh.create_mesh_cylinder(radius=1.0, height=2.0, resolution=20, split=4)
        
        ## initialize a catheter
        n_beziers = 1
        self.bezier_set = BezierSet(n_beziers)

        self.bezier_num_samples = 101
        self.bezier_surface_resolution = 50

        self.bezier_radius = 0.0015

    def setCameraParams(self, fx, fy, cx, cy, size_x, size_y, camera_extrinsics, camera_intrinsics):
        """
        Set intrinsic and extrinsic camera parameters

        Args:
            fx (float): horizontal direction focal length
            fy (float): vertical direction focal length
            cx (float): horizontal center of image
            cy (float): vertical center of image
            size_x (int): width of image
            size_y (int): height of image
            camera_extrinsics ((4, 4) numpy array): RT matrix 
            camera_intrinsics ((3, 3) numpy array): K matrix 
        """
        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy
        self.size_x = size_x
        self.size_y = size_y
        self.cam_RT_H = camera_extrinsics
        self.cam_K = camera_intrinsics

        # camera E parameters
        cam_RT_H = torch.tensor([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
        invert_y = torch.tensor([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]])
        self.cam_RT_H = torch.matmul(invert_y, cam_RT_H)

    def getBezierCurve(self, para_gt, p_start):

        p_mid = para_gt[0:3]
        p_end = para_gt[3:6]
        p_c1 = 4 / 3 * p_mid - 1 / 3 * p_start
        p_c2 = 4 / 3 * p_mid - 1 / 3 * p_end
        # self.control_pts = torch.vstack((p_start, c2, p_end, c1))

        self.num_samples = 101
        sample_list = torch.linspace(0, 1, self.num_samples)

        # Get positions and normals from samples along bezier curve
        self.bezier_pos = torch.zeros(self.num_samples, 3)
        self.bezier_der = torch.zeros(self.num_samples, 3)
        self.bezier_snd_der = torch.zeros(self.num_samples, 3)
        for i, s in enumerate(sample_list):
            self.bezier_pos[i, :] = (1 - s)**3 * p_start + 3 * s * (1 - s)**2 * \
                p_c1 + 3 * (1 - s) * s**2 * p_c2 + s**3 * p_end
            self.bezier_der[i, :] = -(1 - s)**2 * p_start + ((1 - s)**2 - 2 * s *
                                                             (1 - s)) * p_c1 + (-s**2 + 2 *
                                                                                (1 - s) * s) * p_c2 + s**2 * p_end
            self.bezier_snd_der[i, :] = 6 * (1 - s) * (p_c2 - 2 * p_c1 + p_start) + 6 * s * (p_end - 2 * p_c2 + p_c1)

        # Convert positions and normals to camera frame
        pos_bezier_H = torch.cat((self.bezier_pos, torch.ones(self.num_samples, 1)), dim=1)

        pos_bezier_cam_H = torch.transpose(torch.matmul(self.cam_RT_H, torch.transpose(pos_bezier_H, 0, 1)), 0, 1)
        self.pos_bezier_cam = pos_bezier_cam_H[1:, :-1]

        der_bezier_H = torch.cat((self.bezier_der, torch.zeros((self.num_samples, 1))), dim=1)
        der_bezier_cam_H = torch.transpose(torch.matmul(self.cam_RT_H, torch.transpose(der_bezier_H[1:, :], 0, 1)), 0,
                                           1)
        self.der_bezier_cam = der_bezier_cam_H[:, :-1]

    def getBezierTNB(self, bezier_pos, bezier_der, bezier_snd_der):

        bezier_der_n = torch.linalg.norm(bezier_der, ord=2, dim=1)
        self.bezier_tangent = bezier_der / torch.unsqueeze(bezier_der_n, dim=1)

        bezier_normal_numerator = torch.linalg.cross(bezier_der, torch.linalg.cross(bezier_snd_der, bezier_der))
        bezier_normal_numerator_n = torch.mul(
            bezier_der_n, torch.linalg.norm(torch.linalg.cross(bezier_snd_der, bezier_der), ord=2, dim=1))
        
        self.bezier_normal = bezier_normal_numerator / torch.unsqueeze(bezier_normal_numerator_n, dim=1)

        bezier_binormal_numerator = torch.linalg.cross(bezier_der, bezier_snd_der)
        bezier_binormal_numerator_n = torch.linalg.norm(bezier_binormal_numerator, ord=2, dim=1)

        self.bezier_binormal = bezier_binormal_numerator / torch.unsqueeze(bezier_binormal_numerator_n, dim=1)
    
    def getBezierSurface(self):

        self.bezier_surface = torch.zeros(self.bezier_num_samples, self.bezier_surface_resolution, 3)

        theta_list = torch.linspace(0.0, np.pi, self.bezier_surface_resolution)

        for i in range(self.bezier_num_samples):
            surface_vec = self.bezier_radius * (- torch.mul(diff_catheter.bezier_normal[1, :], torch.unsqueeze(torch.cos(theta_list),dim=1)) +
                                                torch.mul(diff_catheter.bezier_binormal[1, :], torch.unsqueeze(torch.sin(theta_list),dim=1)))
            print(surface_vec)
            self.bezier_surface[i, :, :] = self.bezier_pos[i, :] + surface_vec

In [None]:
diff_catheter = DiffRenderCatheter()

para_gt = torch.tensor([0.02003904, 0.0016096, 0.10205799, 0.02489567, -0.04695673, 0.196168896], dtype=torch.float)
p_start = torch.tensor([0.02, 0.002, 0.0])

diff_catheter.getBezierCurve(para_gt, p_start)
diff_catheter.getBezierTNB(diff_catheter.bezier_pos, diff_catheter.bezier_der, diff_catheter.bezier_snd_der)
diff_catheter.getBezierSurface()

In [None]:
range(diff_catheter.bezier_num_samples)

In [None]:
diff_catheter.bezier_normal[1, :]

theta_list = torch.linspace(0.0, np.pi, 50)
torch.mul(diff_catheter.bezier_normal[1, :], torch.unsqueeze(torch.cos(theta_list),dim=1))


In [141]:

surface_vertices = torch.reshape(diff_catheter.bezier_surface, (-1, 3))
top_center_vertice = torch.unsqueeze(diff_catheter.bezier_pos[0, :], dim=0)
bot_center_vertice = torch.unsqueeze(diff_catheter.bezier_pos[-1, :], dim=0)
torch.cat((top_center_vertice, bot_center_vertice, surface_vertices), dim=0)
# top_center_vertice = torch.squeeze(diff_catheter.bezier_pos[0, :],dim=0)
# bot_center_vertice.shape

tensor([[ 2.0000e-02,  2.0000e-03,  0.0000e+00],
        [ 2.4896e-02, -4.6957e-02,  1.9617e-01],
        [ 2.0149e-02,  5.0745e-04, -2.2088e-06],
        ...,
        [ 2.4557e-02, -4.5496e-02,  1.9617e-01],
        [ 2.4651e-02, -4.5477e-02,  1.9617e-01],
        [ 2.4746e-02, -4.5464e-02,  1.9617e-01]])