In [18]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def plot_single_shape(vector_10d=None, figsize=(10, 8), show_info=True):
    """
    Plot a single shape from 10D vector
    
    Args:
        vector_10d: numpy array of shape (10,) - if None, generates random
        figsize: tuple - figure size
        show_info: bool - whether to show parameter info in title
    
    Returns:
        fig, ax: matplotlib figure and axis objects
        shape_info: dict with shape details
    """
    # Import your shape generation function
    from generate_shapes import sample_10d_sphere, generate_shape_from_10d_vector_injective
    
    # Generate random vector if none provided
    if vector_10d is None:
        vector_10d = sample_10d_sphere(1)[0]
    
    # Generate shape
    shape_data, shape_type, param_usage = generate_shape_from_10d_vector_injective(vector_10d)
    
    # Create figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot based on shape type
    if shape_type in ["superquadric", "organic_blob"]:
        # Surface plots
        x, y, z = shape_data
        surf = ax.plot_surface(x, y, z, alpha=0.8, cmap='viridis', 
                              linewidth=0, antialiased=True)
        
        # Add colorbar
        fig.colorbar(surf, ax=ax, shrink=0.5, aspect=20)
        
        # Equal aspect ratio
        max_range = max(np.max(np.abs(x)), np.max(np.abs(y)), np.max(np.abs(z))) * 1.1
        ax.set_xlim([-max_range, max_range])
        ax.set_ylim([-max_range, max_range])
        ax.set_zlim([-max_range, max_range])
        
    elif shape_type == "fractal_tree":
        # Tree structure
        branches = shape_data
        if branches:
            # Extract branch information
            starts = np.array([branch[0] for branch in branches])
            ends = np.array([branch[1] for branch in branches])
            thicknesses = np.array([branch[2] if len(branch) > 2 else 0.1 for branch in branches])
            
            # Color by generation/depth
            distances = np.linalg.norm(starts, axis=1)
            if np.max(distances) > 0:
                colors = plt.cm.autumn(distances / np.max(distances))
            else:
                colors = ['brown'] * len(branches)
            
            # Plot branches
            for i, ((start, end), thickness, color) in enumerate(zip(zip(starts, ends), thicknesses, colors)):
                linewidth = max(0.5, thickness * 50)
                ax.plot([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], 
                       color=color, linewidth=linewidth, alpha=0.8)
            
            # Set aspect ratio
            all_points = np.vstack([starts, ends])
            max_range = np.max(np.abs(all_points)) * 1.1
            ax.set_xlim([-max_range, max_range])
            ax.set_ylim([-max_range, max_range])
            ax.set_zlim([0, max_range * 2])
        else:
            ax.text(0, 0, 0, "No branches generated", fontsize=12)
            
    elif shape_type in ["crystal", "geodesic_dome", "twisted_prism", "spiral_tower"]:
        # Mesh-based shapes
        vertices, faces = shape_data
        
        if len(vertices) > 0:
            # Create face collection
            if len(faces) > 0:
                face_vertices = []
                for face in faces:
                    if len(face) >= 3:
                        face_vertices.append(vertices[face])
                
                if face_vertices:
                    # Color selection based on shape type
                    color_map = {
                        "crystal": ("lightblue", "darkblue"),
                        "geodesic_dome": ("cyan", "navy"), 
                        "twisted_prism": ("orange", "darkorange"),
                        "spiral_tower": ("lightgreen", "darkgreen")
                    }
                    face_color, edge_color = color_map.get(shape_type, ("lightgray", "black"))
                    
                    poly_collection = Poly3DCollection(face_vertices, alpha=0.7, 
                                                     facecolor=face_color, edgecolor=edge_color, 
                                                     linewidths=1)
                    ax.add_collection3d(poly_collection)
            
            # Plot vertices
            ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], 
                      c='red', s=30, alpha=0.8)
            
            # Set aspect ratio
            max_range = np.max(np.abs(vertices)) * 1.1
            ax.set_xlim([-max_range, max_range])
            ax.set_ylim([-max_range, max_range])
            ax.set_zlim([-max_range, max_range])
        else:
            ax.text(0, 0, 0, "No vertices generated", fontsize=12)
            
    elif shape_type == "neural_network":
        # Network structure
        nodes, connections = shape_data
        
        if len(nodes) > 0:
            # Plot connections first
            for connection in connections:
                if len(connection) >= 2:
                    start_idx, end_idx = connection[0], connection[1]
                    if start_idx < len(nodes) and end_idx < len(nodes):
                        start, end = nodes[start_idx], nodes[end_idx]
                        ax.plot([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], 
                               color='blue', alpha=0.3, linewidth=1)
            
            # Plot nodes colored by layer
            x_coords = nodes[:, 0]
            unique_x = np.unique(x_coords)
            layer_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_x)))
            
            for i, x_val in enumerate(unique_x):
                layer_nodes = nodes[np.abs(x_coords - x_val) < 0.01]
                if len(layer_nodes) > 0:
                    ax.scatter(layer_nodes[:, 0], layer_nodes[:, 1], layer_nodes[:, 2], 
                              c=[layer_colors[i]], s=100, alpha=0.9, 
                              edgecolors='black', linewidth=1)
            
            # Set aspect ratio
            x_range = np.max(nodes[:, 0]) - np.min(nodes[:, 0])
            y_range = np.max(nodes[:, 1]) - np.min(nodes[:, 1])
            z_range = np.max(nodes[:, 2]) - np.min(nodes[:, 2])
            max_range = max(y_range, z_range) * 1.2
            
            ax.set_xlim([np.min(nodes[:, 0]) - 0.5, np.max(nodes[:, 0]) + 0.5])
            ax.set_ylim([np.mean(nodes[:, 1]) - max_range/2, np.mean(nodes[:, 1]) + max_range/2])
            ax.set_zlim([np.mean(nodes[:, 2]) - max_range/2, np.mean(nodes[:, 2]) + max_range/2])
        else:
            ax.text(0, 0, 0, "No nodes generated", fontsize=12)
    
    # Styling
    ax.set_xlabel('X', fontsize=12)
    ax.set_ylabel('Y', fontsize=12)
    ax.set_zlabel('Z', fontsize=12)
    
    # Title with optional parameter info
    if show_info:
        title = f"{shape_type.replace('_', ' ').title()}\n"
        title += f"param[0] = {vector_10d[0]:.3f} (shape selector)\n"
        title += f"5D sample: [{vector_10d[1]:.2f}, {vector_10d[2]:.2f}, {vector_10d[3]:.2f}, ...]"
        ax.set_title(title, fontsize=11, fontweight='bold')
    else:
        ax.set_title(f"{shape_type.replace('_', ' ').title()}", fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    # Return shape info
    shape_info = {
        "shape_type": shape_type,
        "vector_10d": vector_10d,
        "param_usage": param_usage
    }
    
    return fig, ax, shape_info

def debug_shape_mapping():
    """
    Debug function to show exactly how parameter 0 maps to shape types
    """
    from generate_shapes import sample_10d_sphere, generate_shape_from_10d_vector_injective
    
    shape_types = [
        "superquadric", "fractal_tree", "crystal", "organic_blob",
        "geodesic_dome", "twisted_prism", "spiral_tower", "neural_network"
    ]
    
    print("Shape Type Distribution Analysis:")
    print("=" * 50)
    
    # Sample many vectors and see distribution
    n_samples = 1000
    shape_counts = {shape: 0 for shape in shape_types}
    param0_values = {shape: [] for shape in shape_types}
    
    for _ in range(n_samples):
        vector = sample_10d_sphere(1)[0]
        shape_data, shape_type, _ = generate_shape_from_10d_vector_injective(vector)
        shape_counts[shape_type] += 1
        param0_values[shape_type].append(vector[0])
    
    for shape_name in shape_types:
        count = shape_counts[shape_name]
        percentage = (count / n_samples) * 100
        if count > 0:
            param_range = [min(param0_values[shape_name]), max(param0_values[shape_name])]
            param_avg = np.mean(param0_values[shape_name])
            print(f"{shape_name:15}: {count:4d} samples ({percentage:5.1f}%) | param[0] range: [{param_range[0]:6.3f}, {param_range[1]:6.3f}] | avg: {param_avg:6.3f}")
        else:
            print(f"{shape_name:15}: {count:4d} samples ({percentage:5.1f}%) | No samples found")
    
    print(f"\nTotal: {sum(shape_counts.values())} samples")

def plot_shape_distribution():
    """
    Plot histogram of shape type distribution
    """
    from generate_shapes import sample_10d_sphere, generate_shape_from_10d_vector_injective
    import matplotlib.pyplot as plt
    
    # Sample many vectors
    n_samples = 1000
    param0_values = []
    shape_types = []
    
    for _ in range(n_samples):
        vector = sample_10d_sphere(1)[0]
        shape_data, shape_type, _ = generate_shape_from_10d_vector_injective(vector)
        param0_values.append(vector[0])
        shape_types.append(shape_type)
    
    # Create histogram
    fig, ax = plt.subplots(figsize=(12, 6))
    
    unique_shapes = list(set(shape_types))
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_shapes)))
    
    for i, shape in enumerate(unique_shapes):
        shape_param0 = [p for p, s in zip(param0_values, shape_types) if s == shape]
        ax.hist(shape_param0, bins=20, alpha=0.7, label=shape, color=colors[i])
    
    ax.set_xlabel('param[0] value')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of Shape Types by param[0]')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_shape_by_type(shape_type_name, vector_10d=None, figsize=(10, 8)):
    """
    Plot a specific shape type by forcing parameter 0 to select it
    
    Args:
        shape_type_name: str - one of ["superquadric", "fractal_tree", "crystal", 
                                      "organic_blob", "geodesic_dome", "twisted_prism", 
                                      "spiral_tower", "neural_network"]
        vector_10d: numpy array - if None, generates random but forces shape type
        figsize: tuple - figure size
    """
    from generate_shapes import sample_10d_sphere
    
    # Shape type mapping (same as in your original code)
    shape_types = [
        "superquadric", "fractal_tree", "crystal", "organic_blob",
        "geodesic_dome", "twisted_prism", "spiral_tower", "neural_network"
    ]
    
    if shape_type_name not in shape_types:
        raise ValueError(f"Shape type must be one of: {shape_types}")
    
    # Generate or modify vector to force specific shape type
    if vector_10d is None:
        vector_10d = sample_10d_sphere(1)[0]
    else:
        vector_10d = vector_10d.copy()
    
    # Force parameter 0 to select desired shape type
    shape_idx = shape_types.index(shape_type_name)
    # Map shape index [0,1,2,3,4,5,6,7] to param0 ranges
    # Each shape gets a range of width 2/8 = 0.25 in [-1, 1]
    range_width = 2.0 / 8
    range_start = -1 + shape_idx * range_width
    target_param0 = range_start + range_width / 2  # Center of range
    
    # Debug print
    print(f"Forcing shape: {shape_type_name} (idx={shape_idx})")
    print(f"Target param[0]: {target_param0:.3f}")
    
    # Adjust vector to have desired param[0] while maintaining unit norm
    vector_10d[0] = target_param0
    vector_10d = vector_10d / np.linalg.norm(vector_10d)  # Renormalize
    
    print(f"Final param[0]: {vector_10d[0]:.3f}")
    
    return plot_single_shape(vector_10d, figsize, show_info=True)

def explore_random_shapes(n_shapes=4, figsize=(12, 10)):
    """
    Generate and plot multiple random shapes in a grid
    
    Args:
        n_shapes: int - number of shapes to generate
        figsize: tuple - overall figure size
    """
    from generate_shapes import sample_10d_sphere, generate_shape_from_10d_vector_injective
    
    # Calculate grid dimensions
    cols = 2
    rows = (n_shapes + cols - 1) // cols
    
    fig = plt.figure(figsize=figsize)
    
    for i in range(n_shapes):
        # Generate random vector
        vector = sample_10d_sphere(1)[0]
        shape_data, shape_type, _ = generate_shape_from_10d_vector_injective(vector)
        
        ax = fig.add_subplot(rows, cols, i + 1, projection='3d')
        
        # Quick visualization based on shape type
        if shape_type in ["superquadric", "organic_blob"]:
            x, y, z = shape_data
            ax.plot_surface(x, y, z, alpha=0.7, cmap='viridis')
        elif shape_type == "fractal_tree":
            branches = shape_data
            for start, end, thickness in branches:
                ax.plot([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], 
                       'brown', linewidth=max(0.5, thickness*20))
        elif shape_type in ["crystal", "geodesic_dome", "twisted_prism", "spiral_tower"]:
            vertices, faces = shape_data
            if len(faces) > 0:
                face_vertices = [vertices[face] for face in faces if len(face) >= 3]
                if face_vertices:
                    poly_collection = Poly3DCollection(face_vertices, alpha=0.6)
                    ax.add_collection3d(poly_collection)
            if len(vertices) > 0:
                ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], s=20)
        elif shape_type == "neural_network":
            nodes, connections = shape_data
            for start_idx, end_idx in connections:
                if start_idx < len(nodes) and end_idx < len(nodes):
                    start, end = nodes[start_idx], nodes[end_idx]
                    ax.plot([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], 
                           'blue', alpha=0.3)
            ax.scatter(nodes[:, 0], nodes[:, 1], nodes[:, 2], c='red', s=50)
        
        ax.set_title(f"{shape_type.replace('_', ' ').title()}\nparam[0]={vector[0]:.2f}", fontsize=10)
        ax.set_xlabel('X')
        ax.set_ylabel('Y') 
        ax.set_zlabel('Z')
    
    plt.tight_layout()
    return fig


debug_shape_mapping()

# Example usage functions for Jupyter notebook:
"""
# Basic usage:
fig, ax, info = plot_single_shape()  # Random shape
plt.show()

# Plot specific shape type:
fig, ax, info = plot_shape_by_type("crystal")
plt.show()

# Plot with your own vector:
my_vector = np.random.randn(10)
my_vector = my_vector / np.linalg.norm(my_vector)
fig, ax, info = plot_single_shape(my_vector)
plt.show()

# Explore multiple random shapes:
fig = explore_random_shapes(n_shapes=6)
plt.show()

# Print shape info:
print(f"Shape type: {info['shape_type']}")
print(f"Parameters: {info['param_usage']}")
"""

Shape Type Distribution Analysis:
superquadric   :    6 samples (  0.6%) | param[0] range: [-0.916, -0.753] | avg: -0.802
fractal_tree   :   62 samples (  6.2%) | param[0] range: [-0.725, -0.501] | avg: -0.576
crystal        :  147 samples ( 14.7%) | param[0] range: [-0.499, -0.250] | avg: -0.357
organic_blob   :  268 samples ( 26.8%) | param[0] range: [-0.249, -0.001] | avg: -0.122
geodesic_dome  :  284 samples ( 28.4%) | param[0] range: [ 0.000,  0.249] | avg:  0.122
twisted_prism  :  175 samples ( 17.5%) | param[0] range: [ 0.254,  0.496] | avg:  0.356
spiral_tower   :   56 samples (  5.6%) | param[0] range: [ 0.500,  0.729] | avg:  0.579
neural_network :    2 samples (  0.2%) | param[0] range: [ 0.783,  0.845] | avg:  0.814

Total: 1000 samples


'\n# Basic usage:\nfig, ax, info = plot_single_shape()  # Random shape\nplt.show()\n\n# Plot specific shape type:\nfig, ax, info = plot_shape_by_type("crystal")\nplt.show()\n\n# Plot with your own vector:\nmy_vector = np.random.randn(10)\nmy_vector = my_vector / np.linalg.norm(my_vector)\nfig, ax, info = plot_single_shape(my_vector)\nplt.show()\n\n# Explore multiple random shapes:\nfig = explore_random_shapes(n_shapes=6)\nplt.show()\n\n# Print shape info:\nprint(f"Shape type: {info[\'shape_type\']}")\nprint(f"Parameters: {info[\'param_usage\']}")\n'