In [1]:
import vtk
import numpy as np
from vtkmodules.vtkFiltersGeneral import vtkBooleanOperationPolyDataFilter
import os
import random
from tqdm import tqdm
import trimesh

def check_intersection(polydata1, polydata2):
    intersection_filter = vtk.vtkIntersectionPolyDataFilter()
    intersection_filter.SetInputData(0, polydata1)
    intersection_filter.SetInputData(1, polydata2)
    intersection_filter.Update()

    # Check if there is any intersection
    if intersection_filter.GetOutput().GetNumberOfCells() > 0:
        print("The two objects touch or intersect.")
    else:
        print("The two objects do not touch.")
def check_non_manifold_edges(polydata):
    # Create a filter to extract feature edges
    feature_edges = vtk.vtkFeatureEdges()
    feature_edges.SetInputData(polydata)

    # Enable checking for boundary, non-manifold edges, and feature edges
    feature_edges.BoundaryEdgesOn()
    feature_edges.NonManifoldEdgesOn()
    feature_edges.FeatureEdgesOff()  # Disable feature edges (sharp edges)
    feature_edges.ManifoldEdgesOff()  # Disable manifold edges (normal edges)

    # Update and run the pipeline
    feature_edges.Update()

    # If there are any non-manifold edges, the output will have points
    num_non_manifold_edges = feature_edges.GetOutput().GetNumberOfCells()
    if num_non_manifold_edges > 0:
        return False
    else:
        return True
    
def check_boundary_edges(polydata):
    feature_edges = vtk.vtkFeatureEdges()
    feature_edges.SetInputData(polydata)
    feature_edges.BoundaryEdgesOn()
    feature_edges.FeatureEdgesOff()  # Ignore feature edges
    feature_edges.NonManifoldEdgesOff()  # Ignore non-manifold edges

    # Update and run the pipeline
    feature_edges.Update()

    num_boundary_edges = feature_edges.GetOutput().GetNumberOfCells()
    if num_boundary_edges > 0:
        return False
    else:
        return True

def check_and_fill_holes(polydata, hole_size=10):
    fill_holes = vtk.vtkFillHolesFilter()
    fill_holes.SetInputData(polydata)
    fill_holes.SetHoleSize(hole_size)  # Set a reasonable hole size
    fill_holes.Update()

    num_filled_holes = fill_holes.GetOutput().GetNumberOfCells() - polydata.GetNumberOfCells()
    if num_filled_holes > 0:
        return fill_holes.GetOutput()
    else:
        return polydata
    
def create_random_mesh(type):
    size = random.randint(5, 10) / 4
    if type == "cube":
        # Create a cube with random coordinates
        cube = vtk.vtkCubeSource()
        cube.SetXLength(size)
        cube.SetYLength(size)
        cube.SetZLength(size)
        cube.Update()
        return cube.GetOutput()
    elif type == "sphere":
        # Create a sphere with random coordinates
        sphere = vtk.vtkSphereSource()
        sphere.SetRadius(size)
        sphere.SetPhiResolution(15)
        sphere.SetThetaResolution(15)
        sphere.Update()
        return sphere.GetOutput()
    elif type == "cylinder":
        # Create a cylinder with random coordinates
        cylinder = vtk.vtkCylinderSource()
        cylinder.SetRadius(size)
        cylinder.SetHeight(size * 2)
        cylinder.SetResolution(20)
        cylinder.Update()
        return cylinder.GetOutput()
# Step 1: Load the STL file using vtkSTLReader
def load_stl(filename):
    reader = vtk.vtkSTLReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()
def extrude_polydata(polydata, distance=10):
    # Create a linear extrusion filter
    extrude = vtk.vtkLinearExtrusionFilter()
    extrude.SetInputData(polydata)
    
    # Set the extrusion direction
    extrude.SetExtrusionTypeToNormalExtrusion()
    extrude.SetVector(0, 0, -1)  # Extrude along the Z-axis
    
    # Set the distance to extrude
    extrude.SetScaleFactor(distance)
    
    # Update the filter
    extrude.Update()
    return extrude.GetOutput()

def rotate_polydata(polydata, angle=random.randint(15, 120), axis=(0, 1, 0)):
    # Create a transform object
    transform = vtk.vtkTransform()
    axis_x, axis_y, axis_z = axis # Rotation around Z-axis (you can change this to any axis)

    # Apply the rotation
    transform.RotateWXYZ(angle, axis_x, axis_y, axis_z)

    # Apply the transformation to the polydata
    transformFilter = vtk.vtkTransformPolyDataFilter()
    transformFilter.SetInputData(polydata)
    transformFilter.SetTransform(transform)
    transformFilter.Update()
    return transformFilter.GetOutput()

# Step 2: Compute the center of mass of the model
def compute_center_of_mass(polydata):
    center_of_mass_filter = vtk.vtkCenterOfMass()
    center_of_mass_filter.SetInputData(polydata)
    center_of_mass_filter.Update()
    return center_of_mass_filter.GetCenter()

def create_coordinate_array(size_x, size_y, spacing=2.0):
    # Create a 2D NumPy array of coordinates with increased spacing
    x = np.linspace(-size_x / 2 * spacing, size_x / 2 * spacing, size_x)
    y = np.linspace(-size_y / 2 * spacing, size_y / 2 * spacing, size_y)
    x, y = np.meshgrid(x, y)  # Create a meshgrid
    z = np.random.uniform(-0.7, 0.7, (size_x, size_y))  # Random z values scaled by spacing

    # Stack the coordinates into a (x, y, z) format
    coordinates = np.stack((x.flatten(), y.flatten(), z.flatten()), axis=-1)
    return coordinates

def create_mesh_from_coordinates(coordinates):
    # Create VTK points from NumPy coordinates
    vtk_points = vtk.vtkPoints()
    for point in coordinates:
        vtk_points.InsertNextPoint(point)

    # Create a VTK PolyData object
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(vtk_points)

    # Create polys from the original array to define the surface
    size_x = int(np.sqrt(len(coordinates)))  # Assume square shape
    size_y = len(coordinates) // size_x
    polys = vtk.vtkCellArray()

    # Create quads (or triangles) from the meshgrid
    for i in range(size_x - 1):
        for j in range(size_y - 1):
            # Define two triangles for each quad in the mesh
            point1 = i * size_y + j
            point2 = (i + 1) * size_y + j
            point3 = (i + 1) * size_y + (j + 1)
            point4 = i * size_y + (j + 1)

            # First triangle
            polys.InsertNextCell(3)
            polys.InsertCellPoint(point1)
            polys.InsertCellPoint(point2)
            polys.InsertCellPoint(point3)

            # Second triangle
            polys.InsertNextCell(3)
            polys.InsertCellPoint(point1)
            polys.InsertCellPoint(point3)
            polys.InsertCellPoint(point4)

    polydata.SetPolys(polys)

    # Generate normals for lighting
    normal_generator = vtk.vtkPolyDataNormals()
    normal_generator.SetInputData(polydata)
    normal_generator.SetFeatureAngle(60.0)
    normal_generator.Update()

    return normal_generator.GetOutput()

# Step 3: Translate the model to move its center of mass to the origin (0, 0, 0)
def translate_to_origin(polydata, point_location):
    point1, point2, point3 = point_location
    transform = vtk.vtkTransform()
    transform.Translate(-point1, -point2, -point3)

    transform_filter = vtk.vtkTransformPolyDataFilter()
    transform_filter.SetTransform(transform)
    transform_filter.SetInputData(polydata)
    transform_filter.Update()

    return transform_filter.GetOutput()

def clean_polydata(polydata):
    clean_filter = vtk.vtkCleanPolyData()
    clean_filter.SetInputData(polydata)
    clean_filter.Update()
    return clean_filter.GetOutput()

def triangulate(polydata):
    triangle_filter = vtk.vtkTriangleFilter()
    triangle_filter.SetInputData(polydata)
    triangle_filter.Update()
    return triangle_filter.GetOutput()

def decimate(polydata, reduction_factor=0.5):
    decimate = vtk.vtkDecimatePro()
    decimate.SetInputData(polydata)
    
    # Target reduction factor, e.g., 0.5 means reduce to 50% of original size
    decimate.SetTargetReduction(reduction_factor)
    
    # Preserve topology ensures that the mesh structure stays intact
    decimate.PreserveTopologyOn()
    
    decimate.Update()
    return decimate.GetOutput()

def recalculate_normals(polydata):
    # Create a vtkPolyDataNormals object
    normals = vtk.vtkPolyDataNormals()
    
    # Set the input connection properly using GetOutputPort()
    normals.SetInputData(polydata)  # Use SetInputData for vtkPolyData
    
    # Optionally, you can control the behavior of the normals computation:
    normals.ComputePointNormalsOn()  # Compute normals at points (default)
    normals.ComputeCellNormalsOn()   # Compute normals for cells (optional)
    normals.ConsistencyOn()  # Ensures that normals are consistent across the surface
    normals.AutoOrientNormalsOn()  # Adjust normals to point outward
    normals.SplittingOff()  # Optional: Avoid splitting vertices at sharp edges
    
    # Update the normals filter to process the input
    normals.Update()
    
    # Return the output polydata with recalculated normals
    return normals.GetOutput()


def make_mesh_watertight(polydata):
    # Step 1: Clean the mesh to remove duplicate points and degenerate cells
    cleaned_polydata = clean_polydata(polydata)


    # Step 3: Recompute normals to ensure they are consistent
    normals_polydata = recalculate_normals(cleaned_polydata)

    # Step 4: Check and fill holes
    watertight_polydata = check_and_fill_holes(normals_polydata)

    if check_non_manifold_edges(watertight_polydata) and check_boundary_edges(watertight_polydata):
        return watertight_polydata
    else:
        return None
    
def move_mesh(mesh, x, y, z):
    transform = vtk.vtkTransform()
    transform.Translate(x, y, z)

    transform_filter = vtk.vtkTransformPolyDataFilter()
    transform_filter.SetTransform(transform)
    transform_filter.SetInputData(mesh)
    transform_filter.Update()

    return transform_filter.GetOutput()

def apply_boolean_operation(polydata1, polydata2, operation):
    """
    Apply a boolean operation on two vtkPolyData objects and separate the results if operation is 'difference'.

    :param polydata1: First vtkPolyData object.
    :param polydata2: Second vtkPolyData object.
    :param operation: Type of boolean operation ('union', 'intersection', 'difference').
    :return: Resulting vtkPolyData objects (could be one or two).
    """

    # Create the boolean operation filter
    boolean_filter = vtk.vtkBooleanOperationPolyDataFilter()
    
    # Set the operation type
    if operation == 'union':
        boolean_filter.SetOperationToUnion()
    elif operation == 'intersection':
        boolean_filter.SetOperationToIntersection()
    elif operation == 'difference':
        boolean_filter.SetOperationToDifference()
    else:
        raise ValueError("Invalid operation type. Choose 'union', 'intersection', or 'difference'.")

    # Set input data
    boolean_filter.SetInputData(0, polydata1)
    boolean_filter.SetInputData(1, polydata2)
    
    # Update the filter to perform the operation
    boolean_filter.Update()

    # Get the resulting polydata
    result_polydata = boolean_filter.GetOutput()
    if result_polydata is None:
        print('boolean problem')

    return result_polydata
def combine_components(components):
    """
    Combine multiple vtkPolyData components into a single vtkPolyData object.
    
    Parameters:
        components (list): A list of vtkPolyData objects.
    
    Returns:
        vtk.vtkPolyData: A single vtkPolyData object containing all components.
    """
    # Create a vtkAppendPolyData object
    append_filter = vtk.vtkAppendPolyData()
    
    # Iterate through the components and add them to the append filter
    for component in components:
        append_filter.AddInputData(component)
    
    # Update the filter to combine the components
    append_filter.Update()
    
    # Get the combined polydata
    combined_polydata = append_filter.GetOutput()
    
    return combined_polydata


def get_nr_regions(polydata):
    try:
        connectivity_filter = vtk.vtkConnectivityFilter()
        connectivity_filter.SetInputData(polydata)
        connectivity_filter.SetExtractionModeToAllRegions()  # Extract all regions
        connectivity_filter.Update()

        # Get the number of connected components
        num_components = connectivity_filter.GetNumberOfExtractedRegions()
        return num_components
    except:
        return None
def extract_components(polydata):
    """
    Extracts connected components from a vtkPolyData object.
    
    Parameters:
        polydata (vtk.vtkPolyData): Input vtkPolyData containing connected components.
    
    Returns:
        list: A list of vtkPolyData objects, each representing a connected component.
    """
    # Create a connectivity filter
    connectivity_filter = vtk.vtkConnectivityFilter()
    # Check if polydata is valid
    if polydata is not None and isinstance(polydata, vtk.vtkPolyData):
        connectivity_filter.SetInputData(polydata)  # Use SetInputData for VTK >= 9.0
        connectivity_filter.SetExtractionModeToAllRegions()  # Extract all regions
        connectivity_filter.Update()
    if polydata is None:
        print('polydata is none')
    if not isinstance(polydata, vtk.vtkPolyData):
        print('polydata is not vtkPolyData')


    # Get the number of connected components
    num_components = connectivity_filter.GetNumberOfExtractedRegions()
    components = []

    # Iterate through the components
    for i in range(num_components):
        # Create a new connectivity filter instance for each region
        component_filter = vtk.vtkConnectivityFilter()
        component_filter.SetInputData(polydata)  # Use the original polydata
        component_filter.SetExtractionModeToSpecifiedRegions()  # Extract specified regions
        component_filter.AddSpecifiedRegion(i)  # Specify the index of the region
        component_filter.Update()

        # Get the output for the current component
        component = vtk.vtkPolyData()
        component.DeepCopy(connectivity_filter.GetOutput())  # Copy the output to avoid overwrite
        components.append(component)

    #print(len(components))
    return components



# Main function to load, center, cut, separate, and visualize the STL model
def main(type_of_obj, mesh, index):
    # Load the STL file
    #original_model = create_random_mesh(type_of_obj)
    original_model = load_stl(f"frag_05.stl")
    #print(get_nr_regions(centered_model))
    # Compute the center of mass
    #center_of_mass = compute_center_of_mass(original_model)
    #print(center_of_mass)
    # Translate the model to center it at the origin
    #centered_model = translate_to_origin(original_model, center_of_mass)
    model_center = compute_center_of_mass(original_model)
    model_center = list(model_center)
    model_center = [x * -1 for x in model_center]
    centered_model = move_mesh(original_model, model_center[0], model_center[1], model_center[2])
    centered_model = clean_polydata(centered_model)
    centered_model = triangulate(centered_model)
    
    obj_writer = vtk.vtkOBJWriter()
    obj_writer.SetFileName(f"dataset_3d/ground_truth/{type_of_obj}_{index}.obj")
    obj_writer.SetInputData(centered_model)
    obj_writer.Write()

    check_intersection(centered_model, mesh)
    model_center = compute_center_of_mass(centered_model)
    #print(model_center)
    results = apply_boolean_operation(centered_model, mesh, 'difference')   
    #results = make_mesh_watertight(results)
    results = extract_components(results)
    #print(len(results))
    #print(results[0])
    if None not in results:
        for component in results:
            clean = clean_polydata(component)
            clean = recalculate_normals(clean)
            clean = make_mesh_watertight(clean)
            #component = move_mesh(component, random.uniform(0, 2.5), 0, random.uniform(-0.5, 0.5))
        print('done cleaning')
        final_mesh = combine_components(results)
        final_mesh = make_mesh_watertight(final_mesh)
        obj_writer = vtk.vtkOBJWriter()
        obj_writer.SetFileName(f"dataset_3d/train/{type_of_obj}_{index}.obj")
        obj_writer.SetInputData(final_mesh)
        obj_writer.Write()
        return True
    elif None in results:
        return False




In [2]:
def create_cutting_shape():
    coordinates = create_coordinate_array(10,10)
    mesh = create_mesh_from_coordinates(coordinates)
    mesh = clean_polydata(mesh)
    mesh_center = compute_center_of_mass(mesh)
    mesh = translate_to_origin(mesh, mesh_center)
    mesh = rotate_polydata(mesh, 90)
    mesh = extrude_polydata(mesh, 1/4)
    mesh = recalculate_normals(mesh)
    mesh = triangulate(mesh)
    mesh = clean_polydata(mesh)
    mesh = make_mesh_watertight(mesh)
    return mesh

In [3]:
def create_dataset(type_of_obj, nr_samples):
    i = 0
    os.makedirs('dataset_3d', exist_ok=True)
    with tqdm(total=100) as pbar:
        pbar.set_description(f'Creating dataset for {type_of_obj}')
        while i < nr_samples:
            mesh = create_cutting_shape()
            os.makedirs(f'dataset_3d/train', exist_ok=True)
            os.makedirs(f'dataset_3d/ground_truth', exist_ok=True)
            if main(type_of_obj, mesh, i):
                i += 1
                pbar.update(np.round(100/nr_samples, 2))
            else:
                
                break

In [4]:
#create_dataset("cube", 4000)

In [5]:
create_dataset("cylinder", 1)

Creating dataset for cylinder:   0%|          | 0/100 [00:00<?, ?it/s]

The two objects touch or intersect.


In [6]:
#create_dataset("sphere", 4000)

In [None]:
import trimes

In [None]:
mesh = 