Test the accuracy of the projection of catheter in the past frames

In [7]:
import sys
sys.path.append('..')
sys.path.insert(1, 'E:/OneDrive - UC San Diego/UCSD/Lab/Catheter/diff_catheter/scripts')

import torch
import torch.nn as nn

import numpy as np

from construction_bezier import ConstructionBezier
from loss_define import (
    ContourChamferLoss, 
    TipDistanceLoss, 
    ImageContourChamferLoss, 
    GenerateRefData
)

from catheter_motion_tensor import CatheterMotion
from utils import *

class ReconstructionOptimizer(nn.Module): 

    def __init__(self, p_start, para_init, image_ref, gpu_or_cpu, past_frames_list, delta_u_list, img_save_path, image_save_path_list): 
        '''
        This function initializes the catheter optimization model.

        Args:
            p_start (tensor): starting point of the catheter
            image_ref (numpy array): reference image to compare to
            gpu_or_cpu (str): either 'cuda' or 'cpu'
        '''
        super().__init__()

        """
        NOTE: To initialize ConstructionBezier class as an instance object as ReconstructionOptimizer 
        during its initialization, it will cause the issue that gradient computation complexity 
        accumulates with the number of iterations. Because during the forward pass of each iteration, 
        some intermediate variables may be saved as instance variables of the ConstructionBezier class 
        and thus referenced in the backward pass of the next iteration, 
        resulting in each iteration not being independent of each other.
        """
        # self.build_bezier = ConstructionBezier(radius=0.0015)
        # self.build_bezier.to(gpu_or_cpu)
        # self.build_bezier.loadRawImage(img_save_path)
        
        self.img_save_path = img_save_path

        self.contour_chamfer_loss = ContourChamferLoss(device=gpu_or_cpu)
        self.contour_chamfer_loss.to(gpu_or_cpu)
        self.tip_distance_loss = TipDistanceLoss(device=gpu_or_cpu)
        self.tip_distance_loss.to(gpu_or_cpu)
        self.image_contour_chamfer_loss = ImageContourChamferLoss(device=gpu_or_cpu)
        self.image_contour_chamfer_loss.to(gpu_or_cpu)
        
        # Declare self.tip_euclidean_distance_loss as a variable that'll hold a single numpy scalar value
        self.tip_euclidean_distance_loss = None
        self.tip_loss = None

        self.p_start = p_start.to(gpu_or_cpu).detach()
        self.para_init = nn.Parameter(torch.from_numpy(para_init).to(gpu_or_cpu),
                                      requires_grad=True)
        
        
        image_ref = torch.from_numpy(image_ref.astype(np.float32))
        self.register_buffer('image_ref', image_ref)
        
        # Generate reference data, so you don't need to generate it in every forward pass
        self.generate_ref_data = GenerateRefData(self.image_ref)
        ref_catheter_contour = self.generate_ref_data.get_raw_contour()
        # self.register_buffer('ref_catheter_contour', ref_catheter_contour)
        self.ref_catheter_contour = ref_catheter_contour
        ref_catheter_centerline = self.generate_ref_data.get_raw_centerline()
        # self.register_buffer('ref_catheter_centerline', ref_catheter_centerline)
        self.ref_catheter_centerline = ref_catheter_centerline
        
        # self.register_buffer('delta_u_list', delta_u_list)
        self.delta_u_list = delta_u_list.to(gpu_or_cpu)
        # self.register_buffer('past_frames_list', past_frames_list)
        # self.register_buffer('image_save_path_list', image_save_path_list)
        self.image_save_path_list = image_save_path_list
        
        # Generate reference data for past frames
        contour_list = []
        centerline_list = []
        for image in past_frames_list:
            image = torch.from_numpy(image.astype(np.float32))
            generate_ref_data = GenerateRefData(image)
            ref_catheter_contour = generate_ref_data.get_raw_contour()
            ref_catheter_centerline = generate_ref_data.get_raw_centerline()
            contour_list.append(ref_catheter_contour)
            centerline_list.append(ref_catheter_centerline)
        # self.register_buffer('contour_list', torch.stack(contour_list))
        # self.register_buffer('centerline_list', torch.stack(centerline_list))
        self.contour_list = contour_list
        self.centerline_list = centerline_list
        
        self.gpu_or_cpu = gpu_or_cpu

    def forward(self, save_img_path): 
        '''
        Function to run forward pass of the catheter optimization model.
        Creates catheter model, gets projection onto 2d image, and calculates loss.

        Args:
            save_img_path (str): path to save the projection image to
        '''
        
        build_bezier = ConstructionBezier(radius=0.0015)
        build_bezier.to(self.gpu_or_cpu)
        build_bezier.loadRawImage(self.img_save_path)
        
        # Generate the Bezier curve cylinder mesh points
        build_bezier.getBezierCurveCylinder(self.p_start, self.para_init)

        # Get 2d projected Bezier Cylinder mesh points
        build_bezier.getCylinderMeshProjImg()

        # Get 2d projected Bezier centerline (position) points
        build_bezier.getBezierProjImg()
        
        build_bezier.draw2DCylinderImage(self.image_ref, save_img_path + 'current.png')

        
        catheterMotion = CatheterMotion(self.p_start, self.gpu_or_cpu, l=0.2, r=0.01)
        predicted_paras = catheterMotion.past_frames_prediction(self.delta_u_list, self.para_init)
        # print("predicted_paras.requires_grad:", predicted_paras.requires_grad)
        motion_model_loss = torch.tensor(0.0).to(self.gpu_or_cpu)
        for i in range(len(predicted_paras)):
            construction_bezier = ConstructionBezier()
            construction_bezier.to(self.gpu_or_cpu)
            construction_bezier.loadRawImage(self.image_save_path_list[i])
            construction_bezier.getBezierCurveCylinder(self.p_start, predicted_paras[i].to(self.gpu_or_cpu))
            construction_bezier.getCylinderMeshProjImg()
            construction_bezier.getBezierProjImg()
            
            img_ref_binary = process_image(self.image_save_path_list[i])
            image_ref = torch.from_numpy(img_ref_binary.astype(np.float32))
            construction_bezier.draw2DCylinderImage(image_ref, save_img_path + 'predicted_' + str(i) + '.png')


In [8]:
scripts_path = 'E:/OneDrive - UC San Diego/UCSD/Lab/Catheter/diff_catheter/scripts/test_diff_render_catheter_v2'
dataset_folder = "gt_dataset4"

###========================================================
### 1) SET TO GPU OR CPU COMPUTING
###========================================================
if torch.cuda.is_available():
    gpu_or_cpu = torch.device("cuda:0") 
    torch.cuda.set_device(gpu_or_cpu)
else:
    gpu_or_cpu = torch.device("cpu")

###========================================================
### 2) VARIABLES FOR BEZIER CURVE CONSTRUCTION
###========================================================
p_start = torch.tensor([0.02, 0.002, 0.000001]) # 0 here will cause NaN in draw2DCylinderImage, pTip

para_init = np.array([0.034, -0.01, 0.536, 0.2, -0.37, 0.6],
                    dtype=np.float32)

gt_name = 'gt_35_-0.0008_0.0008_0.2_0.01'
case_naming = scripts_path + '/' + dataset_folder + '/' + gt_name
img_save_path = case_naming + '.png'
cc_specs_path = case_naming + '.npy'
target_specs_path = None
viewpoint_mode = 1
transparent_mode = 0

img_ref_binary = process_image(img_save_path)

# Ground Truth parameters for catheter used in SRC presentation
para_gt_np = read_gt_params(cc_specs_path)
# para_gt = torch.tensor(para_gt_np, dtype=torch.float, device=gpu_or_cpu, requires_grad=False)
# end_effector_gt = para_gt[3:6]
para_init = np.array(para_gt_np, dtype=np.float32)

folder_path = scripts_path + '/' + dataset_folder + '/'
image_save_path_list = [
folder_path + 'gt_33_-0.0008_0.0004_0.2_0.01.png',
folder_path + 'gt_31_-0.0008_0.0000_0.2_0.01.png'
]

past_frames_list = []
for path in image_save_path_list:
    past_frames_list.append(process_image(path))

delta_u_list = torch.tensor([[0, 0.0004], [0, 0.0004]])

###========================================================
### 3) SET UP AND RUN OPTIMIZATION MODEL
###========================================================
catheter_optimize_model = ReconstructionOptimizer(p_start, para_init, img_ref_binary, gpu_or_cpu, past_frames_list, delta_u_list, img_save_path, image_save_path_list).to(gpu_or_cpu)

save_img_path = scripts_path + '/test_imgs/test_past_frame/'

catheter_optimize_model.forward(save_img_path)