In [None]:
import math
import numpy as np
import trimesh
import os
import random
from pathlib import Path
from typing import Tuple, List
import logging

class ToothCavityGenerator:
    def __init__(self, input_dir: str, output_dir: str):
        
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        
        # Create output directories
        self.cavity_dir = self.output_dir / "cavity_teeth"
        self.filling_dir = self.output_dir / "fillings"
        
        self.cavity_dir.mkdir(parents=True, exist_ok=True)
        self.filling_dir.mkdir(parents=True, exist_ok=True)
        
        # Setup logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def load_tooth_mesh(self, stl_path: str) -> trimesh.Trimesh:
        try:
            mesh = trimesh.load_mesh(stl_path)
            if not isinstance(mesh, trimesh.Trimesh):
                raise ValueError("Invalid mesh format")
            
            # Ensure mesh is watertight for Boolean operations
            if not mesh.is_watertight:
                self.logger.warning(f"Mesh {stl_path} is not watertight. Attempting to fix...")
                mesh.fill_holes()
                mesh.remove_degenerate_faces()
                mesh.remove_duplicate_faces()
                mesh.remove_unreferenced_vertices()
            
            return mesh
        except Exception as e:
            self.logger.error(f"Failed to load {stl_path}: {e}")
            raise

    def _fractal_noise(self, points: np.ndarray, octaves: int = 4, base_freq: float = 0.5) -> np.ndarray:

        rng = np.random.RandomState(42)
        noise = np.zeros(len(points))
        freq = base_freq
        amp = 1.0
        for o in range(octaves):
            # combine a few plane waves per octave for richer patterns
            for _ in range(4):
                v = rng.normal(size=(3,))
                v /= (np.linalg.norm(v) + 1e-12)
                phase = rng.uniform(0, 2 * np.pi)
                proj = points.dot(v) * freq
                noise += np.sin(proj + phase) * amp
            freq *= 2.0
            amp *= 0.5
        # normalize to -1..1
        noise = noise - np.mean(noise)
        mx = np.max(np.abs(noise)) + 1e-12
        return noise / mx
    
    def distort_mesh(self, mesh: trimesh.Trimesh, noise_strength: float = 0.18, octaves: int = 4, smoothing_iters: int = 8) -> trimesh.Trimesh:

        m = mesh.copy()
        verts = m.vertices.copy()
        pts = verts  # points in 3D
        noise = self._fractal_noise(pts, octaves=octaves, base_freq=0.5)  # shape (n_vertices,)
        # get vertex normals (trimesh computes them on demand)
        normals = m.vertex_normals
        scale = noise_strength * np.max(m.extents)
        displacement = (noise[:, None]) * normals * scale
        m.vertices = verts + displacement
    
        # gentle smoothing (Taubin/Laplacian). If trimesh has smoothing helper use that.
        try:
            from trimesh.smoothing import filter_taubin
            filter_taubin(m, iterations=smoothing_iters)
        except Exception:
            # fallback: quick Laplacian-like smoothing (average neighbors)
            try:
                # trimesh has method `smoothed()` in some versions; use safely if available
                m = m.smoothed()
            except Exception:
                pass
    
        m.remove_degenerate_faces()
        m.remove_duplicate_faces()
        m.remove_unreferenced_vertices()
        return m
    
    def _rotation_from_z_to(self, vec: np.ndarray) -> np.ndarray:
        v = vec / (np.linalg.norm(vec) + 1e-12)
        z = np.array([0.0, 0.0, 1.0])
        if np.allclose(v, z):
            return np.eye(4)
        if np.allclose(v, -z):
            R = np.eye(4); R[:3,:3] = np.array([[1,0,0],[0,-1,0],[0,0,-1]])
            return R
        axis = np.cross(z, v)
        axis /= (np.linalg.norm(axis) + 1e-12)
        angle = math.acos(np.clip(np.dot(z, v), -1.0, 1.0))
        # Rodrigues' rotation
        K = np.array([[0, -axis[2], axis[1]],
                      [axis[2], 0, -axis[0]],
                      [-axis[1], axis[0], 0]])
        R3 = np.eye(3) + math.sin(angle) * K + (1 - math.cos(angle)) * (K @ K)
        R = np.eye(4); R[:3,:3] = R3
        return R

    def generate_cavity_shape(self, size_range: Tuple[float, float] = (1.0, 3.0)) -> trimesh.Trimesh:

        cavity_type = random.choice(['sphere', 'ellipsoid', 'cylinder'])
        size = random.uniform(*size_range)
    
        if cavity_type == 'sphere':
            cavity = trimesh.creation.icosphere(radius=size, subdivisions=3)
            cavity = self.distort_mesh(cavity, noise_strength=0.2, octaves=4)
    
        elif cavity_type == 'ellipsoid':
            cavity = trimesh.creation.icosphere(radius=size, subdivisions=3)
            # Random anisotropic scaling
            scale_factors = [
                random.uniform(0.5, 1.4),
                random.uniform(0.5, 1.4),
                random.uniform(0.7, 1.2)
            ]
            cavity.apply_scale(scale_factors)
            cavity = self.distort_mesh(cavity, noise_strength=0.25, octaves=5)
    
        else:  # warped cylinder
            height = size * random.uniform(1.5, 2.5)
            cavity = trimesh.creation.cylinder(radius=size*0.7, height=height, sections=24)
    
            # Random rotation to avoid alignment
            rotation = trimesh.transformations.random_rotation_matrix()
            cavity.apply_transform(rotation)
    
            # Warp the cylinder: add sinusoidal bend along z
            verts = cavity.vertices.copy()
            bend_strength = size * 0.3
            verts[:, 0] += np.sin(verts[:, 2] * np.pi / height) * bend_strength * (0.5 - np.random.rand())
            verts[:, 1] += np.cos(verts[:, 2] * np.pi / height) * bend_strength * (0.5 - np.random.rand())
            cavity.vertices = verts
    
            # Add surface noise so ends aren’t flat
            cavity = self.distort_mesh(cavity, noise_strength=0.3, octaves=6)
    
        return cavity

    def find_surface_point(self, mesh: trimesh.Trimesh) -> Tuple[np.ndarray, np.ndarray]:
        
        # Get Z bounds of the tooth
        z_min, z_max = mesh.bounds[:, 2]
        z_range = z_max - z_min
        
        # Define crown area threshold (top 30% of the tooth in Z direction)
        crown_threshold = z_min + 0.7 * z_range  # Only sample from top 30%
        
        # Keep sampling until we get a point in the crown area
        max_attempts = 100
        for attempt in range(max_attempts):
            # Sample random point on mesh surface
            points, face_indices = mesh.sample(1, return_index=True)
            point = points[0]
            
            # Check if point is in crown area (upper portion)
            if point[2] >= crown_threshold:
                # Get face normal
                face_idx = face_indices[0]
                normal = mesh.face_normals[face_idx]
                return point, normal
        
        # Fallback: if we couldn't find a point in crown area, use the highest point
        self.logger.warning("Could not find surface point in crown area, using highest point")
        highest_vertex_idx = np.argmax(mesh.vertices[:, 2])
        point = mesh.vertices[highest_vertex_idx]
        
        # Find the face containing this vertex to get normal
        faces_containing_vertex = np.where(mesh.faces == highest_vertex_idx)[0]
        if len(faces_containing_vertex) > 0:
            face_idx = faces_containing_vertex[0]
            normal = mesh.face_normals[face_idx]
        else:
            # Fallback normal pointing up
            normal = np.array([0.0, 0.0, 1.0])
        
        return point, normal
    
    def position_cavity_on_surface(self, tooth_mesh: trimesh.Trimesh, cavity_mesh: trimesh.Trimesh, 
                              penetration_depth: float = 0.22) -> trimesh.Trimesh:
    
        surface_point, surface_normal = self.find_surface_point(tooth_mesh)
        cavity_positioned = cavity_mesh.copy()
    
        # align +Z of cavity to the surface normal
        R = self._rotation_from_z_to(surface_normal)
        cavity_positioned.apply_transform(R)
    
        # translate so cavity's center is at the surface point
        cavity_positioned.apply_translation(surface_point)
    
        # compute how far cavity extends along the surface normal (relative to surface_point)
        rel = cavity_positioned.vertices - surface_point
        # projection along normal
        proj_along_normal = rel.dot(surface_normal)
        # if the cavity is entirely below the surface (all proj < 0), push outward so it intersects
        max_proj = proj_along_normal.max()
        if max_proj < 0.0:
            outward = -max_proj + 0.01 * np.max(cavity_positioned.extents)
            cavity_positioned.apply_translation(surface_normal * outward)
    
        # compute a shallow inward penetration (so cavity starts at surface but penetrates a bit)
        cavity_size = np.max(cavity_positioned.extents)
        inward_amount = cavity_size * penetration_depth * random.uniform(0.25, 0.6)
        cavity_positioned.apply_translation(-surface_normal * inward_amount)
    
        return cavity_positioned
    
    def create_cavity_tooth(self, tooth_mesh: trimesh.Trimesh, 
                          num_cavities: int = 1) -> Tuple[trimesh.Trimesh, List[trimesh.Trimesh]]:
        
        cavity_tooth = tooth_mesh.copy()
        individual_cavities = []
        
        for i in range(num_cavities):
            try:
                # Generate and position cavity
                cavity_shape = self.generate_cavity_shape()
                positioned_cavity = self.position_cavity_on_surface(tooth_mesh, cavity_shape)
                individual_cavities.append(positioned_cavity)
                
                # Subtract cavity from tooth
                cavity_tooth = cavity_tooth.difference(positioned_cavity)
                
                # Ensure result is still a valid mesh
                if not isinstance(cavity_tooth, trimesh.Trimesh) or cavity_tooth.vertices.shape[0] == 0:
                    self.logger.warning(f"Boolean operation failed for cavity {i+1}, skipping...")
                    continue
                    
                self.logger.info(f"Created cavity {i+1}/{num_cavities}")
                
            except Exception as e:
                self.logger.error(f"Failed to create cavity {i+1}: {e}")
                continue
        
        return cavity_tooth, individual_cavities
    
    def create_fillings(self, original_tooth: trimesh.Trimesh, 
                       cavities: List[trimesh.Trimesh]) -> List[trimesh.Trimesh]:
        
        fillings = []
        
        for i, cavity in enumerate(cavities):
            try:
                # Create filling by intersecting cavity with original tooth
                filling = original_tooth.intersection(cavity)
                
                if isinstance(filling, trimesh.Trimesh) and filling.vertices.shape[0] > 0:
                    fillings.append(filling)
                    self.logger.info(f"Created filling {i+1}")
                else:
                    self.logger.warning(f"Failed to create valid filling {i+1}")
                    
            except Exception as e:
                self.logger.error(f"Failed to create filling {i+1}: {e}")
                continue
        
        return fillings
    
    def save_mesh(self, mesh: trimesh.Trimesh, filepath: str):
        
        try:
            mesh.export(filepath)
            self.logger.info(f"Saved: {filepath}")
        except Exception as e:
            self.logger.error(f"Failed to save {filepath}: {e}")
    
    def process_tooth_file(self, stl_filepath: str, num_cavities: int = 1, 
                          num_variations: int = 1):
        
        self.logger.info(f"Processing: {stl_filepath}")
        
        # Load original tooth
        original_tooth = self.load_tooth_mesh(stl_filepath)
        tooth_name = Path(stl_filepath).stem
        
        for variation in range(num_variations):
            variation_suffix = f"_var{variation+1:03d}" if num_variations > 1 else ""
            
            # Create cavity tooth and fillings
            cavity_tooth, cavities = self.create_cavity_tooth(original_tooth, num_cavities)
            fillings = self.create_fillings(original_tooth, cavities)
            
            # Save cavity tooth
            cavity_filename = f"{tooth_name}_cavity{variation_suffix}.stl"
            cavity_filepath = self.cavity_dir / cavity_filename
            self.save_mesh(cavity_tooth, str(cavity_filepath))
            
            # Save individual fillings
            for i, filling in enumerate(fillings):
                filling_filename = f"{tooth_name}_filling{i+1:02d}{variation_suffix}.stl"
                filling_filepath = self.filling_dir / filling_filename
                self.save_mesh(filling, str(filling_filepath))
            
            self.logger.info(f"Completed variation {variation+1}/{num_variations} for {tooth_name}")
    
    def process_all_teeth(self, num_cavities: int = 1, num_variations: int = 3):
        
        stl_files = list(self.input_dir.glob("*.stl"))
        
        if not stl_files:
            self.logger.error(f"No STL files found in {self.input_dir}")
            return
        
        self.logger.info(f"Found {len(stl_files)} STL files to process")
        
        for stl_file in stl_files:
            try:
                self.process_tooth_file(str(stl_file), num_cavities, num_variations)
            except Exception as e:
                self.logger.error(f"Failed to process {stl_file}: {e}")
                continue
        
        self.logger.info("Processing completed!")


def main():

    input_dir = r""          
    output_dir = r""              
    num_cavities = 1                             
    num_variations = 1                           
    single_file = None                        
    
    generator = ToothCavityGenerator(input_dir, output_dir)
    
    if single_file:
        generator.process_tooth_file(single_file, num_cavities, num_variations)
    else:
        generator.process_all_teeth(num_cavities, num_variations)


if __name__ == "__main__":
    main()