In [1]:
import numpy as np
import os
import trimesh
import meshio
from noise import pnoise3
import tetgen as tg
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
from tqdm import tqdm
import logging

In [2]:

class HeightmapMeshGenerator:
    def __init__(self, size_x=100, size_y=100, amplitude=1.0, noise_function=None, **noise_params):
        """Initialize the generator with given parameters."""
        self.size_x = size_x
        self.size_y = size_y
        self.amplitude = amplitude
        self.noise_function = noise_function if noise_function is not None else self.fbm_noise
        self.noise_params = noise_params

        self.heightmap = None
        self.normals = None
        self.mesh_vertices = None
        self.mesh_faces = None
        self.tetrahedral_mesh = None

        # Attributes for normal map
        self.normal_map_scale = 1  # Default scale factor
        self.normal_map_resolution = (512, 512)  # Default resolution (width, height)

        # Attributes for displacement map
        self.displacement_map_scale = 1  # Default scale factor
        self.displacement_map_resolution = (512, 512)  # Default resolution (width, height)

        # Output folder setup
        self.output_folder = "output"
        os.makedirs(self.output_folder, exist_ok=True)

        self.rng = np.random.default_rng(42)

        # Create a Plotly figure widget for visualization
        self.fig_widget = go.FigureWidget()

    # Noise Functions

    def fbm_noise(self, x, y, z=0, octaves=4, persistence=0.5, lacunarity=2.0):
        """Fractal Brownian Motion noise."""
        value = 0.0
        amplitude = 1.0
        frequency = 1.0
        for _ in range(octaves):
            value += amplitude * pnoise3(x * frequency, y * frequency, z * frequency)
            amplitude *= persistence
            frequency *= lacunarity
        return value

    def perlin_noise(self, x, y, z=0):
        """Simple Perlin noise."""
        return pnoise3(x, y, z)

    def square_wave(self, x, y, z=0, frequency=10.0, amplitude=1.0):
        """Square wave noise."""
        return amplitude * np.sign(np.sin(frequency * x))
    
    def sine_wave(self, x, y, z=0, frequency=10.0, amplitude=1.0):
        """Sine wave noise."""
        return amplitude * np.sin(frequency * x)

    def beckmann_noise(self, x, y, alpha=0.5):
        """Beckmann microfacet distribution."""
        tan_theta_h = np.sqrt(x**2 + y**2) / alpha
        cos_theta_h = 1 / np.sqrt(1 + tan_theta_h**2)
        D = np.exp(-(tan_theta_h**2)) / (np.pi * alpha**2 * cos_theta_h**4)
        return D

    def ggx_noise(self, x, y, alpha=0.5):
        """GGX (Trowbridge-Reitz) microfacet distribution."""
        tan2_theta_h = (x**2 + y**2) / alpha**2
        cos_theta_h = 1 / np.sqrt(1 + tan2_theta_h)
        cos2_theta_h = cos_theta_h**2
        D = alpha**2 / (np.pi * ((cos2_theta_h * (alpha**2 - 1) + 1)**2))
        return D

    def blinn_noise(self, x, y, n=20):
        """Blinn-Phong microfacet distribution."""
        tan_theta_h = np.sqrt(x**2 + y**2)
        cos_theta_h = 1 / np.sqrt(1 + tan_theta_h**2)
        D = (n + 2) / (2 * np.pi) * cos_theta_h**n
        return D

    def mandelbrot_noise(self, x, y, L=10.0, D_f=1.5, gamma=1.2, M=10, N_max=10):
        """
        Ausloos-Berman Weierstrass-Mandelbrot fractal noise function.
        Generates a fractal height value at coordinates (x, y) using the formula:

        y(x) = L^(4 - 2*D_f) ln(γ) * Σ_m=1^M Σ_n=1^N_max [γ^(n-1)(2(D_f-2)) * 
                (cos(2πγ^(n-1)x/L - cos(πm/M) + φ_mn)]

        Parameters:
            x, y (float): Spatial coordinates for evaluation.
            L (float): Length scale of the noise.
            D_f (float): Fractal dimension (affects self-similarity scaling).
            gamma (float): Frequency scaling factor.
            M (int): Number of angular divisions (outer summation index).
            N_max (int): Maximum frequency index (inner summation index).
        Returns:
            float: Fractal noise value at (x, y).
        """
        # Precompute scaling factors
        rng = self.rng
        log_gamma = np.log(gamma)
        scale_factor = L**(4 - 2 * D_f) * log_gamma
        noise_value = np.zeros_like(x)

        for n in range(1, N_max + 1):
            freq_scale = gamma**(n - 1)
            amp_scale = freq_scale**(2 * (D_f - 2))
            phi_mn = rng.uniform(0, 2 * np.pi, size=(M, 1, 1))
            m_values = np.arange(1, M + 1).reshape(-1, 1, 1)

            # Compute argument for both x and y
            arg_x = (2 * np.pi * freq_scale * x / L) - np.cos(np.pi * m_values / M) + phi_mn
            arg_y = (2 * np.pi * freq_scale * y / L) - np.cos(np.pi * m_values / M) + phi_mn
            noise_value += amp_scale * (np.sum(np.cos(arg_x), axis=0) + np.sum(np.cos(arg_y), axis=0))

        return scale_factor * noise_value


    def generate_filename(self, file_type):
        """Generate filename based on parameters."""
        noise_name = self.noise_function.__name__.replace("_noise", "").replace("_", "").capitalize()
        base_name = f"heightmap_{noise_name}_S{self.size_x}x{self.size_y}_A{self.amplitude:.2f}"

        if self.noise_function.__name__ in ['fbm_noise', 'perlin_noise']:
            base_name += f"_O{self.noise_params.get('octaves', 4)}P{self.noise_params.get('persistence', 0.5):.2f}L{self.noise_params.get('lacunarity', 2.0):.2f}"
        elif self.noise_function.__name__ in ['beckmann_noise', 'ggx_noise']:
            base_name += f"_A{self.noise_params.get('alpha', 0.5):.2f}"
        elif self.noise_function.__name__ == 'blinn_noise':
            base_name += f"_N{self.noise_params.get('n', 20)}"
        elif self.noise_function.__name__ in ['sine_wave', 'square_wave']:
            base_name += f"_F{self.noise_params.get('frequency', 10.0):.1f}A{self.noise_params.get('noise_amplitude', 1.0):.1f}"

        if file_type == "stl":
            return f"{base_name}.stl"
        elif file_type == "msh":
            return f"{base_name}.msh"
        elif file_type == "normal":
            return f"{base_name}_normal.png"
        elif file_type == "displacement":
            return f"{base_name}_displacement.png"
        return "output_file"

    # Heightmap Generation
    def generate_heightmap(self):
        """Generate the heightmap using the selected noise function."""
        size_x, size_y = self.size_x, self.size_y
        x = np.linspace(-1, 1, size_x)
        y = np.linspace(-1, 1, size_y)
        x_grid, y_grid = np.meshgrid(x, y, indexing='ij')

        # Updated non_vectorized list
        non_vectorized = [self.fbm_noise, self.perlin_noise]
        
        if self.noise_function in non_vectorized:
            heightmap = np.zeros((size_x, size_y))
            for i in range(size_x):
                for j in range(size_y):
                    xi = x[i]
                    yj = y[j]
                    heightmap[i, j] = self.noise_function(xi, yj, **self.noise_params)
        else:
            # Vectorized functions
            exponent = self.noise_params.get('exponent', 1.0)
            noise_params_no_exponent = {k: v for k, v in self.noise_params.items() if k != 'exponent'}
            D = self.noise_function(x_grid, y_grid, **noise_params_no_exponent)
            D_transformed = D ** exponent
            heightmap = D_transformed

        # Normalize and scale the heightmap
        heightmap = self.normalize_heightmap(heightmap) * self.amplitude
        self.heightmap = heightmap


    def normalize_heightmap(self, heightmap):
        """Normalize the heightmap to range [0, 1]."""
        min_val = np.min(heightmap)
        max_val = np.max(heightmap)
        if max_val - min_val == 0:
            # Fallback to add slight variation
            heightmap += self.rng.normal(0, 1e-6, heightmap.shape)
            min_val = np.min(heightmap)
            max_val = np.max(heightmap)
            if max_val - min_val == 0:
                raise ValueError("Heightmap still has no variation after adjustment.")
        return (heightmap - min_val) / (max_val - min_val)

    # Mesh Generation and Processing

    def heightmap_to_closed_mesh(self):
        """Convert the heightmap to a closed 3D mesh."""
        size_x, size_y = self.heightmap.shape

        # Create grid of indices
        i = np.arange(size_x)
        j = np.arange(size_y)
        ii, jj = np.meshgrid(i, j, indexing='ij')

        # Flatten the grid indices
        ii_flat = ii.flatten()
        jj_flat = jj.flatten()
        height_flat = self.heightmap.flatten()

        # Generate top and bottom vertices
        top_vertices = np.column_stack((ii_flat, jj_flat, height_flat))
        bottom_vertices = np.column_stack((ii_flat, jj_flat, np.zeros_like(height_flat)))

        # Combine vertices
        vertices = np.vstack((top_vertices, bottom_vertices))

        # Index offset for bottom vertices
        offset = size_x * size_y

        # Create vertex indices grid
        vertex_indices = np.arange(size_x * size_y).reshape((size_x, size_y))
        vertex_indices_bottom = vertex_indices + offset

        # Indices for cells (quads)
        v0 = vertex_indices[:-1, :-1]
        v1 = vertex_indices[1:, :-1]
        v2 = vertex_indices[:-1, 1:]
        v3 = vertex_indices[1:, 1:]

        # Flatten the indices
        v0_flat = v0.flatten()
        v1_flat = v1.flatten()
        v2_flat = v2.flatten()
        v3_flat = v3.flatten()

        # Create faces for the top surface
        faces_top = np.vstack([
            np.column_stack([v0_flat, v1_flat, v2_flat]),
            np.column_stack([v1_flat, v3_flat, v2_flat])
        ])

        # Bottom surface faces (reverse the order to flip normals)
        v0b = vertex_indices_bottom[:-1, :-1]
        v1b = vertex_indices_bottom[1:, :-1]
        v2b = vertex_indices_bottom[:-1, 1:]
        v3b = vertex_indices_bottom[1:, 1:]

        v0b_flat = v0b.flatten()
        v1b_flat = v1b.flatten()
        v2b_flat = v2b.flatten()
        v3b_flat = v3b.flatten()

        faces_bottom = np.vstack([
            np.column_stack([v0b_flat, v2b_flat, v1b_flat]),
            np.column_stack([v1b_flat, v2b_flat, v3b_flat])
        ])

        # Side faces
        # Left side (j=0)
        v_top_left = vertex_indices[:, 0]
        v_bot_left = vertex_indices_bottom[:, 0]

        faces_left = np.vstack([
            np.column_stack([v_top_left[:-1], v_bot_left[:-1], v_bot_left[1:]]),
            np.column_stack([v_top_left[:-1], v_bot_left[1:], v_top_left[1:]])
        ])

        # Right side (j=size_y-1)
        v_top_right = vertex_indices[:, -1]
        v_bot_right = vertex_indices_bottom[:, -1]

        faces_right = np.vstack([
            np.column_stack([v_top_right[:-1], v_bot_right[1:], v_bot_right[:-1]]),
            np.column_stack([v_top_right[:-1], v_top_right[1:], v_bot_right[1:]])
        ])

        # Front side (i=0)
        v_top_front = vertex_indices[0, :]
        v_bot_front = vertex_indices_bottom[0, :]

        faces_front = np.vstack([
            np.column_stack([v_top_front[:-1], v_bot_front[:-1], v_bot_front[1:]]),
            np.column_stack([v_top_front[:-1], v_bot_front[1:], v_top_front[1:]])
        ])

        # Back side (i=size_x-1)
        v_top_back = vertex_indices[-1, :]
        v_bot_back = vertex_indices_bottom[-1, :]

        faces_back = np.vstack([
            np.column_stack([v_top_back[:-1], v_bot_back[1:], v_bot_back[:-1]]),
            np.column_stack([v_top_back[:-1], v_top_back[1:], v_bot_back[1:]])
        ])

        # Combine all faces
        faces = np.vstack((faces_top, faces_bottom, faces_left, faces_right, faces_front, faces_back))

        self.vertices = vertices
        self.faces = faces

    def repair_mesh(self):
        """Repair the mesh if it's not watertight."""
        mesh = trimesh.Trimesh(vertices=self.vertices, faces=self.faces)
        if not mesh.is_watertight:
            # Attempt to repair the mesh
            trimesh.repair.fill_holes(mesh)
            trimesh.repair.fix_normals(mesh)
            trimesh.repair.fix_inversion(mesh)
            # Remove duplicate faces or vertices
            mesh.remove_duplicate_faces()
            mesh.remove_duplicate_vertices()
            mesh.remove_unreferenced_vertices()
            # Check again if the mesh is watertight
            if not mesh.is_watertight:
                print("Warning: Mesh is still not watertight after repair.")
                # Try to repair again
                trimesh.repair.fill_holes(mesh)
                trimesh.repair.fix_normals(mesh)
                trimesh.repair.fix_inversion(mesh)
                mesh.remove_duplicate_faces()
                mesh.remove_duplicate_vertices()
                mesh.remove_unreferenced_vertices()
                if not mesh.is_watertight:
                    raise RuntimeError("Mesh is not watertight after repair.")
        self.mesh_vertices = mesh.vertices
        self.mesh_faces = mesh.faces

    def simplify_mesh(self, target_face_count):
        """Simplify the mesh to a target number of faces."""
        mesh = trimesh.Trimesh(vertices=self.mesh_vertices, faces=self.mesh_faces)
        simplified_mesh = mesh.simplify_quadratic_decimation(target_face_count)
        self.mesh_vertices = simplified_mesh.vertices
        self.mesh_faces = simplified_mesh.faces
        print(f"Simplified to {len(self.mesh_faces)} faces.")

    def save_mesh_to_stl(self, filename):
        """Save the mesh to an STL file."""
        meshio_mesh = meshio.Mesh(points=self.mesh_vertices, cells=[("triangle", self.mesh_faces)])
        meshio.write(filename, meshio_mesh)
        print(f"Mesh saved to {filename}")

    def tetrahedralize(self, input_file='heightmap.stl', output_file='heightmap_tet.msh'):
        """Perform tetrahedralization on an STL file and save the result."""
        if not os.path.isfile(input_file):
            raise FileNotFoundError(f"File {input_file} not found.")

        imesh = meshio.read(input_file)
        V = imesh.points
        F = imesh.cells_dict.get("triangle", [])

        if len(F) == 0:
            raise ValueError(f"No triangular faces found in {input_file}.")

        # Create TetGen mesher object
        mesher = tg.TetGen(V, F)

        try:
            # Perform tetrahedralization
            Vtg, Ctg = mesher.tetrahedralize(
                order=1,
                mindihedral=5.0,
                minratio=1.0
            )
        except RuntimeError as e:
            raise RuntimeError("Failed to tetrahedralize. May need to repair surface by making it manifold.") from e

        # Create and save the tetrahedral mesh
        omesh = meshio.Mesh(Vtg, [("tetra", Ctg)])
        meshio.write(output_file, omesh)
        self.tetrahedral_mesh = omesh
        print(f"Tetrahedral mesh saved to {output_file}")

    # Visualization
    def update_plotly_figure(self):
        """Update the Plotly figure with the current heightmap."""
        if self.heightmap is None:
            return
        heightmap_safe = np.nan_to_num(self.heightmap, nan=0.0, posinf=1.0, neginf=0.0)
        size_x, size_y = heightmap_safe.shape
        x = np.linspace(-1, 1, size_x)
        y = np.linspace(-1, 1, size_y)
        x, y = np.meshgrid(x, y)

        # Clear the figure before updating to avoid duplicates
        self.fig_widget.data = []  # Clear previous data
        self.fig_widget.add_surface(z=heightmap_safe, x=x, y=y, colorscale='Viridis')

        self.fig_widget.update_layout(
            title='Heightmap',
            autosize=True,
            scene=dict(
                zaxis=dict(title='Height'),
                xaxis=dict(title='X Axis'),
                yaxis=dict(title='Y Axis')
            ),
            margin=dict(l=65, r=50, b=65, t=90)
        )

    def compute_normals(self):
        """Compute the normals for the heightmap mesh."""
        dzdx = np.gradient(self.heightmap, axis=0)
        dzdy = np.gradient(self.heightmap, axis=1)
        normals = np.dstack((-dzdx, -dzdy, np.ones_like(self.heightmap)))
        norm = np.linalg.norm(normals, axis=2, keepdims=True)
        norm[norm == 0] = 1  # Prevent division by zero
        self.normals = normals / norm

    def save_normal_map_as_png(self, filename=None):
        """Save the normal map as a detailed PNG image with adjustable resolution."""
        if self.normals is None:
            raise ValueError("Normals have not been computed yet.")

        if filename is None:
            filename = self.generate_filename("normal")

        # Map normals from [-1, 1] to [0, 255] for RGB representation
        normals_normalized = (self.normals + 1.0) / 2.0  # Scale to [0,1]
        normals_image = (normals_normalized * 255).astype(np.uint8)

        # Create the image using PIL
        img = Image.fromarray(normals_image, 'RGB')

        # Adjust resolution based on normal_map_resolution
        target_resolution = self.normal_map_resolution
        img = img.resize(target_resolution, Image.NEAREST)  # Use NEAREST to preserve normal data
        print(f"Normal map resized to {target_resolution} pixels.")

        # Apply scaling if scale factor is greater than 1
        if self.normal_map_scale != 1:
            new_size = (img.width * self.normal_map_scale, img.height * self.normal_map_scale)
            img = img.resize(new_size, Image.NEAREST)  # Use NEAREST to preserve normal data
            print(f"Normal map upscaled to {new_size} pixels.")

        # Save the image
        img.save(os.path.join(self.output_folder, filename))
        print(f"Normal map saved to {os.path.join(self.output_folder, filename)}")

    def save_displacement_map_as_png(self, filename=None):
        """Save the displacement map as a grayscale PNG image with adjustable resolution."""
        if self.heightmap is None:
            raise ValueError("Heightmap has not been generated yet.")

        if filename is None:
            filename = self.generate_filename("displacement")

        # Normalize heightmap to [0, 255] for grayscale representation
        displacement_normalized = (self.heightmap * 255).astype(np.uint8)
        displacement_image = Image.fromarray(displacement_normalized, 'L')  # 'L' mode for grayscale

        # Adjust resolution based on displacement_map_resolution
        target_resolution = self.displacement_map_resolution
        displacement_image = displacement_image.resize(target_resolution, Image.NEAREST)  # Use NEAREST to preserve data
        print(f"Displacement map resized to {target_resolution} pixels.")

        # Apply scaling if scale factor is greater than 1
        if self.displacement_map_scale != 1:
            new_size = (displacement_image.width * self.displacement_map_scale, displacement_image.height * self.displacement_map_scale)
            displacement_image = displacement_image.resize(new_size, Image.NEAREST)  # Use NEAREST to preserve data
            print(f"Displacement map upscaled to {new_size} pixels.")

        # Save the image
        displacement_image.save(os.path.join(self.output_folder, filename))
        print(f"Displacement map saved to {os.path.join(self.output_folder, filename)}")

    def display_interactive_controls(self, stl_filename='heightmap.stl', mesh_filename='heightmap_tet.msh'):
        """Display interactive controls and update the heightmap accordingly."""
        # Create interactive sliders
        amplitude_slider = widgets.FloatSlider(
            value=self.amplitude,
            min=0.1,
            max=20.0,
            step=0.1,
            description='Amplitude:',
            continuous_update=False
        )
        size_x_slider = widgets.IntSlider(
            value=self.size_x,
            min=10,
            max=200,
            step=10,
            description='Size X:',
            continuous_update=False
        )
        size_y_slider = widgets.IntSlider(
            value=self.size_y,
            min=10,
            max=200,
            step=10,
            description='Size Y:',
            continuous_update=False
        )
        noise_function_dropdown = widgets.Dropdown(
            options=[
                'Fractal Brownian Motion', 
                'Perlin Noise', 
                'Beckmann', 
                'Blinn-Phong', 
                'GGX', 
                'Sine Wave', 
                'Square Wave',
                'Mandelbrot'
            ],
            value='Fractal Brownian Motion',
            description='Noise Function:'
        )

        # Sliders for Fractal Brownian Motion parameters
        octaves_slider = widgets.IntSlider(
            value=self.noise_params.get('octaves', 4),
            min=1,
            max=10,
            step=1,
            description='Octaves:',
            continuous_update=False
        )
        persistence_slider = widgets.FloatSlider(
            value=self.noise_params.get('persistence', 0.5),
            min=0.1,
            max=1.0,
            step=0.1,
            description='Persistence:',
            continuous_update=False
        )
        lacunarity_slider = widgets.FloatSlider(
            value=self.noise_params.get('lacunarity', 2.0),
            min=1.0,
            max=5.0,
            step=0.1,
            description='Lacunarity:',
            continuous_update=False
        )

        # Slider for alpha parameter in Beckmann and GGX distributions
        alpha_slider = widgets.FloatSlider(
            value=self.noise_params.get('alpha', 0.5),
            min=0.01,
            max=1.0,
            step=0.01,
            description='Alpha (Roughness):',
            continuous_update=False
        )

        # Slider for n parameter in Blinn-Phong distribution
        n_slider = widgets.IntSlider(
            value=self.noise_params.get('n', 20),
            min=1,
            max=100,
            step=1,
            description='Shininess n:',
            continuous_update=False
        )

        # Slider for exponent to adjust peak sharpness
        exponent_slider = widgets.FloatSlider(
            value=self.noise_params.get('exponent', 1.0),
            min=0.1,
            max=5.0,
            step=0.1,
            description='Exponent:',
            continuous_update=False
        )

        # Sliders for Sine Wave and Square Wave parameters
        frequency_slider = widgets.FloatSlider(
            value=self.noise_params.get('frequency', 10.0),
            min=1.0,
            max=20.0,
            step=1.0,
            description='Frequency:',
            continuous_update=False
        )
        amplitude_slider_wave = widgets.FloatSlider(
            value=self.noise_params.get('amplitude', 1.0),
            min=0.1,
            max=5.0,
            step=0.1,
            description='Wave Amplitude:',
            continuous_update=False
        )

        # Sliders for Mandelbrot parameters
        mandelbrot_L_slider = widgets.FloatSlider(
            value=self.noise_params.get('L', 1.0),
            min=1.0,
            max=20.0,
            step=0.1,
            description='L:',
            continuous_update=False
        )

        DF_slider = widgets.FloatSlider(
            value=self.noise_params.get('D_f', 1.0),
            min=1.0,
            max=2.0,
            step=0.1,
            description='D_f:',
            continuous_update=False
        )
        gamma_slider = widgets.FloatSlider(
            value=self.noise_params.get('gamma', 1.0),
            min=1.0,
            max=2.5,
            step=0.1,
            description='Gamma:',
            continuous_update=False
        )
        M_slider = widgets.IntSlider(
            value=self.noise_params.get('M', 2),
            min=5,
            max=10,
            step=1,
            description='M:',
            continuous_update=False
        )

        N_max_slider = widgets.IntSlider(
            value=self.noise_params.get('N_max', 100),
            min=5,
            max=1000,
            step=10,
            description='N_max:',
            continuous_update=False
        )

        # Container to hold parameter sliders
        params_container = widgets.VBox()

        # Function to update parameter sliders based on selected noise function
        def update_params_sliders(*args):
            if noise_function_dropdown.value == 'Fractal Brownian Motion':
                params_container.children = [octaves_slider, persistence_slider, lacunarity_slider]
            elif noise_function_dropdown.value == 'Perlin Noise':
                params_container.children = []  # No additional parameters for Perlin Noise
            elif noise_function_dropdown.value == 'Beckmann' or noise_function_dropdown.value == 'GGX':
                params_container.children = [alpha_slider, exponent_slider]
            elif noise_function_dropdown.value == 'Blinn-Phong':
                params_container.children = [n_slider, exponent_slider]
            elif noise_function_dropdown.value == 'Sine Wave' or noise_function_dropdown.value == 'Square Wave':
                params_container.children = [frequency_slider, amplitude_slider_wave]
            elif noise_function_dropdown.value == 'Mandelbrot':
                params_container.children = [mandelbrot_L_slider, DF_slider, gamma_slider, M_slider, N_max_slider]
            else:
                params_container.children = []

        # Observe changes in the noise function dropdown
        noise_function_dropdown.observe(update_params_sliders, 'value')

        # Initialize parameter sliders
        update_params_sliders()

        # Define the update function
        def update_params(change):
            self.amplitude = amplitude_slider.value
            self.size_x = size_x_slider.value
            self.size_y = size_y_slider.value

            # Set the noise function based on the dropdown selection
            if noise_function_dropdown.value == 'Fractal Brownian Motion':
                self.noise_function = self.fbm_noise
                self.noise_params = {
                    'octaves': octaves_slider.value,
                    'persistence': persistence_slider.value,
                    'lacunarity': lacunarity_slider.value
                }
            elif noise_function_dropdown.value == 'Perlin Noise':
                self.noise_function = self.perlin_noise
                self.noise_params = {}
            elif noise_function_dropdown.value == 'Beckmann':
                self.noise_function = self.beckmann_noise
                self.noise_params = {
                    'alpha': alpha_slider.value,
                    'exponent': exponent_slider.value
                }
            elif noise_function_dropdown.value == 'GGX':
                self.noise_function = self.ggx_noise
                self.noise_params = {
                    'alpha': alpha_slider.value,
                    'exponent': exponent_slider.value
                }
            elif noise_function_dropdown.value == 'Blinn-Phong':
                self.noise_function = self.blinn_noise
                self.noise_params = {
                    'n': n_slider.value,
                    'exponent': exponent_slider.value
                }
            elif noise_function_dropdown.value == 'Sine Wave':
                self.noise_function = self.sine_wave
                self.noise_params = {
                    'frequency': frequency_slider.value,
                    'amplitude': amplitude_slider_wave.value
                }
            elif noise_function_dropdown.value == 'Square Wave':
                self.noise_function = self.square_wave
                self.noise_params = {
                    'frequency': frequency_slider.value,
                    'amplitude': amplitude_slider_wave.value
                }
            elif noise_function_dropdown.value == 'Mandelbrot':
                self.noise_function = self.mandelbrot_noise
                self.noise_params = {
                    'L': mandelbrot_L_slider.value,
                    'D_f': DF_slider.value,
                    'gamma': gamma_slider.value,
                    'M': M_slider.value,
                    'N_max': N_max_slider.value
                }
            else:
                self.noise_function = self.fbm_noise
                self.noise_params = {}

            # Regenerate heightmap and update the figure
            self.generate_heightmap()
            self.update_plotly_figure()

        # Attach the update_params function to widget events
        for widget_control in [
            amplitude_slider, size_x_slider, size_y_slider,
            octaves_slider, persistence_slider, lacunarity_slider,
            alpha_slider, n_slider,
            frequency_slider, amplitude_slider_wave, exponent_slider,
            mandelbrot_L_slider,
            DF_slider, gamma_slider, M_slider, N_max_slider,
            noise_function_dropdown
        ]:
            widget_control.observe(update_params, 'value')

        # Create the "Save Mesh" button
        save_mesh_button = widgets.Button(
            description='Save Mesh',
            button_style='success'
        )
        save_output = widgets.Output()

        # Define the function to save the mesh when the button is clicked
        def on_save_mesh_button_clicked(b):
            with save_output:
                save_output.clear_output()
                print("Saving mesh...")
                self.heightmap_to_closed_mesh()
                self.repair_mesh()
                self.save_mesh_to_stl(stl_filename)
                print(f"Mesh saved to {stl_filename}")
                self.tetrahedralize(stl_filename, mesh_filename)
                print(f"Tetrahedral mesh saved to {mesh_filename}")
                self.compute_normals()
                normal_map_filename = self.generate_filename("normal")
                self.save_normal_map_as_png(normal_map_filename)
                print(f"Normal map saved to {normal_map_filename}")
                displacement_map_filename = self.generate_filename("displacement")
                self.save_displacement_map_as_png(displacement_map_filename)
                print(f"Displacement map saved to {displacement_map_filename}")

        # Attach the function to the button click event
        save_mesh_button.on_click(on_save_mesh_button_clicked)

        # Create the UI layout
        ui = widgets.VBox([
            amplitude_slider,
            size_x_slider,
            size_y_slider,
            noise_function_dropdown,
            params_container,
            save_mesh_button,
            save_output
        ])

        # Initial update
        self.generate_heightmap()
        self.update_plotly_figure()

        # Display the interactive widgets and Plotly figure
        display(ui, self.fig_widget)

In [3]:
# Create an instance of the generator with default parameters
generator = HeightmapMeshGenerator()

# Display interactive controls
generator.display_interactive_controls()

VBox(children=(FloatSlider(value=1.0, continuous_update=False, description='Amplitude:', max=20.0, min=0.1), I…

FigureWidget({
    'data': [{'colorscale': [[0.0, '#440154'], [0.1111111111111111, '#482878'],
                             [0.2222222222222222, '#3e4989'], [0.3333333333333333,
                             '#31688e'], [0.4444444444444444, '#26828e'],
                             [0.5555555555555556, '#1f9e89'], [0.6666666666666666,
                             '#35b779'], [0.7777777777777778, '#6ece58'],
                             [0.8888888888888888, '#b5de2b'], [1.0, '#fde725']],
              'type': 'surface',
              'uid': 'a34789fb-0372-4967-bfdf-9d8817b2ab27',
              'x': array([[-1.        , -0.97979798, -0.95959596, ...,  0.95959596,  0.97979798,
                            1.        ],
                          [-1.        , -0.97979798, -0.95959596, ...,  0.95959596,  0.97979798,
                            1.        ],
                          [-1.        , -0.97979798, -0.95959596, ...,  0.95959596,  0.97979798,
                            1.        ],
  