In [1]:
import numpy as np
from numpy.random import default_rng, Generator
import trimesh
from trimesh import load, Trimesh
from trimesh.primitives import Sphere
from magnet_pinn.generator.samplers import BlobSampler, TubeSampler, PropertySampler
from magnet_pinn.generator.structures import Blob, Tube
from magnet_pinn.generator.typing import StructurePhantom
from magnet_pinn.generator.transforms import ToMesh, MeshesCutout, MeshesCleaning, Compose
from magnet_pinn.generator.io import MeshWriter
from magnet_pinn.generator.structures import CustomMeshStructure
from magnet_pinn.generator.utils import generate_fibonacci_points_on_sphere

In [None]:
class MeshBlobSampler:
    
    def __init__(self, child_radius: float, intersection_threshold: float = 0.5):
        if child_radius <= 0:
            raise ValueError("child_radius must be positive")
        if not 0 < intersection_threshold <= 1:
            raise ValueError("intersection_threshold must be in (0, 1]")
            
        self.child_radius = child_radius
        self.intersection_threshold = intersection_threshold

    def sample_children_blobs(self, parent_mesh_structure: CustomMeshStructure, 
                            num_children: int, rng: Generator, 
                            max_iterations: int = 100000) -> list[Blob]:
        if num_children == 0:
            return []
            
        mesh = parent_mesh_structure.mesh
        placed_blobs = []
        
        # Generate candidate positions using boundary-aware sampling
        for i in range(num_children):
            attempts = 0
            blob_placed = False
            
            while attempts < max_iterations and not blob_placed:
                # Sample position with bias toward mesh boundary
                position = self._sample_boundary_biased_position(mesh, rng)
                
                # Check intersection with mesh
                if self._validates_mesh_intersection(mesh, position):
                    # Check collision with existing blobs
                    if self._validates_blob_collision(position, placed_blobs):
                        blob = Blob(position, self.child_radius, 
                                  seed=rng.integers(0, 2**32-1).item())
                        placed_blobs.append(blob)
                        blob_placed = True
                
                attempts += 1
            
            if not blob_placed:
                break  # Stop if unable to place more blobs
                
        return placed_blobs

    def _sample_boundary_biased_position(self, mesh: Trimesh, rng: Generator) -> np.ndarray:
        if rng.random() < 0.3:
            surface_point = mesh.sample(1)[0]
            surface_normal = mesh.face_normals[mesh.nearest.on_surface([surface_point])[2][0]]
            offset_distance = rng.uniform(-self.child_radius, self.child_radius)
            return surface_point + surface_normal * offset_distance
        else:
            bounds = mesh.bounds
            return rng.uniform(bounds[0], bounds[1])

    def _validates_mesh_intersection(self, mesh: Trimesh, position: np.ndarray) -> bool:
        distance_to_mesh = np.abs(mesh.nearest.signed_distance([position])[0])
        
        intersects = distance_to_mesh < self.child_radius
        
        if intersects and self.intersection_threshold > 0.5:
            sphere_points = self._generate_sphere_sample_points(position)
            contained_points = mesh.contains(sphere_points)
            intersection_ratio = np.sum(contained_points) / len(contained_points)
            return intersection_ratio >= self.intersection_threshold
            
        return intersects

    def _validates_blob_collision(self, position: np.ndarray, existing_blobs: list[Blob]) -> bool:
        min_distance = 2 * self.child_radius  # No overlap
        
        for blob in existing_blobs:
            distance = np.linalg.norm(position - blob.position)
            if distance < min_distance:
                return False
        return True

    def _generate_sphere_sample_points(self, center: np.ndarray, num_points: int = 100) -> np.ndarray:
        unit_points = generate_fibonacci_points_on_sphere(num_points)
        return unit_points * self.child_radius + center

In [3]:
class CylinderPhantom:
    def __init__(self, stl_mesh_path: str, num_children_blobs: int = 3, 
                 blob_radius_decrease_per_level: float = 0.3, num_tubes: int = 5,
                 relative_tube_max_radius: float = 0.1, relative_tube_min_radius: float = 0.01):
        self.parent_structure = CustomMeshStructure(stl_mesh_path)

        child_radius = self.parent_structure.radius * blob_radius_decrease_per_level

        self.num_children_blobs = num_children_blobs
        self.num_tubes = num_tubes
        
        self.child_sampler = MeshBlobSampler(
            child_radius,
            0.9
        )
        
        tube_max_radius = relative_tube_max_radius * self.parent_structure.radius
        tube_min_radius = relative_tube_min_radius * self.parent_structure.radius
        self.tube_sampler = TubeSampler(tube_max_radius, tube_min_radius)
    
    def _estimate_sampling_radius(self):
        return 0.7 * self.parent_structure.radius

    def generate(self, seed: int = None) -> StructurePhantom:
        rng = default_rng(seed)
        
        children_blobs = self.child_sampler.sample_children_blobs(
            self.parent_structure,
            num_children=self.num_children_blobs,
            rng=rng
        )
        
        sampling_radius = self._estimate_sampling_radius()
        tubes = self.tube_sampler.sample_tubes(
            center=self.parent_structure.position,
            radius=sampling_radius,
            num_tubes=self.num_tubes,
            rng=rng
        )
        
        return StructurePhantom(
            parent=self.parent_structure,
            children=children_blobs,
            tubes=tubes
        )

In [None]:
cylinder_phantom = CylinderPhantom(
    stl_mesh_path="./phantom.stl",
    num_children_blobs=8,
    blob_radius_decrease_per_level=0.2,
    num_tubes=10,
    relative_tube_max_radius=0.08,
    relative_tube_min_radius=0.02
)

print(f"Loaded mesh with {len(cylinder_phantom.parent_structure.mesh.vertices)} vertices")
print(f"Mesh bounds: {cylinder_phantom.parent_structure.mesh.bounds}")
print(f"Mesh center: {cylinder_phantom.parent_structure.position}")
print(f"Effective radius: {cylinder_phantom.parent_structure.radius:.2f}")

Loaded mesh with 100 vertices
Mesh bounds: [[-125.6217  -150.259   -138.     ]
 [ 125.6217    79.74104  138.     ]]
Mesh center: [ 5.45850530e-07 -4.16992509e+01  3.24919470e-15]
Effective radius: 190.31


In [11]:
meshes = cylinder_phantom.generate(seed=42)
print(f"Generated {len(meshes.children)} children blobs and {len(meshes.tubes)} tubes")

Generated 5 children blobs and 50 tubes


In [12]:
workflow = Compose([
    ToMesh(),
    MeshesCutout(),
    MeshesCleaning()
])

In [None]:
resulting_meshes = workflow(meshes)

In [None]:
trimesh.boolean.union(
    resulting_meshes.children + resulting_meshes.tubes
).show()