In [1]:
import numpy as np
import nrrd
import open3d as o3d
import scipy.ndimage
import nibabel as nib
import os
from scipy import ndimage
import pydicom
import time
import nibabel as nib
import pyvista as pv
import copy
import vtk
from scipy.ndimage import label, find_objects,generate_binary_structure
from skimage import morphology,draw,measure
import matplotlib.pyplot as plt
from skimage.morphology import ball, dilation, erosion, skeletonize_3d
import traceback
from PIL import Image
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import splprep, splev
import json
from stl import mesh
import SimpleITK as sitk
from abc import ABC, abstractmethod
from collections import deque
import pdb
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from scipy.optimize import linear_sum_assignment

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class utilsBase(ABC):
    def load_dicom_series_as_3d_array(self,folder_path):
        """
        从文件夹中读取 DICOM 序列并转换为三维 NumPy 数组。

        参数：
        folder_path: 包含 DICOM 文件的文件夹路径

        返回：
        3D NumPy 数组，表示 DICOM 序列图像体积
        """
        # 使用 ImageSeriesReader 读取 DICOM 序列
        reader = sitk.ImageSeriesReader()
        # 获取文件夹中所有 DICOM 文件的文件名
        dicom_series = reader.GetGDCMSeriesFileNames(folder_path)
        reader.SetFileNames(dicom_series)
        # 读取图像
        image = reader.Execute()
        # 将 SimpleITK 图像转换为 NumPy 数组
        img_array = sitk.GetArrayFromImage(image)
        # img_array 的形状是 [slice, height, width]
        # 如果需要的是 [height, width, slice] 形式，可以使用 np.transpose(img_array, (1, 2, 0))
                
        spacing = np.array(image.GetSpacing())
        direction = np.array(image.GetDirection()).reshape(3, 3)
        origin = np.array(image.GetOrigin())
        
        
        return np.transpose(img_array, (2, 1, 0)), spacing, origin, direction
    def retain_largest_connected_component(self, data):
        # 标记连通区域
        st = time.time()
        labeled_array = measure.label(data,background = 0)
        print('labeled_array time',time.time() - st)
        st = time.time()
        regions = measure.regionprops(labeled_array)
        print('regions time',time.time() - st)
        st = time.time()
        # 如果没有找到任何连通区域，直接返回原数组
        # 找到每个连通区域的大小
        regions = sorted(regions, key=lambda x:x.area, reverse=True)
        print('sorted time', time.time() - st)
        st = time.time()
        # 创建一个新数组，只包含最大的连通区域
        new_data = np.zeros_like(data)
        if len(regions) >0:
            largest_component = regions[0]
            new_data[largest_component.coords[:,0],largest_component.coords[:,1],largest_component.coords[:,2]] = 1
            # pcd = o3d.geometry.PointCloud()
            # # 假设 data 是一个 Nx3 的 numpy 数组
            # # data = np.random.rand(100, 3)  # 随机生成100个点
            # pcd.points = o3d.utility.Vector3dVector(largest_component.coords)
            # # 可视化点云
            # o3d.visualization.draw_geometries([pcd])
            print('final time', time.time() - st)
            # st = time.time()
            return new_data, largest_component.bbox
        else:
            print(traceback.format_exc())
            pdb.set_trace()
            return new_data, None
        
    def numpy_to_vtk_image(self,data):
        # 将 NumPy 数组转换为 vtkImageData
        image = vtk.vtkImageData()
        image.SetDimensions(data.shape)
        image.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 1)

        # 填充 vtkImageData 对象

        flat_data = np.transpose(data,(2,1,0)).ravel()
        for i in range(len(flat_data)):
            image.GetPointData().GetScalars().SetTuple1(i, flat_data[i])

        return image
    def extract_surface(self,image):
        # 使用 Marching Cubes 算法提取表面
        extractor = vtk.vtkDiscreteMarchingCubes()
        extractor.SetInputData(image)
        extractor.GenerateValues(1, 1, 1)  # 参数设置为（数量，最小值，最大值），提取值为1的表面
        extractor.Update()
        return extractor.GetOutput()

    def smooth_mesh(self,polydata):
        # 网格平滑
        smoother = vtk.vtkSmoothPolyDataFilter()
        smoother.SetInputData(polydata)
        smoother.SetNumberOfIterations(15)
        smoother.SetRelaxationFactor(0.1)
        smoother.FeatureEdgeSmoothingOff()
        smoother.BoundarySmoothingOn()
        smoother.Update()
        return smoother.GetOutput()
    def save_as_stl(self, polydata, filename):
    # 保存为 STL 文件
        stlWriter = vtk.vtkSTLWriter()
        stlWriter.SetFileName(filename)
        stlWriter.SetInputData(polydata)
        stlWriter.Write()
    def getSmoothedBiggestStl(self, big, hilumBox, middleRegion, offset):
        big[:max(1,hilumBox[0][0]-middleRegion),:,:] = 0
        big[min(hilumBox[0][1]+middleRegion,511):,:,:] = 0
        big[:,:max(hilumBox[1][0]-middleRegion,1),:] = 0
        big[:,min(hilumBox[1][1]+middleRegion,511):,:] = 0
        big[:,:,:max(hilumBox[2][0]-middleRegion-offset,1)] = 0
        big[:,:,min(511,hilumBox[2][1]+middleRegion-offset):] = 0
        
        
        big,box = self.retain_largest_connected_component(big)
        vtk_image = self.numpy_to_vtk_image(big)
        polydata = self.extract_surface(vtk_image)
        smoothed = self.smooth_mesh(polydata)
        return smoothed 
    
    def ww_wc(self, img, k='lungNoduleClass'):
        ref_dict = {"tsetra": [-600, 5500], "ADC": [1400, 1800], "BVAL": [60, 140], 'lungNoduleClass': [-500, 1500]}

        wcenter = ref_dict[k][0]
        wwidth = ref_dict[k][1]
        minvalue = (2 * wcenter - wwidth) / 2.0 + 0.5
        maxvalue = (2 * wcenter + wwidth) / 2.0 + 0.5

        dfactor = 255.0 / (maxvalue - minvalue)

        zo = np.ones(img.shape) * minvalue
        Two55 = np.ones(img.shape) * maxvalue
        img = np.where(img < minvalue, zo, img)
        img = np.where(img > maxvalue, Two55, img)
        img = ((img - minvalue) * dfactor)  # .astype('uint8')

        return img

    def getCos(self, v1,v2):
        V1 = v1 / np.linalg.norm(v1)
        V2 = v2 / np.linalg.norm(v2)
        return np.dot(V1, V2) 
    def decimate_mesh(self,input_polydata, reduction_rate=0.95):
        decimator = vtk.vtkDecimatePro()
        decimator.SetInputData(input_polydata)
        decimator.SetTargetReduction(reduction_rate)
        decimator.PreserveTopologyOn()
        decimator.Update()

        return decimator.GetOutput()
    
    def smoother(self, mesh,iteration=15, rate = 0.4):
        smoother1 = vtk.vtkSmoothPolyDataFilter()
        smoother1.SetInputData(mesh)
        smoother1.SetNumberOfIterations(iteration)  # 平滑迭代的次数
        smoother1.SetRelaxationFactor(rate)  # 平滑弛豫因子
        smoother1.FeatureEdgeSmoothingOff()
        smoother1.BoundarySmoothingOn()
        smoother1.Update()
        return smoother1.GetOutput()
    
    def gaussian_filter_smooth(self, points, sigma=2.0):
        from scipy.ndimage import gaussian_filter1d
        smoothed_points = np.copy(points)
        for i in range(3):  # 假设点是三维的
            smoothed_points[:, i] = gaussian_filter1d(points[:, i], sigma)
        return smoothed_points
    def smoothViaKeyPoints(self, points,key_indices, pointsNum = 100):
        key_points = points[key_indices]
        # 创建时间参数t
        t = np.linspace(0, 1, len(points))
        t_keys = t[key_indices]

        # 创建样条曲线
        cs_x = CubicSpline(t_keys, key_points[:, 0])
        cs_y = CubicSpline(t_keys, key_points[:, 1])
        cs_z = CubicSpline(t_keys, key_points[:, 2])

        # 绘制曲线
        t_fine = np.linspace(0, 1, pointsNum)
        return np.stack((cs_x(t_fine),cs_y(t_fine),cs_z(t_fine)),axis = 1)
    def mean_insert(self, points, num_insert= 3):
        newpoints = []
        for i, point in enumerate(points):
            if i+1 < len(points):
                newpoints.append(point)
                newpoints.append(((np.array(point) + np.array(points[i+1]))/2 ).tolist())
        newpoints.append(points[-1])
        newpoints.append( (np.array(points[-1]) + 0.001).tolist())
        return newpoints

    def create_hemisphere(self, radius, center, resolution):
        sphereSource = vtk.vtkSphereSource()
        sphereSource.SetRadius(radius)
        sphereSource.SetCenter(center)
        sphereSource.SetThetaResolution(resolution)
        sphereSource.SetPhiResolution(resolution)
        sphereSource.SetStartPhi(-180)
        sphereSource.SetEndPhi(180)
        sphereSource.Update()
        return sphereSource.GetOutput()
    def merge_and_save_mesh(self, mesh1, mesh2):
        """ 合并两个 mesh 并保存为 STL 文件 """
        appendFilter = vtk.vtkAppendPolyData()
        appendFilter.AddInputData(mesh1)
        appendFilter.AddInputData(mesh2)
        appendFilter.Update()
        return appendFilter.GetOutput()
    def ensure_triangles(self, poly_data):
        triangle_filter = vtk.vtkTriangleFilter()
        triangle_filter.SetInputData(poly_data)
        triangle_filter.Update()
        return triangle_filter.GetOutput()
    def perform_boolean_union(self, mesh1, mesh2):
        clean_filter1 = vtk.vtkCleanPolyData()
        clean_filter1.SetInputData(mesh1)
        clean_filter1.Update()

        clean_filter2 = vtk.vtkCleanPolyData()
        clean_filter2.SetInputData(mesh2)
        clean_filter2.Update()

        normals1 = vtk.vtkPolyDataNormals()
        normals1.SetInputData(clean_filter1.GetOutput())
        normals1.ComputePointNormalsOn()
        normals1.ComputeCellNormalsOn()
        normals1.Update()

        normals2 = vtk.vtkPolyDataNormals()
        normals2.SetInputData(clean_filter2.GetOutput())
        normals2.ComputePointNormalsOn()
        normals2.ComputeCellNormalsOn()
        normals2.Update()
        visualize(normals1.GetOutput())
        visualize(normals2.GetOutput())

        tri_poly_data1 = ensure_triangles(normals1.GetOutput())
        tri_poly_data2 = ensure_triangles(normals2.GetOutput())
        intersection_filter = vtk.vtkIntersectionPolyDataFilter()
        intersection_filter.DebugOn()
        intersection_filter.SetInputData(0, tri_poly_data1)
        intersection_filter.SetInputData(1, tri_poly_data2)
        intersection_filter.Update()

        boolean_filter = vtk.vtkBooleanOperationPolyDataFilter()
        boolean_filter.DebugOn()
        boolean_filter.SetOperationToUnion()
        boolean_filter.SetInputData(0, decimate_mesh(tri_poly_data1))
        boolean_filter.SetInputData(1, decimate_mesh(tri_poly_data2))
        boolean_filter.SetTolerance(1e-5)
        boolean_filter.Update()

        mesh = convert2Triangles(boolean_filter.GetOutput())
        visualize(mesh)

        return mesh

    def convert2Triangles(self, mesh):
        tri_filter = vtk.vtkTriangleFilter()
        tri_filter.SetInputData(mesh)
        tri_filter.PassLinesOff()
        tri_filter.PassVertsOff()
        tri_filter.Update()
        return tri_filter.GetOutput()

    def subdivide_mesh(self, mesh, number_of_subdivisions=1):
        subdivision_filter = vtk.vtkLoopSubdivisionFilter()
        subdivision_filter.SetInputData(mesh)
        subdivision_filter.SetNumberOfSubdivisions(number_of_subdivisions)
        subdivision_filter.Update()
        return subdivision_filter.GetOutput()
    def visualize(self, poly_data):
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputData(poly_data)
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        renderer = vtk.vtkRenderer()
        render_window = vtk.vtkRenderWindow()
        render_window.AddRenderer(renderer)
        render_window_interactor = vtk.vtkRenderWindowInteractor()
        render_window_interactor.SetRenderWindow(render_window)
        renderer.AddActor(actor)
        renderer.SetBackground(0.1, 0.2, 0.3)
        render_window.Render()
        render_window_interactor.Start()
    def check_mesh_intersection(self, mesh1, mesh2):
        intersection_filter = vtk.vtkIntersectionPolyDataFilter()
        intersection_filter.SetInputData(0, mesh1)
        intersection_filter.SetInputData(1, mesh2)
        intersection_filter.Update()

        # 检查是否有交点
        if intersection_filter.GetNumberOfIntersectionPoints() > 0:
            return True
        else:
            return False
    def writeStlAndMerge(self, mesh,previous, need2MergeStl = True,boolUnion= False):
        if boolUnion:
            previous = self.perform_boolean_union(previous, mesh)
        else:
            previous = self.merge_and_save_mesh(previous,mesh)

        return previous
    
    def subdivide_mesh(self, mesh, number_of_subdivisions=1):
        subdivision_filter = vtk.vtkLoopSubdivisionFilter()
        subdivision_filter.SetInputData(mesh)
        subdivision_filter.SetNumberOfSubdivisions(number_of_subdivisions)
        subdivision_filter.Update()
        return subdivision_filter.GetOutput()
    
    
    def lookTree(self, bloodTree,pointsListAll):
        pointsListAll.append(bloodTree['line'])
        for subt in bloodTree['subtree']:
            self.lookTree(subt,pointsListAll)
            
    def seelinesAndPoints(self, lines, thispoints, needColorSlowlyChange = False, color1 = [1,0,0],color2 = [0,0,1]):
        showlist = []
        for ii,pt in enumerate(thispoints):
            newpoints = [[0,0,0]]
            if len(pt) == 2 and (type(pt[0]) != int and type(pt[0]) != float):
                pt = pt[0]
            newpoints.append(pt) 
            pcd = o3d.geometry.PointCloud()
            try:
                pcd.points = o3d.utility.Vector3dVector(np.array(newpoints))
            except:
                print(traceback.format_exc())
                pdb.set_trace()
            if needColorSlowlyChange:
                pcd.paint_uniform_color(color1)
            else:
                pcd.paint_uniform_color(color2)
            showlist.append(pcd)
        time.sleep(1)
        for ii,line in enumerate(lines):
            newline = []
            for pt in line:    
                if len(pt) >=2 :
                    pt = pt[0]
                newline.append(pt)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(np.array(newline))
            if needColorSlowlyChange:
                pcd.paint_uniform_color([ii/ len(lines) , 0,0])
            else:
                pcd.paint_uniform_color(color1)
            showlist.append(pcd)

        # 可视化点云
        o3d.visualization.draw_geometries(showlist)

    def fitCurveAndComputeTangents(self,points):
        points = np.array(points)
        if len(points) < 2:
            raise ValueError("点集至少需要包含两个点")
        # 如果点集只有两个点，直接计算线性切线方向
        if len(points) == 2:
            direction = points[1] - points[0]
            direction = direction / np.linalg.norm(direction)  # 归一化
            return [direction] * 2  # 两个点的切线方向相同, 返回两个切线
        # 对超过两个点的情况，使用B样条拟合
        if len(points ) == 3:
            npoints = [points[0],(points[0]+points[1])/2, points[1],(points[1]+points[2]) /2, points[2] ]
            points = np.array(npoints)
        try:
            tck, u = splprep(points[::max(1,len(points)//50)].T, s=10)
        except:
            # pdb.set_trace()
            raise ValueError("something wrong in tck, points")
        # 求解切线方向 (求导)
        tangents = np.array(splev(u, tck, der=1))  # der=1 表示求一阶导数
        tangents = tangents.T  # 转置回 (n, 3) 形式
        # 归一化切线向量
        tangents = tangents / np.linalg.norm(tangents, axis=1)[:, np.newaxis]
        return tangents

In [92]:
class utilsVolumSearching(utilsBase):
    def findPath(self, volume, start, end):
        """
        使用广度优先搜索（BFS）从start到end查找路径，并返回路径上的所有点。
        参数:
        volume: 三维二值图像，值为1表示血管，0表示背景
        start: 起始点 (x1, y1, z1)
        end: 终止点 (x2, y2, z2)

        返回:
        path: 从start到end的路径上的所有点
        """
        # 用于表示每个点的邻域，26邻域表示允许在3个维度上移动
        directions = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), 
                      (0, 0, 1),(-1,-1,0),(-1,1,0), (1,-1,0) ,(1, 1, 0),(-1, 0, -1),
                      (-1,0,1),(1,0,-1),(1,0,1),(0,-1,-1),(0,-1,1),(0,1,-1),
                      (0,1,1),(-1,-1,-1),(-1,-1,1),(-1,1,-1),(-1,1,1),(1,-1,-1),(1,1,-1),(1,-1,1),(1,1,1)]
        queue = deque([start])
        came_from = {start: None}  # 用来记录路径
        while queue:
            current = queue.popleft()
            # 如果找到了终点
            if current == end:
                break
            # 遍历邻居
            for direction in directions:
                neighbor = (current[0] + direction[0], current[1] + direction[1], current[2] + direction[2])
                # 检查邻居是否有效并且没有被访问过
                if is_valid_point(volume, neighbor) and neighbor not in came_from:
                    queue.append(neighbor)
                    came_from[neighbor] = current
        # 如果终点没有被访问过，说明没有路径
        if end not in came_from:
            return None

        # 回溯路径
        path = []
        current = end
        while current is not None:
            path.append(current)
            current = came_from[current]
        path.reverse()  # 翻转路径以便从start到end
        return np.asarray(path)
    
    def findDirectPath(self,connectPoint,mask, region,growResult):
        coords = region.coords
        closedPoint = coords[np.argmin(np.sqrt(np.sum((connectPoint - coords)**2,axis = 1)))]
        path = findPath(growResult,tuple(connectPoint),tuple(closedPoint))
        newmask = np.zeros(mask.shape)
        newmask[mask>0] = 1
        newmask[path] = 1
        newmask[region.coords[:,0],region.coords[:,1],region.coords[:,2] ] = 1
        return newmask
    def checkConnected(self, root,region, threshold = 3):
        import numpy as np
        from scipy.spatial import KDTree
        def getAllPoints(root, l):
            for li in root['line']:
                l.append(li[0])
            for subt in root['subtree']:
                getAllPoints(subt,l)
        def compute_average_distance(A, B):
            # 构建A点集的KD树
            tree = KDTree(A)
            # 查找B中每个点在A中的最近邻点，并计算距离
            distances, _ = tree.query(B)
            # 计算所有点的平均距离
            average_distance = np.mean(distances)
            return average_distance
        A = []
        B = []
        getAllPoints(root,A)
        if type(region) != dict:
            B = region.coords.tolist()
        else:
            getAllPoints(region,B)
        average_distance = compute_average_distance(A, B)
        return average_distance < threshold
    def getNewPointOrder(self, queuesDic, nowPoint,mask, mainDirection, targetPoints, lookAngle, pixelOrConvo = 1):
        directions = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), 
                      (0, 0, 1),(-1,-1,0),(-1,1,0), (1,-1,0) ,(1, 1, 0),(-1, 0, -1),
                      (-1,0,1),(1,0,-1),(1,0,1),(0,-1,-1),(0,-1,1),(0,1,-1),
                      (0,1,1),(-1,-1,-1),(-1,-1,1),(-1,1,-1),(-1,1,1),(1,-1,-1),(1,1,-1),(1,-1,1),(1,1,1)]
        resDirs = {}
        order = 0
        resDirAndOrder = []
        if pixelOrConvo == 1:
            for d in directions:
                newP = [nowPoint[0]+d[0],nowPoint[1]+d[1],nowPoint[2]+d[2]]
                if newP[0]<0 or newP[0]>=mask.shape[0]:
                    continue
                if newP[1]<0 or newP[1]>=mask.shape[1]:
                    continue
                if newP[2]<0 or newP[2]>=mask.shape[2]:
                    continue
                if mask[newp[0],newp[1],newp[2]] != 0:
                    continue
                flagT = 0
                for tp in targetPoints:
                    cos1 = getCos(d, tp - nowPoint)
                    if cos1< np.cos(lookAngle):
                        flagT = 1
                        resDirAndOrder.append([d,order])
                    order+=1
                flagM = 0
                if flagT ==0:
                    cos1 =getCos(d, mainDirection)
                    if cos1< np.cos(lookAngle):
                        flagM = 1

                        resDirAndOrder.append([d,order])

                if flagM == 0:
                    order +=1

                    resDirAndOrder.append([d,order])
            
            return  resDirAndOrder
        else:
            return []
    def region_grow(self, volume, rootRegion, classMask, region, root, targetPoints, thresholdNeibour=200.0,thresholdHU = -200):
        try:    
            if checkConnected(root,region):
                return
        except:
            print(traceback.format_exc())
            pdb.set_trace()
        region_mask = np.zeros(volume.shape)
        showGrowing = np.zeros(volume.shape)
        rootMask = np.zeros(volume.shape)
        rootMask[rootRegion.coords[:,0],
                 rootRegion.coords[:,1],
                 rootRegion.coords[:,2]] = 1
        seeds = [region.coords[0]]
        struct = generate_binary_structure(3, 3)  # 3D的26邻域
        for seed in seeds:
            region_mask[seed[0],seed[1],seed[2]] = 1
        label = classMask[region.coords[0][0],region.coords[0][1],region.coords[0][2]]
        nowMask = np.where(classMask == label, 1, 0)
        nowCoords = np.where(nowMask>0)
        showz = nowCoords[2][0]
        img_volume = ww_wc(volume)
        image1_np = img_volume[:,:,showz]
        image2_np = nowMask[:,:,showz]
        image1 = Image.fromarray(image1_np)
        image2 = Image.fromarray(image2_np*200)

        pca = PCA(n_components=3)
        pca.fit(region.coords)
        mainDirection = pca.components_[0]
        if np.dot((rootRegion.coords[0] - region.coords[0]), mainDirection) <0:
            mainDirection = -mainDirection

        # 确保两张图片大小一致
        if image1.size == image2.size:
            print(f"图片大小一致: {image1.size}")
        else:
            print(f"图片大小不一致: image1: {image1.size}, image2: {image2.size}")

        # 获取图像大小
        width, height = image1.size

        # 创建一个新的空白图像用于并排显示两张图片
        combined_image = Image.new('RGB', (width * 2, height))

        # 将两张图片粘贴到新的空白图像上
        # combined_image.paste(image1, (0, 0))
        # combined_image.paste(image2, (width, 0))


        image2_gray = image2.convert('L')

        # 2. 创建一个全蓝色的图片并应用到标签上
        blue_label = Image.new('RGBA', image2.size, (0, 0, 255, 0))  # 初始透明蓝色
        for x in range(image2.size[0]):
            for y in range(image2.size[1]):
                # 根据灰度图的像素值，调整蓝色的透明度
                gray_value = image2_gray.getpixel((x, y))
                blue_label.putpixel((x, y), (0, 0, 255, int(gray_value * 0.3)))  # 透明度 0.3 (取值范围 0-255)

        # 将第一张图片转换为 RGBA 模式，以便可以叠加
        image1_rgba = image1.convert('RGBA')

        # 叠加蓝色标签到第一张图片上
        # combined_image = Image.new('RGB', (width * 2, height))

        combined_image.paste(Image.alpha_composite(image1_rgba, blue_label), (0, 0))
        combined_image.paste(image1, (width, 0))
        # 显示叠加后的图片
        combined_image.show()
        count = 0 
        speedQueue = {}
        queuesNum = 16
        for i in range(queuesNum):
            speedQueue[i] = []
        while True:
            flagContinue = False
            for queueIndex in range(queuesNum):
                if len(speedQueue[queueIndex])>0:
                    flagContinue = True
                    break
            if not flagContinue:
                break
            x, y, z = speedQueue[queueIndex].pop(0)
            if rootMask[x,y,z] == label:
                connectPoint = [x,y,z]
                break
            current_value = volume[x, y, z]
            count +=1
    #         for dx, dy, dz in np.ndindex(struct.shape):
    #             if struct[dx, dy, dz]:  # 忽略非邻居
    #                 nx, ny, nz = x + dx - 1, y + dy - 1, z + dz - 1
            resDirAndOrder = getNewPointOrder(
                queueDic, 
                nowPoint,mask, 
                mainDirection, 
                targetPoints, 
                lookAngle, 
                pixelOrConvo = 1)
            for dr in resDirAndOrder:
                neighbor_value = volume[dr[0][0],dr[0][1],dr[0][2]]
                if abs(neighbor_value - current_value) <= thresholdNeibour  and neighbor_value > thresholdHU:
                    region_mask[dr[0][0],dr[0][1],dr[0][2]] = 1
                    showGrowing[dr[0][0],dr[0][1],dr[0][2]] = 1
                    seeds[dr[1]].append(dr[0])
                    image1_np = img_volume[:,:,nz]
                    image2_np = showGrowing[:,:,nz]
                    image1 = Image.fromarray(image1_np)
                    image2 = Image.fromarray(image2_np*200)
                    width, height = image1.size
                    combined_image = Image.new('RGB', (width * 2, height))
                    image2_gray = image2.convert('L')
                    blue_label = Image.new('RGBA', image2.size, (0, 0, 255, 0))  # 初始透明蓝色
                    for x in range(image2.size[0]):
                        for y in range(image2.size[1]):
                            gray_value = image2_gray.getpixel((x, y))
                            blue_label.putpixel((x, y), (0, 0, 255, int(gray_value * 0.3)))  # 透明度 0.3 (取值范围 0-255)
                    image1_rgba = image1.convert('RGBA')
                    combined_image.paste(Image.alpha_composite(image1_rgba, blue_label), (0, 0))
                    combined_image.paste(image1, (width, 0))
                    combined_image.show()
        growResult = np.where((region_mask == 1)|(nowMask == 1),1,0 )

        nz = connectPoint[2]
        image1_np = img_volume[:,:,nz]
        image2_np = growResult[:,:,nz]
        image1 = Image.fromarray(image1_np)
        image2 = Image.fromarray(image2_np*200)
        width, height = image1.size
        combined_image = Image.new('RGB', (width * 2, height))
        image2_gray = image2.convert('L')
        blue_label = Image.new('RGBA', image2.size, (0, 0, 255, 0))  # 初始透明蓝色
        for x in range(image2.size[0]):
            for y in range(image2.size[1]):
                gray_value = image2_gray.getpixel((x, y))
                blue_label.putpixel((x, y), (0, 0, 255, int(gray_value * 0.3)))  # 透明度 0.3 (取值范围 0-255)
        image1_rgba = image1.convert('RGBA')
        combined_image.paste(Image.alpha_composite(image1_rgba, blue_label), (0, 0))
        combined_image.paste(image1, (width, 0))
        combined_image.show()

        try:
            newMask = findDirectPath(connectPoint, nowMask, region, growResult)
        except:
            print(traceback.format_exc())
            pdb.set_trace()
        # coords = np.where(bigJing==1)
        # center = np.mean(hilumBox,axis=1)[::-1]
        # coords = np.stack(coords,axis = 1)
        # pcd = o3d.geometry.PointCloud()
        # pcd.points = o3d.utility.Vector3dVector(coords)
        # # 可视化点云
        # o3d.visualization.draw_geometries([pcd])
        # newMask = np.where(nowMask ==1,1,0 )
        # newMask = np.where((region_mask)|(nowMask ==1),1,0 )
        BiggerRegionsDong, smallerRegionsDong = retain_connected_component_list(newMask,1)
        newtree = getTreeFromRegion(newMask, center, BiggerRegionsDong[0])
        # deleteOtherBranch(seed1,seed2,newroot, oldroot, targetroot)

        # coords = np.where(bigJing==1)
        # center = np.mean(hilumBox,axis=1)[::-1]
        # coords = np.stack(coords,axis = 1)
        # pcd = o3d.geometry.PointCloud()
        # pcd.points = o3d.utility.Vector3dVector(coords)
        # # 可视化点云
        # o3d.visualization.draw_geometries([pcd])

        return region_mask
    

In [93]:
class utilsTree(utilsBase):
    def label_vessel_grades(self,vessel_data, start_point):
        '''
        完成线段构建
        '''
        directions = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), 
                      (0, 0, 1),(-1,-1,0),(-1,1,0), (1,-1,0) ,(1, 1, 0),(-1, 0, -1),
                      (-1,0,1),(1,0,-1),(1,0,1),(0,-1,-1),(0,-1,1),(0,1,-1),
                      (0,1,1),(-1,-1,-1),(-1,-1,1),(-1,1,-1),(-1,1,1),(1,-1,-1),(1,1,-1),(1,-1,1),(1,1,1)]
        grades = np.zeros(vessel_data.shape)
        queue = deque([(start_point, 1, (-1,-1,-1))])  # (position, grade)
        grades[start_point[0],start_point[1],start_point[2]] = 1
        gradesDic = {}
        pointListWithPrior = {}
        while queue:
            (x, y, z), current_grade,(px,py,pz) = queue.popleft()
            connections = 0
            for dx, dy, dz in directions:
                nx, ny, nz = x + dx, y + dy, z + dz
                if 0 <= nx < vessel_data.shape[0] and 0 <= ny < vessel_data.shape[1] and 0 <= nz < vessel_data.shape[2]:
                    if vessel_data[nx, ny, nz] == 1  and grades[nx, ny, nz] == 0:
                        connections += 1
                        if connections > 1:  # More than one connection means a new branch
                            new_grade = current_grade + 1
                        else:
                            new_grade = current_grade
                        if new_grade not in gradesDic:
                            gradesDic[new_grade] = []

                        grades[nx, ny, nz] = new_grade
                        queue.append(((nx, ny, nz), new_grade,(x, y, z)))
                        if new_grade not in pointListWithPrior:
                            pointListWithPrior[new_grade] = [((nx, ny, nz),(x, y, z))]
                        else:
                            pointListWithPrior[new_grade].append(((nx, ny, nz),(x, y, z)))
        def getAllSeg(pointListWithPrior):
            '''
            完成线段构建，输入前序点集，输出线段集
            '''
            SegDic = {}
            for k in pointListWithPrior:
                readed = [0] * len(pointListWithPrior[k])

                startSignal = 1
                segNum = 0
                SegDic[k] = {}
                findflag = 0
                while sum(readed)<len(pointListWithPrior[k]):
                    # if startSignal:
                    #     SegDic[SegNum] = deque([pointListWithPrior[k][index])
                    if findflag != 1:
                        segNum += 1
                        for i, pointPair in enumerate(pointListWithPrior[k]):
                            if readed[i] == 1:
                                continue
                            SegDic[k][segNum] = deque([pointListWithPrior[k][i]])
                            readed[i] = 1 
                            break
                    findflag = 0
                    for i, pointPair in enumerate(pointListWithPrior[k]):
                        if readed[i] == 1:
                            continue
                        if pointPair[0] == SegDic[k][segNum][-1][1]:
                            SegDic[k][segNum].append(pointPair)
                            findflag = 1
                            readed[i] = 1 
                            break
                        if pointPair[1] == SegDic[k][segNum][0][0]:
                            SegDic[k][segNum].appendleft(pointPair)
                            findflag = 1
                            readed[i] = 1 
                            break
            return SegDic
        pointshowList = []
        for g in pointListWithPrior:
            for points in pointListWithPrior[g]:
                pointshowList.append(points[0])
        segDic = getAllSeg(pointListWithPrior)
        return grades,segDic
    def keyPointFind_treeGenerate(self, bloodTree, layer, segDic,findDad):
        for i in segDic:
            for j in segDic[i]:
                for k, point in enumerate(bloodTree['line']):
                    try:
                        if tuple(segDic[i][j][-1][1]) == tuple(point[0]) and findDad[i][j] == 0:
                            newtree = {
                                "line":list(segDic[i][j]),
                                "subtree":[],
                                "deep":[],
                                "subLength":[],
                                "dividePointIndex":[],
                                "layer":layer,}
                            findDad[i][j] = 1
                            bloodTree['subtree'].append(newtree)
                            bloodTree['dividePointIndex'].append(k)
                            self.keyPointFind_treeGenerate(newtree,layer+1, segDic,findDad)
                    except:
                        print(traceback.format_exc())
                        print(segDic[i][j][-1][1],segDic.keys(),segDic[i].keys(), i, j ,point,findDad[i][j],findDad.shape)
    def emptyDeep(self,bloodTree):
        bloodTree['subLength'] = []
        for i, subt in enumerate(bloodTree['subtree']):
            self.emptyDeep(subt)
    def assignDeep(self,bloodTree):
        maxdeep = 0
        maxlength = 0
        maxLengthInd = 0
        for i, subt in enumerate(bloodTree['subtree']):
            sbdeep, sblength = self.assignDeep(subt)
            index = bloodTree['dividePointIndex'][i]
            maxdeep = max(sbdeep,maxdeep)
            if sblength> maxlength:
                maxlength = sblength
                maxLengthInd = index
            bloodTree['deep'].append(sbdeep)
            bloodTree['subLength'].append(sblength)
        return maxdeep+1, maxlength+len(list(bloodTree['line'])[maxLengthInd:])
    def findLongestLine(self,bloodTree):
        if len(bloodTree['line'])>0 and len(bloodTree['subtree']) > 0:
            try:
                index = np.argmax(bloodTree['subLength'])
                newline = self.findLongestLine(bloodTree['subtree'][index])+ copy.deepcopy(bloodTree['line'][bloodTree['dividePointIndex'][index]:])
                return newline
            except:
                print('error happened',traceback.format_exc())
                pdb.set_trace()
                raise
        else:
            return copy.deepcopy(bloodTree['line'])
    def findSmallBranch(self,bloodTree):
        branches = []
        if len(bloodTree['subtree']) ==0:
            return branches
        # if len(bloodTree['subtree']) !=0 and len(bloodTree['subLength']) == 0:
            # pdb.set_trace()
        maxLengthInd = np.argmax(bloodTree['subLength'])    
        if bloodTree["dividePointIndex"][maxLengthInd] == 0:
            for i in range(len(bloodTree["dividePointIndex"])):
                if i != maxLengthInd:
                    branches.append({
                        'branch':copy.deepcopy(bloodTree['subtree'][i]),
                        'dividePointIndex': bloodTree['subLength'][maxLengthInd] + bloodTree["dividePointIndex"][i] - \
                        bloodTree["dividePointIndex"][maxLengthInd]
                        })

            return branches+self.findSmallBranch(bloodTree['subtree'][maxLengthInd])
        if bloodTree["dividePointIndex"][maxLengthInd] >= 1:
            newtree = {
                "line":copy.deepcopy(bloodTree['line'][:bloodTree["dividePointIndex"][maxLengthInd]]),
                "subtree":[],
                "deep":[],
                "subLength":[],
                "dividePointIndex":[],
                "layer":999,}
            for i in range(len(bloodTree["dividePointIndex"])):
                if bloodTree["dividePointIndex"][i] < bloodTree["dividePointIndex"][maxLengthInd]:
                    newtree['subtree'].append(copy.deepcopy(bloodTree['subtree'][i]))
                    newtree['dividePointIndex'].append(bloodTree["dividePointIndex"][i])
                    newtree['subLength'].append(bloodTree['subLength'][i])
                if bloodTree["dividePointIndex"][i]>=bloodTree["dividePointIndex"][maxLengthInd] and i != maxLengthInd:
                    branches.append(
                        {'branch':copy.deepcopy(bloodTree['subtree'][i]),
                        'dividePointIndex': bloodTree['subLength'][maxLengthInd] + bloodTree["dividePointIndex"][i] - \
                         bloodTree["dividePointIndex"][maxLengthInd]
                        })
            branches.append(
                {'branch':newtree,
                 'dividePointIndex': bloodTree['subLength'][maxLengthInd]})
            return branches + self.findSmallBranch(bloodTree['subtree'][maxLengthInd])
    def createNewTreeFromOld(self,bloodTree):
        newline = self.findLongestLine(bloodTree)
        newtree = {
            "line":newline,
            "subtree":[],
            "deep":[],
            "subLength":[],
            "dividePointIndex":[],
            "layer":0,}
        # seePoints(newline)
        branches = self.findSmallBranch(bloodTree)
        for br in branches:
            newtree['subtree'].append(self.createNewTreeFromOld(br['branch']))
            newtree["dividePointIndex"].append(br['dividePointIndex'])
            if br['branch']['subLength'] != []:
                newtree["subLength"].append(max(br['branch']['subLength']))
        
        return newtree

In [94]:
import math
class utilsConnection(utilsBase):
    def __init__(self):
        super(utilsBase,self).__init__()
        self.vesselWithDir = None


        
    def getPointsInField(self, p1, pts, v, r, h):
        # 将所有向量转换为numpy数组
        p1 = np.array(p1)
        pts = np.array(pts)
        V = np.array(v)
        # 计算向量P1P2
        P1P2 = pts - p1
        # 计算单位向量
        V_unit = V / np.linalg.norm(V)
        # 计算P1P2在V方向上的投影长度
        projection_length = np.dot(P1P2, V_unit)
        # 判断p2是否在锥体的高度范围内
    #     if projection_length < 0 or projection_length > h:
    #         return False
        # 计算P1P2与V的夹角余弦值
        P1P2_unit = P1P2 / np.expand_dims(np.linalg.norm(P1P2, axis = 1),axis=1)
        cos_angle = np.dot(P1P2_unit, V_unit)
        # 锥体角度的余弦值
        cos_r = np.cos(r)
        # 判断p2是否在锥形的角度范围内
        return pts[(cos_angle>np.cos(r))&(projection_length<=h)&(projection_length>0)]
    
    def getCosVal(self, v1,v2):
        """计算两个向量之间的夹角"""
        cos_theta = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
        cos_theta = np.clip(cos_theta, -1.0, 1.0)  # 确保余弦值在合法范围内
        return cos_theta
    
    def dfsGetPoints(self,targetRoot,targetPoints):
        for p in targetRoot['line']:
            targetPoints.append(p[0])
        for subt in targetRoot['subtree']:
            self.dfsGetPoints(subt,targetPoints)
        
    def dfsGetPointsWithWeight(self,targetRoot,targetPoints):
        for p in targetRoot['line'][::-1]:
            targetPoints.append(p[0])
        for subt in targetRoot['subtree']:
            self.dfsGetPoints(subt,targetPoints)       
    def fitCurveAndComputeTangents(self,points):
        points = np.array(points)
        if len(points) < 2:
            raise ValueError("点集至少需要包含两个点")
        # 如果点集只有两个点，直接计算线性切线方向
        if len(points) == 2:
            direction = points[1] - points[0]
            direction = direction / np.linalg.norm(direction)  # 归一化
            return [direction] * 2  # 两个点的切线方向相同, 返回两个切线
        # 对超过两个点的情况，使用B样条拟合
        if len(points ) == 3:
            npoints = [points[0],(points[0]+points[1])/2, points[1],(points[1]+points[2]) /2, points[2] ]
            points = np.array(npoints)
        try:
            tck, u = splprep(points.T, s=10)
        except:
            pdb.set_trace()
        # 求解切线方向 (求导)
        tangents = np.array(splev(u, tck, der=1))  # der=1 表示求一阶导数
        tangents = tangents.T  # 转置回 (n, 3) 形式
        # 归一化切线向量
        tangents = tangents / np.linalg.norm(tangents, axis=1)[:, np.newaxis]
        return tangents
    
    def findBestPoint(self, root, trp, points, directions):
        newLineDirection = (trp - points ) / np.expand_dims(np.linalg.norm(trp - points , axis=1),axis=1)
        rp = []
        rs = []

        targetbranch = []
        self.findBranch(root,trp,targetbranch)
        rootPoints = [p[0] for p in root['line']]
        tangents = self.fitCurveAndComputeTangents(rootPoints)
        if len(targetbranch)>1:
            tan = tangents[targetbranch[1]]        
        
        for direc in directions:
            cosLineMain = np.dot(newLineDirection, direc) 
            if len(targetbranch)>1:
                cosLineMain+= np.dot(newLineDirection, tan)
            
            newLineLength = np.sqrt(np.sum((trp - points  )**2, axis = 1))
            w = [0.1,0.9]
            scores = (100/newLineLength)* w[0] + cosLineMain*w[1]
            p,s = trp[np.argmax(scores)], scores[np.argmax(scores)]
            rp.append(p)
            rs.append(s)

        rp = np.stack(rp, axis = 0)
        rs = np.stack(rs, axis = 0)

        return rp[np.argmax(rs)],rs[np.argmax(rs)] , rp, rs
    
    def shortVersionConnection(self, root, targetRoot, angle = 75, height = 40):
        rpts = []
        self.dfsGetPointsWithWeight(root,rpts)
        trpts = []
        self.dfsGetPointsWithWeight(targetRoot,trpts)

        pcatr = PCA(n_components=3)
        pcatr.fit(np.array(trpts))
        mainDirectionGR = pcatr.components_[0]

        rtp = np.array(root['line'][-1][0])
        trtp = np.array(targetRoot['line'][-1][0])

        if len(targetRoot['line']) > 6:
            interval = 3
        else:
            interval = 1
        lpoints = targetRoot['line'][::interval]
        lpoints = [p[0] for p in lpoints]

        tgs = self.fitCurveAndComputeTangents(lpoints)

        if np.dot(mainDirectionGR, trtp - rtp) > 0:
            mainDirectionGR = -mainDirectionGR

        
        points = self.getPointsInField(trtp, np.array(rpts), np.array(mainDirectionGR), r=(angle*np.pi/180), h=height)
        pointstg1 = self.getPointsInField(trtp, np.array(rpts), np.array(tgs[-1]), r=(angle*np.pi/180), h=height)

        resPoints = [points,pointstg1]
        for ii in range(-3,-7,-2):
            try:
                resPoints.append(
                    self.getPointsInField(trtp, np.array(rpts), np.array(tgs[ii]), r=(angle*np.pi/180), h=height)
                                )
            except:
                pass

        dirs = [mainDirectionGR]
        for ii in range(-1,-4,-1):
            try:
                dirs.append(tgs[ii])
            except:
                pass

        resPoints = np.concatenate(resPoints, axis = 0)

        return resPoints, dirs
    
    def getGenLine(self, stp,dsp, double = False):
        def getDis(p1,p2):
            return np.sqrt(np.sum((p1-p2)**2))


        dire = dsp - stp

        direUnit = dire / np.linalg.norm(dire)

        nowstp = (stp+direUnit).astype(np.int32)
        if double:
            addpoints = [[tuple(nowstp),tuple(stp)]]
        else:
            addpoints = [nowstp]

        while getDis(dsp, nowstp) > 3:
            nowdir = dsp - nowstp
            nowdireUnit = nowdir / np.linalg.norm(nowdir) * 2
            laststp = nowstp
            nowstp = (nowstp+nowdireUnit).astype(np.int32)
            if double:
                addpoints.append([tuple(nowstp),tuple(laststp)])
            else:
                addpoints.append(nowstp)
        return addpoints


    def findBranch(self, root,rootPoint,tb):
            flag = 0
            recordi = -1
            for i, p in enumerate(root['line']):
                if np.sum((p[0] - rootPoint) ** 2) < 1e-6:
                    flag = 1
                    recordi = i

                    break
            if flag == 1:
                tb.append([root,i])
            else:
                for subt in root['subtree']:
                    self.findBranch(subt,rootPoint, tb)
    def addLine2Root(self, root,subroot,tarPoint, rootPoint, addlines,  showAllroot = None):
        # def findBranch(root,rootPoint,tb):
        #     flag = 0
        #     recordi = -1
        #     for i, p in enumerate(root['line']):
        #         if np.sum((p[0] - rootPoint) ** 2) < 1e-6:
        #             flag = 1
        #             recordi = i

        #             break
        #     if flag == 1:
        #         tb.append([root,i])
        #     else:
        #         for subt in root['subtree']:
        #             findBranch(subt,rootPoint, tb)

        tb = []
        self.findBranch(root,rootPoint, tb)

        pointList = [tarPoint,rootPoint]
        if tb[0][1] == 0:
            tb[0][0]['line'] = subroot['line']+ addlines +tb[0][0]['line']
            for i, dpi in enumerate(tb[0][0]['dividePointIndex']):
                tb[0][0]['dividePointIndex'][i] = len(subroot['line'])+ len(addlines)+tb[0][0]['dividePointIndex'][i]
            for i, subt in enumerate(subroot['subtree']):
                tb[0][0]['subtree'].append(subt)
                tb[0][0]['dividePointIndex'].append(subroot['dividePointIndex'][i])
            if showAllroot is not None:
                lineList = []
                self.lookTree(showAllroot,lineList)
                self.lookTree(subroot,lineList)
                self.seelinesAndPoints(lineList,[])
        else:
            tb[0][0]['dividePointIndex'].append(tb[0][1])
            tb[0][0]['subtree'].append(subroot)
            subroot['line'][0] = [subroot['line'][0][0],rootPoint]
            subroot['line'] = subroot['line'] + addlines
            
    def checkWongStart(self, connectPoint, root, subRoot, addline):
        def calculate_angle(v1, v2):
            """计算两个向量之间的夹角"""
            cos_theta = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
            cos_theta = np.clip(cos_theta, -1.0, 1.0)  # 确保余弦值在合法范围内
            return np.degrees(np.arccos(cos_theta))
        def find_turning_points(points, smoothing_factor=10, angle_threshold=30):
            # 对三维点集进行平滑
            tck, u = splprep(points.T, s=smoothing_factor)
            smoothed_points = np.array(splev(u, tck)).T
            # 计算切线方向向量
            derivatives = np.gradient(smoothed_points, axis=0)
            directions = derivatives / np.linalg.norm(derivatives, axis=1, keepdims=True)
            # 检测拐弯点
            turning_points = []
            angles = []
            for i in range(1, len(directions) - 1):
                angle = calculate_angle(directions[i - 1], directions[i])
                if angle > angle_threshold:
                    turning_points.append(smoothed_points[i])
                    angles.append(angle)

            return np.array(turning_points), angles
    
    def findAllNode(self, root, edgeList, lastNode = None):
        if lastNode is None:
            edgeList.append(root['line'][-1][0])
        edgeList.append(root['line'][0][0])

        for i,subt in enumerate(root['subtree']):
            self.findAllNode(subt, edgeList, lastNode = root['dividePointIndex'][i])
                
#     def judgeWrongConnect()

    
    
    
    def judgeGoodStartPoint(self, subRoot,root, addline, box, angle = 150, tanInterval = 10):
        '''
        subRoot:待合并分支
        root:主要枝干
        addline
        
        '''
        
        edgeList = []
        self.findAllNode(subRoot, edgeList, None)
        minCos = 1
        lowcosEdge = []
        lowCosVal = []
    
        for i,subt in enumerate(subRoot['subtree']):
            jointLine = subt['line'] + subRoot['line'][subRoot['dividePointIndex'][i]:]
            points = [point[0] for point in jointLine]
            tangents = self.fitCurveAndComputeTangents(points)
            interval = min(len(tangents), tanInterval)
            
            for i, t in enumerate(tangents):
                if i > len(tangents) - interval:
                    break
                nowMinCos = 1
                for ii in range(1, interval):
                    cosine = self.getCosVal(tangents[i], tangents[i+ii])
                    if cosine<minCos:
                        minCos = cosine
                    if cosine< math.cos(angle*np.pi/180):
                        lowcosEdge.append(jointLine[i+ii//2])
                        lowCosVal.append(cosine)
        # if judge!= 100
        # pdb.set_trace()
        if minCos > math.cos(angle*np.pi/180):
            return True, edgeList,lowcosEdge,lowCosVal
        return False,edgeList,lowcosEdge, lowCosVal

    def getVesselWithDir(self,root,lengthThreshold = 3):
        self.pointsWithDir = {}
        self.vesselWithDir = []
        if len(root['line']) < lengthThreshold:
            return
        points = [p[0] for p in root['line']][::5]
        try:
            tangents = self.fitCurveAndComputeTangents(points)
            for p in points:
                self.pointsWithDir[tuple(p)] = tangents

            self.vesselWithDir.append[[points,tangents]]
        except:
            pass
        for subt in root['subtree']:
            self.getVesselWithDir(subt, lengthThreshold=lengthThreshold)
        
        
        
    def findClosedLines(newLine,oldLines):
        from scipy.optimize import linear_sum_assignment
        from scipy.interpolate import interp1d
        def interpolate_points(points, num_points):
            """
            将点集合插值到指定数量的点。
            Args:
                points (np.ndarray): 原始点集合，形状为 [n, 3]
                num_points (int): 插值后的点数量
            Returns:
                np.ndarray: 插值后的点集合，形状为 [num_points, 3]
            """
            t_original = np.linspace(0, 1, len(points))
            t_interpolated = np.linspace(0, 1, num_points)
            interpolated_points = np.array([
                interp1d(t_original, points[:, dim], kind='linear')(t_interpolated)
                for dim in range(points.shape[1])
            ]).T
            return interpolated_points
        def calculate_average_distance(A, B):
            """
            计算两条线段的平均最小匹配距离，适用于点数不一致的情况。
            Args:
                A (np.ndarray): 线段 A 的点集合，形状为 [n, 3]
                B (np.ndarray): 线段 B 的点集合，形状为 [m, 3]
            Returns:
                float: 两条线段之间的平均距离
            """
            # 插值使两个集合的点数相等
            if len(A) > len(B):
                B = interpolate_points(B, len(A))
            elif len(A) < len(B):
                A = interpolate_points(A, len(B))
            
            # 计算距离矩阵
            distance_matrix = np.linalg.norm(A[:, np.newaxis, :] - B[np.newaxis, :, :], axis=2)
            
            # 使用匈牙利算法找到最优匹配
            row_ind, col_ind = linear_sum_assignment(distance_matrix)
            
            # 计算匹配对之间的总距离
            total_distance = distance_matrix[row_ind, col_ind].sum()
            
            # 计算平均距离
            average_distance = total_distance / len(A)
            
            return average_distance
        def find_nearest_segments(new_segment, segments, k=3):
            """
            找到距离新线段最近的 k 条线段。
            Args:
                new_segment (np.ndarray): 新的线段，形状为 [n, 3]
                segments (list of np.ndarray): 已知的 n 条线段，每条线段是形状为 [m, 3] 的点集合
                k (int): 返回最近的线段数量
            Returns:
                list of tuple: 距离最近的 k 条线段及其距离，格式为 (index, distance)
            """
            distances = []
            for i, segment in enumerate(segments):
                avg_distance = calculate_average_distance(new_segment, segment)
                distances.append((i, avg_distance))
            
            # 根据距离排序并取前 k 个
            distances = sorted(distances, key=lambda x: x[1])
            nearest_segments = distances[:k]
            return nearest_segments
        return find_nearest_segments(newLine, rootLines, k=3)
    
    
    def getOriginFromAroundVessel(self, brt, root, vesselLengthThreshold = 30, distanceThreshold = 20):
        '''
        root
        '''
        if self.vesselWithDir is None:
            self.getVesselWithDir(root, distanceThreshold)
        
        lines = [l[0] for l in self.vesselWithDir]

        readyLine = [p[0] for p in brt['line']]
        closedLinesInd = self.findClosedLines(readyLine, lines)
        alldis = [d for cli,d in closedLinesInd]
        alldis = np.sum(alldis)

        allPoint = []
        for cli, dis in closedLinesInd:
            allPoint +=  self.vesselWithDir[cli][0]
            allDire = self.vesselWithDir[cli][0]
        
        alldirec = np.asarry(allPoint) - brt['line'][-1][0]
        pdb.set_trace()
            



        

In [95]:
class utilsVTK(utilsBase):
    def writeStlAndMergeOther(self, mesh, need2MergeStl = True):
        if need2MergeStl:
            reader = vtk.vtkSTLReader()
            reader.SetFileName("G:\\tube11_previous_other.stl")
            reader.Update()
            previous = reader.GetOutput()
            if reader.GetOutput().GetNumberOfPoints() == 0:
                print("No points were loaded. Check file path and format.")
            else:
                # if check_mesh_intersection(reader.GetOutput(),mesh):
                try:
                    # mesh = merge_and_save_mesh(previous, mesh)
                    mesh = perform_boolean_union(previous, mesh)

                    # mesh = decimate_mesh(mesh,0.5)
                except:
                    mesh = previous
                # else:
                    # mesh = merge_and_save_mesh(reader.GetOutput(),mesh)
        writer = vtk.vtkSTLWriter()
        writer.SetFileName('G:\\tube11_previous_other.stl')
        writer.SetInputData(mesh)
        writer.Write()

        return mesh

    def cosine_annealing_lr(self, t, eta_min=0.001, eta_max=0.1, T_max=100, k=1):
        """Calculate the learning rate at epoch t with cosine annealing."""
        return eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * (T_max-t) / T_max))
    def linear(self, t,T_max,eta_max,eta_min, k =0.6):
        return eta_min + (eta_max - eta_min) * (t / (T_max - 1))*k

    def quadratic_annealing_lr(self, t, eta_max=1,eta_min = 0.5, T_max=100):
        """Calculate the learning rate at epoch t with quadratic annealing."""
        return (eta_max - eta_min) * (1 - (t / T_max) ** 2) + eta_min
    def inverse_time_decay_lr(self, t, eta_min=0.1, eta_max = 1, k=0.03, T_max=100):
        """Calculate the learning rate at epoch t using inverse time decay."""
        return (eta_max- eta_min) / (1 + k * t) + eta_min

    def exponential_decay_lr(self, eta_max=1, eta_min =0.1, k=0.2, t=1,T_max = 100):
        """Calculate the learning rate at epoch t with natural exponential decay."""
        return (eta_max -eta_min)  * np.exp(-k * t) + eta_min
    
    def create_tube(self, points, need2MergeStl= False, maxRadius=0.1, minRadius = 0.1, color=(1, 0, 0), k= 0.6):
        # 创建点数据

        points_vtk = vtk.vtkPoints()
        for p in points:
            points_vtk.InsertNextPoint(p)

        # 创建线段
        lines = vtk.vtkCellArray()
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(points))
        for i, _ in enumerate(points):
            line.GetPointIds().SetId(i, i)
        lines.InsertNextCell(line)

        # 创建 PolyData
        poly_data = vtk.vtkPolyData()
        poly_data.SetPoints(points_vtk)
        poly_data.SetLines(lines)

        # 使用管道过滤器
        tube_filter = vtk.vtkTubeFilter()

        # 设置半径随长度变化
        tube_filter.SetVaryRadiusToVaryRadiusByAbsoluteScalar()

        # 创建半径变化数组
        radii = vtk.vtkFloatArray()
        radii.SetNumberOfValues(poly_data.GetNumberOfPoints())
        for i in range(poly_data.GetNumberOfPoints()):
            # 线性半径变化，你可以根据需要调整计算方式
            # radii.SetValue(i, minRadius + (maxRadius - minRadius) * (i / (poly_data.GetNumberOfPoints() - 1)))
            radii.SetValue(i, self.linear(t=i, eta_min=minRadius, eta_max=maxRadius, T_max=poly_data.GetNumberOfPoints(),k=k))

        poly_data.GetPointData().SetScalars(radii)

        tube_filter.SetInputData(poly_data)
        # tube_filter.SetRadius(radius)
        tube_filter.SetNumberOfSides(30)
        tube_filter.CappingOn()
        tube_filter.Update()
        return tube_filter.GetOutput()
    def getLonger(self,pointsLine, addNum = 3):
        tangents = self.fitCurveAndComputeTangents(pointsLine)

        lastPoints = pointsLine[0]
        for i in range(addNum):
            addpoint = lastPoints - np.mean([tangents[-1],tangents[-2]]) *2
            pointsLine.append(addpoint)
            lastpoint = addpoint
        return pointsLine
        

    
    def renderingTree(self, bloodTree,layer,mainMesh = None, lastLayerMaxRadius = 5,lastLayerMinRadius=3, \
                      lastLayersLine = None, previous = None, k = 0.6 , smoothed_bigJing = None):
        '''
        渲染血管树
        '''
        if len(bloodTree['line']) == 0 :
            return
        pl = []
        keyIndex = copy.deepcopy(bloodTree['dividePointIndex'])

        '''
        延长覆盖
        '''
        pointsLine = [p[0] for p in bloodTree['line']]
        # try:
        #     pl = self.getLonger(pointsLine)
        # except:
        pl = pointsLine
            
        
        
        # for point in list(bloodTree['line']):
        #     pl = pl +[point[0]]
        
        
        keyIndex.append(0)
        
        if np.sum((np.array(bloodTree['line'][-1][1]) - np.array([0,0,0]))**2) > 1e-8:
            pl = pl+[bloodTree['line'][-1][1]]
            pl = pl+[bloodTree['line'][-1][1]]

            pl = pl+[bloodTree['line'][-1][1]]

            keyIndex.append(len(pl)-1)
            keyIndex.append(len(pl)-2)
            keyIndex.append(len(pl)-3)
        if lastLayersLine is not None:
            lastLayersLine = np.array(lastLayersLine)
            lastpoint = np.array(pl[-1])

            closestPoint = np.argmin(np.sum((lastLayersLine - lastpoint)**2,axis= 1))
            pl =  pl + lastLayersLine.tolist()[closestPoint:closestPoint+6]
            keyIndex = keyIndex+ [len(pl) - i for i in range(6)]


        keyIndex = list(set(keyIndex))
        keyIndex = sorted(keyIndex)
        pl = np.array(pl)
        if len(pl) == 1:
#             if nowBronchi == 1:
#                 mesh = self.decimate_mesh(self.convert2Triangles(self.create_hemisphere(lastLayerMaxRadius*k,pl[-1],30)),0.9)
#             else:
            mesh = self.decimate_mesh(self.convert2Triangles(smoothed_bigJing),0.9)

            if previous == None:
                previous = smoothed_bigJing
            else:
                previous = self.writeStlAndMerge(mesh, previous,need2MergeStl = mainMesh != None)
            for si,subt in enumerate(bloodTree['subtree']):
                previous = self.renderingTree(subt,layer+1, mainMesh = 1,lastLayerMaxRadius = lastLayerMaxRadius,\
                                         lastLayerMinRadius=lastLayerMinRadius, lastLayersLine = pl, previous = previous, k = k)
        else:
            if len(bloodTree['subtree']) == 0:
                pl = np.asarray(self.mean_insert(pl,2))
                pl = self.gaussian_filter_smooth(pl)
            else:
                pl = np.asarray(self.mean_insert(pl,2))

                pl = self.gaussian_filter_smooth(pl)

            if lastLayersLine is not None:
                lastLayersLine = np.array(lastLayersLine)
                lastpoint = np.array(pl[-1])
                closestPoint = np.argmin(np.sum((lastLayersLine - lastpoint)**2,axis= 1))
                maxRadius= ((closestPoint) / len(lastLayersLine)) * (lastLayerMaxRadius - lastLayerMinRadius) + lastLayerMinRadius

                maxRadius= self.linear(t=  closestPoint, eta_min=lastLayerMinRadius, eta_max=lastLayerMaxRadius, T_max=len(lastLayersLine),k = k)
                minRadius = lastLayerMinRadius

            else:
                maxRadius= lastLayerMaxRadius 
                minRadius =lastLayerMinRadius
            mesh = self.convert2Triangles(self.create_tube(pl,need2MergeStl = mainMesh != None, maxRadius=maxRadius, minRadius = minRadius,k = k))
            if previous == None:
#                 if nowBronchi == 1:
#                     previous = self.convert2Triangles(self.writeStlAndMerge(mesh,self.create_hemisphere(lastLayerMaxRadius*k,pl[-1],30), need2MergeStl = mainMesh != None))
#                 else:
                previous = self.convert2Triangles(self.writeStlAndMerge(mesh,smoothed_bigJing, need2MergeStl = mainMesh != None))
            else:
                previous = self.writeStlAndMerge(mesh,previous, need2MergeStl = mainMesh != None)
            for subt in bloodTree['subtree']:
                previous = self.renderingTree(subt,layer+1, mainMesh = 1,lastLayerMaxRadius = maxRadius,lastLayerMinRadius=minRadius,\
                                         lastLayersLine = pl, previous = previous, k = k)
        return previous
    
    
    def readDecSmoWrtTrans(self,mesh, tmpPath, targetPath, spacing, origin, direction, zoom_factors):
        writer = vtk.vtkSTLWriter()
        writer.SetFileName(tmpPath)
        writer.SetInputData(mesh)
        writer.Write()
        mesh1 = pv.read(tmpPath)
        mesh1 = self.decimate_mesh(self.smoother(pv.PolyData(mesh1),15,.5 ),.99)

        savetmpDong = targetPath
        writer = vtk.vtkSTLWriter()
        writer.SetFileName(savetmpDong)
        writer.SetInputData(mesh1)
        writer.Write()
        
        voxel_coords = self.loadStl(targetPath)
        # Convert to physical coordinates
        physical_coords = self.voxel2PhysicalCoordinates(zoom_factors,voxel_coords, spacing, origin, direction)
        # Assuming the faces remain the same, you can reshape and save the new mesh
        faces = physical_coords.reshape(-1, 3, 3)
        self.saveStl(targetPath, physical_coords, faces)
    
        
        
        
    def loadStl(self,stlFile):
        """
        Loads an STL file and returns the vertices as a numpy array.

        :param stl_file: Path to the STL file.
        :return: numpy array of shape (N, 3) containing the vertices.
        """
        stl_mesh = mesh.Mesh.from_file(stlFile)
        return stl_mesh.vectors.reshape(-1, 3)

    def saveStl(self,stlFile, vertices, faces):
        """
        Saves vertices and faces to an STL file.

        :param stl_file: Path to the STL file.
        :param vertices: numpy array of shape (N, 3) containing the vertices.
        :param faces: numpy array of shape (M, 3, 3) containing the faces.
        """
        new_mesh = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
        new_mesh.vectors = faces
        new_mesh.save(stlFile)
    def voxel2PhysicalCoordinates(self, zoom_factors, voxelCoords, spacing, origin, direction):
        """
        Converts voxel coordinates to physical coordinates for a given SimpleITK image.

        :param voxel_coords: numpy array of shape (N, 3) containing voxel coordinates.
        :param image: SimpleITK image.
        :return: numpy array of shape (N, 3) containing physical coordinates.
        """
        # Get image spacing, origin, and direction
        spacing = np.array([spacing[0]/zoom_factors[0],spacing[1]/zoom_factors[1],spacing[2]])

        # Convert voxel coordinates to physical coordinates
        physical_coords = np.dot(direction, voxelCoords.T * spacing[:, None]) + origin[:, None]
        return physical_coords.T


In [96]:
from typing import Union

class treeBase(ABC):
    def __init__(self , treeType:Union['lineMode','pointMode'], tempFolder = '/tmp'):
        '''
        treeType: multi type for make a tree
        connectType:
            1 for main Dongmai
            2 for main Jingmai
            3 for main bronchi
            4 for both dongmai and jingmai
            5 for all
        '''
        self.treeType = treeType
        self.tempFolder = tempFolder
        self.utb = utilsBase()
        self.utt = utilsTree()
        self.utvs = utilsVolumSearching()
        self.utconn = utilsConnection()
        self.uvtk=utilsVTK()
        
    def read_file(self,filename):
        if filename.endswith('nrrd'):
            img_data_nrrd, header_nrrd = nrrd.read(filename)
        elif filename.endswith('nii.gz'):
            img_data_nrrd = nib.load(filename)
            # 获取图像数据为numpy数组
            img_data_nrrd = img_data_nrrd.get_fdata()
        elif filename.endswith('npy'):
            img_data_nrrd = np.array(np.load(filename))
        return img_data_nrrd
    
    def SeePoint(self, lines, colors, text=None):

        # def make_point_cloud(npts, center, 1):
        #     pts = np.random.uniform(-radius, radius, size=[npts, 3]) + center
        #     cloud = o3d.geometry.PointCloud()
        #     cloud.points = o3d.utility.Vector3dVector(pts)
        #     cloud.colors = o3d.utility.Vector3dVector(colors)
        #     return cloud
        
        showlist = []
        app = gui.Application.instance
        app.initialize()
        vis = o3d.visualization.O3DVisualizer("Open3D - 3D Text", 1024, 768)
        vis.show_settings = True
        
        # 添加点云到Visualizer
        for li,line in enumerate(lines):
            newpoints = []
            for ii,pt in enumerate(line):
                if len(pt) ==2 and (type(pt[0]) != int and type(pt[0]) != float and type(pt[0]) != np.int64):
                    pt = pt[0]
                newpoints.append(pt)
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(np.asarray(newpoints))
            pcd.paint_uniform_color(colors[li])
            if text is not None and li ==len(line)-1:
                for idx in range(0, len(pcd.points)):
                    vis.add_3d_label(points.points[idx], text[idx])
            showlist.append(pcd)
            time.sleep(.5)
            
            vis.add_geometry(f'line{li}',pcd)
        vis.reset_camera_to_default()
        # vis.run()
        # vis.destroy_window()
        app.add_window(vis)
        app.run()

    def getPoints2Show(self,root, resList):
        resList += [[int(p[0][0]),int(p[0][1]),int(p[0][2])] for p in root['line']]
        for subt in root['subtree']:
            self.getPoints2Show(subt,resList)

        
    def processArtifact(self, data, hilumBox, classname = 2):
        vessel = np.where(data==classname,1,0)
        mr = 5
        offset= 5
        middleRegion = np.where(data==classname,1,0)
        middleRegion[:hilumBox[0][0]-mr,:,:] = 0
        middleRegion[hilumBox[0][1]+mr:,:,:] = 0
        middleRegion[:,:hilumBox[1][0]-mr,:] = 0
        middleRegion[:,hilumBox[1][1]+mr:,:] = 0
        middleRegion[:,:,:hilumBox[2][0]-mr] = 0
        middleRegion[:,:,hilumBox[2][1]+mr:] = 0
        
#         middleRegionc = np.where(middleRegion == 1)
#         middleRegionc = np.stack(middleRegionc,axis=1)

#         pcd = o3d.geometry.PointCloud()

#         # 假设 data 是一个 Nx3 的 numpy 数组
#         # data = np.random.rand(100, 3)  # 随机生成100个点
#         pcd.points = o3d.utility.Vector3dVector(middleRegionc)
#         # 可视化点云
#         o3d.visualization.draw_geometries([pcd])
        
        
        structure = np.ones((3, 3, 3), dtype=np.uint8)

        middleRegion = ndimage.binary_erosion(middleRegion, structure=structure, iterations=5)
        middleRegion,Box = self.utb.retain_largest_connected_component(middleRegion)
        middleRegion = ndimage.binary_dilation(middleRegion, structure=structure, iterations=5)
        
        
        x_slice = slice(Box[0]+10, Box[1]-10)
        y_slice = slice(Box[2]+10, Box[3]-10)
        z_slice = slice(Box[5]-10, Box[4]+10)
        
        vessel[x_slice, 
             y_slice, 
             z_slice][vessel[x_slice, 
                           y_slice, 
                           z_slice] == 1] = middleRegion[x_slice, 
                                                 y_slice, 
                                                 z_slice][vessel[x_slice, 
                                                               y_slice, 
                                                               z_slice] == 1]
        vessel[vessel == 1] = classname
        data[data == classname] = vessel[data == classname]
        
        
        
#         vesselc = np.where(vessel == classname)
#         vesselc = np.stack(vesselc,axis=1)

#         pcd = o3d.geometry.PointCloud()

#         # 假设 data 是一个 Nx3 的 numpy 数组
#         # data = np.random.rand(100, 3)  # 随机生成100个点
#         pcd.points = o3d.utility.Vector3dVector(vesselc)
#         # 可视化点云
#         o3d.visualization.draw_geometries([pcd])
        
        return data
        
    
    def readData(self, dcmFolder, segResultPath, hilumBoxFile, dealJing = True):
        volumeImage,spacing,origin,direction = self.utb.load_dicom_series_as_3d_array(dcmFolder)
        zoom_factors = (512/volumeImage.shape[0], 512/volumeImage.shape[1], 1)
        resData = self.read_file(segResultPath)
        if hilumBoxFile is not None:
            hilumBox = self.read_file(hilumBoxFile)
            hilumBox = hilumBox[:3,:][::-1,:].astype(np.int32)
        else:
            hilumBox = np.array([[0,512],[0,512],[0,512]])
        resData  = scipy.ndimage.zoom(resData, zoom_factors, order=0)
        
        return volumeImage, zoom_factors,resData,hilumBox,spacing,origin,direction
    def getHilums(self, resData,hilumBox ):
        bigDong = np.where(resData ==1 ,1,0)
        # bigJing = np.where(resData ==2 ,1,0)
        middleRegion = 20
        offset = 15
        print('hilumBox',hilumBox)
        bigDongStl = self.utb.getSmoothedBiggestStl(bigDong,hilumBox,middleRegion, offset )
        # bigJingStl = self.utb.getSmoothedBiggestStl(bigJing,hilumBox,middleRegion, offset)
        return bigDongStl
    
    def getTreeFromRegion(self, data, centerMiddle, skeleton):
        try:
            skcoords = skeleton.coords
        except:
            skcoords = skeleton
        mostrecent = np.argmin(np.sum((centerMiddle - skcoords)**2, axis = 1))
        mostrecentcoord = skcoords[mostrecent]
        bloodTree = {
        "line":[(mostrecentcoord, (0,0,0))],
        "subtree":[],
        "deep":[],
        "subLength":[],
        "dividePointIndex":[],
        "layer":0,
        }
        skeletonData = np.zeros(data.shape)
        skeletonData[skcoords[:,0],skcoords[:,1],skcoords[:,2]] = 1
    #     pcd = o3d.geometry.PointCloud()
    #     # 可视化点云
    #     pcd.points = o3d.utility.Vector3dVector(skcoords)
    #     # 可视化点云
    #     o3d.visualization.draw_geometries([pcd])

        grades,segDic = self.utt.label_vessel_grades(skeletonData, mostrecentcoord)
        maxlength = max([len(segDic[k]) for k in segDic.keys()])
        findDad = np.array([[0 for j in range(maxlength+3)] for i in range(len(segDic.keys())+3) ])
        self.utt.keyPointFind_treeGenerate(bloodTree, 0, segDic,findDad)
        maxdeep = self.utt.assignDeep(bloodTree)
        newtree = self.utt.createNewTreeFromOld(bloodTree)

        # check
        
        return newtree

        
    
    def retain_connected_component_list(self, data, cls, pixelThreshold = 50):
        '''
        得到两种不同区域池子
        分别用不同的策略进行连接
        '''
        if cls == -1:
            nowData = np.where(data > 0,1,0)    
        else:
            nowData = np.where(data == cls,1,0)

        skeleton = morphology.skeletonize(nowData)

        labeled_array = measure.label(skeleton,background = 0)
        regions = measure.regionprops(labeled_array)
        regions = sorted(regions, key=lambda x:x.area, reverse=True)
        BiggerRegions = [rg for rg in regions if rg.area >=pixelThreshold]
        smallerRegions = [rg for rg in regions if rg.area <pixelThreshold]
        return BiggerRegions,smallerRegions
        
    def getTreeFromHumanInteraction(self, jdic, data,regionBox,centerDong,centerJing):
        mainTreeDong = self.getTreeFromRegion(data,centerDong,np.array(jdic['dong']))
        mainTreeJing = self.getTreeFromRegion(data,centerJing,np.array(jdic['jing']))
        
        for treePoints, soureP,tarP, mainTreeClass in jdic['humanPair']:
            nowMainTree = mainTreeDong if mainTreeClass == 1 else mainTreeJing
            
            brtree = self.getTreeFromRegion(data,centerJing,np.array(jdic['jing']))
            
            addLines = self.utconn.getGenLine(soureP, tarP, double = True)
            

            self.utconn.addLine2Root(nowMainTree, 
                                 brtree, 
                                 tarPoint = brtree['line'][-1][0],
                                 rootPoint = tarP, 
                                 addlines = addLines,
                                 showAllroot = None )
        
        self.utt.emptyDeep(mainTreeDong)
        maxdeep = self.utt.assignDeep(mainTreeDong)
        self.utt.emptyDeep(mainTreeJing)
        maxdeep = self.utt.assignDeep(mainTreeJing)
        
        mainTreeDong = self.utt.createNewTreeFromOld(mainTreeDong)
        mainTreeJing = self.utt.createNewTreeFromOld(mainTreeJing)
        return mainTreeDong,mainTreeJing

        
    def getAllLineAndTree(self, data, regionBox , centerDong, threshold=0.01, tryNums = 10):
        
        BiggerRegionsDong, smallerRegionsDong = self.retain_connected_component_list(data,1 )
        # BiggerRegionsJing, smallerRegionsJing = self.retain_connected_component_list(data,2)
        mainTreeDong = self.getTreeFromRegion(data, centerDong, BiggerRegionsDong[0])
        # mainTreeJing = self.getTreeFromRegion(data, centerJing, BiggerRegionsJing[0])

        # biggerRegionPool = BiggerRegionsDong[1:] + BiggerRegionsJing[1:]
        biggerRegionPool = BiggerRegionsDong[1:] 


        
                
        self.utt.emptyDeep(mainTreeDong)
        maxdeep = self.utt.assignDeep(mainTreeDong)
        # self.utt.emptyDeep(mainTreeJing)
        # maxdeep = self.utt.assignDeep(mainTreeJing)
        
        mainTreeDong = self.utt.createNewTreeFromOld(mainTreeDong)
        # mainTreeJing = self.utt.createNewTreeFromOld(mainTreeJing)

        return mainTreeDong
    def allProcess(self,config):
        dcmfolder = config['dcmPath']
        segResultPath = config['resPath']
        hilumBoxPath = config['hilumBoxPath']
        tarfolder = config['targetFolder']
        volumeImage, zoom_factors,resData,hilumBox, spacing, origin, direction = self.readData(dcmfolder,segResultPath,hilumBoxPath)
        humanPath = os.path.join(tarfolder,'humanConnection.json')
        
        mainTreeDong = self.getAllLineAndTree(resData, hilumBox, [256,256,0], threshold=0.01)
        
        bigDongStl = self.getHilums(resData,hilumBox)
        meshDong =  self.uvtk.renderingTree(mainTreeDong,
                                            0,
                                            lastLayerMaxRadius = 10,
                                            lastLayerMinRadius=2,
                                            lastLayersLine = None,
                                            previous = None, 
                                            k = 1,
                                            smoothed_bigJing = bigDongStl)
        self.uvtk.readDecSmoWrtTrans(meshDong, os.path.join(self.tempFolder, 'dongmai.stl'), os.path.join(tarfolder,'动脉.stl'), spacing, origin, direction,zoom_factors)


In [None]:
import networkx as nx
def generate_vessel_graph(self, bloodTree):
    """
    生成血管中心线线段联通图
    
    参数:
    bloodTree: 血管树结构
    
    返回:
    G: networkx图对象
    class_points: 三个类别定位点
    """
    # 1. 收集所有点
    all_points = []
    self.getPoints2Show(bloodTree, all_points)
    all_points = np.array(all_points)
    
    # 2. 构建KD树用于快速查找邻居
    from scipy.spatial import cKDTree
    tree = cKDTree(all_points)
    
    # 3. 找到所有点的邻居
    neighbors = tree.query_ball_point(all_points, r=2.0)  # r=2.0是邻居搜索半径
    
    # 4. 找到联通点(有3个或以上邻居的点)
    connection_points = []
    for i, neighbor_list in enumerate(neighbors):
        if len(neighbor_list) >= 3:  # 至少有3个邻居
            connection_points.append(all_points[i])
    
    connection_points = np.array(connection_points)
    
    # 5. 构建networkx图
    G = nx.Graph()
    
    # 添加节点
    for i, point in enumerate(connection_points):
        G.add_node(i, pos=point)
    
    # 构建KD树用于联通点之间的邻居查找
    conn_tree = cKDTree(connection_points)
    
    # 找到联通点之间的边
    edges = conn_tree.query_pairs(r=3.0)  # r=3.0是边连接半径
    
    # 添加边
    G.add_edges_from(edges)
    
    # 6. 找到相距最远的三个点作为类别定位点
    def find_farthest_points(graph, start_point):
        distances = nx.single_source_shortest_path_length(graph, start_point)
        return max(distances.items(), key=lambda x: x[1])[0]
    
    # 选择第一个点(随机选择)
    point1 = 0
    # 找到距离point1最远的点
    point2 = find_farthest_points(G, point1)
    # 找到距离point1和point2最远的点
    point3 = find_farthest_points(G, point2)
    
    class_points = [point1, point2, point3]
    
    # 7. 使用Prim算法生成三个子图
    subgraphs = []
    for start_point in class_points:
        # 使用Prim算法生成最小生成树
        mst = nx.minimum_spanning_tree(G, algorithm='prim')
        # 从start_point开始进行BFS
        bfs_tree = nx.bfs_tree(mst, start_point)
        subgraphs.append(bfs_tree)
    
    return G, class_points, subgraphs

def visualize_graph(self, G, class_points, subgraphs):
    """
    可视化生成的图
    
    参数:
    G: networkx图对象
    class_points: 三个类别定位点
    subgraphs: 三个子图
    """
    import matplotlib.pyplot as plt
    
    # 创建图形
    plt.figure(figsize=(12, 8))
    
    # 获取节点位置
    pos = nx.get_node_attributes(G, 'pos')
    
    # 绘制原始图
    nx.draw_networkx_nodes(G, pos, node_size=50, node_color='gray', alpha=0.6)
    nx.draw_networkx_edges(G, pos, alpha=0.4)
    
    # 绘制三个类别定位点
    colors = ['red', 'blue', 'green']
    for i, point in enumerate(class_points):
        nx.draw_networkx_nodes(G, pos, nodelist=[point], 
                             node_size=100, node_color=colors[i])
    
    # 绘制三个子图
    for i, subgraph in enumerate(subgraphs):
        nx.draw_networkx_edges(subgraph, pos, edge_color=colors[i], 
                             width=2, alpha=0.8)
    
    plt.axis('off')
    plt.show()


# 在treeBase类中添加这些方法后，可以这样使用：
G, class_points, subgraphs = tb.generate_vessel_graph(mainTreeDong)
tb.visualize_graph(G, class_points, subgraphs)

In [97]:
'''integrade test'''
config = {
    'dcmPath':'D:\\data\\liver\\1.3.12.2.1107.5.1.4.95750.30000023073123385442300144053',
    'resPath':'D:\\data\\liver\\xiaqiang.nrrd',
    'hilumBoxPath':None,
    'targetFolder':'D:\\data\\liver\\'
}
# a = np.load(config['hilumBoxPath'])

# pdb.set_trace()

tb = treeBase('lineMode', 'D:\\')

start = time.time()
tb.allProcess(config)
print(time.time() - start)


hilumBox [[  0 512]
 [  0 512]
 [  0 512]]
labeled_array time 0.6338648796081543
regions time 0.2275989055633545
sorted time 0.002925395965576172
final time 0.10933208465576172
132.7871060371399
