In [1]:
import skimage.io
import matplotlib.pyplot as plt
import simple_snake_new as sis
import numpy as np
%matplotlib qt

In [None]:
I = skimage.io.imread('data/plusplus.png')
I = I.mean(axis=2)  # Convert to grayscale

radius = min(I.shape)/4
center = ( I.shape[0]/2, I.shape[1]/2 )

N = 60
step_size = 0.0001
plot_pause = 0.01
alpha = 0.001
beta = 0.001

snake = sis.make_circular_snake(N, center, radius)
B = sis.regularization_matrix(N, alpha, beta)

fig, ax = plt.subplots(figsize=(5, 5))

closed = np.hstack([np.arange(N), 0])  # Indices of the closed curve

for i in range(100):
    
    snake = sis.evolve_snake(snake, I, B, step_size)

    ax.clear()
    ax.imshow(I, cmap='gray')    
    ax.plot(snake[closed, 1],snake[closed, 0], 'b.-')
    ax.set_title(f'iteration {i+1}')
    fig.canvas.draw()
    plt.pause(plot_pause)
  
plt.show()

AttributeError: module 'simple_snake_new' has no attribute 'evolve_snake'

In [None]:
import numpy as np
import scipy.interpolate
import scipy.linalg
import skimage.draw
import matplotlib.pyplot as plt
try:
    import skimage.io
    HAS_SKIMAGE_IO = True
except ImportError:
    HAS_SKIMAGE_IO = False
    print("skimage.io not available, will create synthetic image")


# ===== ORIGINAL SNAKE FUNCTIONS =====

def make_circular_snake(N, center, radius):
    """ Initialize circular snake."""
    center = np.asarray(center).reshape([1, 2])
    angles = np.linspace(0, 2*np.pi, N, endpoint=False)
    unit_circle = np.array([np.cos(angles), np.sin(angles)]).T
    return center + radius * unit_circle


def normalize(n):
    l = np.sqrt((n ** 2).sum(axis=1, keepdims = True))
    l[l == 0] = 1
    return n / l


def get_normals(snake):
    """ Returns snake normals. """
    ds = normalize(np.roll(snake, 1, axis=0) - snake) 
    tangent = normalize(np.roll(ds, -1, axis=0) + ds)
    normal = np.stack([-tangent[:, 1], tangent[:, 0]], axis=1)
    return normal 


def distribute_points(snake):
    """ Distributes snake points equidistantly."""
    N = len(snake)
    closed = snake[np.hstack([np.arange(N), 0])]
    d = np.sqrt(((np.roll(closed, 1, axis=0) - closed) ** 2).sum(axis=1))
    d = np.cumsum(d)
    d = d / d[-1]  # Normalize to 0-1
    x = np.linspace(0, 1, N, endpoint=False)  # New points
    new =  np.stack([np.interp(x, d, closed[:, i]) for i in range(2)], axis=1) 
    return new


def is_ccw(A, B, C):
    # Check if A, B, C are in counterclockwise order
    return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])


def is_crossing(A, B, C, D):
    # Check if line segments AB and CD intersect
    return is_ccw(A, C, D) != is_ccw(B, C, D) and is_ccw(A, B, C) != is_ccw(A, B, D)


def is_counterclockwise(snake):
    """ Check if points are ordered counterclockwise."""
    return np.dot(snake[1:, 0] - snake[:-1, 0],
                  snake[1:, 1] + snake[:-1, 1]) < 0


def remove_intersections(snake, method='new'):
    """ Reorder snake points to remove self-intersections."""
    N = len(snake)
    closed = snake[np.hstack([np.arange(N), 0])]
    for i in range(N - 2):
        for j in range(i + 2, N):
            if is_crossing(closed[i], closed[i + 1], closed[j], closed[j + 1]):
                # Reverse vertices of smallest loop
                rb, re = (i + 1, j) if j - i < N // 2 else (j + 1, i + N)
                indices = np.arange(rb, re+1) % N                 
                closed[indices] = closed[indices[::-1]]                              
    snake = closed[:-1]
    return snake if is_counterclockwise(snake) else np.flip(snake, axis=0)


def keep_snake_inside(snake, shape):
    """ Contains snake inside the image."""
    snake[:, 0] = np.clip(snake[:, 0], 0, shape[0] - 1)
    snake[:, 1] = np.clip(snake[:, 1], 0, shape[1] - 1)
    return snake

    
def regularization_matrix(N, alpha, beta):
    """ Matrix for smoothing the snake."""
    s = np.zeros(N)
    s[[-2, -1, 0, 1, 2]] = (alpha * np.array([0, 1, -2, 1, 0]) + 
                    beta * np.array([-1, 4, -6, 4, -1]))
    S = scipy.linalg.circulant(s)  
    return scipy.linalg.inv(np.eye(N) - S)


def evolve_snake(snake, I, B, step_size):
    """ Single step of snake evolution."""
    mask = skimage.draw.polygon2mask(I.shape, snake)
    m_in = np.mean(I[mask])
    m_out = np.mean(I[~mask])
      
    f = scipy.interpolate.RectBivariateSpline(np.arange(I.shape[0]), np.arange(I.shape[1]), I)
    val = f(snake[:, 0], snake[:, 1], grid=False)

    f_ext = (m_in - m_out) * (2 * val - m_in - m_out)
    
    displacement = step_size * f_ext[:,None] * get_normals(snake)

    snake = snake + displacement  # external part
    snake = B @ snake  # internal part

    snake = remove_intersections(snake)
    snake = distribute_points(snake)
    snake = keep_snake_inside(snake, I.shape)
    return snake


# ===== ENERGY PLOTTING FUNCTIONS =====

def calculate_internal_energy(snake, alpha=0.01, beta=0.1):
    """Calculate internal energy components at each snake point."""
    N = len(snake)
    
    # Calculate first derivatives (stretching energy)
    d1 = np.roll(snake, -1, axis=0) - snake
    continuity_energy = alpha * np.sum(d1**2, axis=1)
    
    # Calculate second derivatives (curvature energy)
    d2 = np.roll(snake, -1, axis=0) - 2*snake + np.roll(snake, 1, axis=0)
    curvature_energy = beta * np.sum(d2**2, axis=1)
    
    total_internal = continuity_energy + curvature_energy
    
    return continuity_energy, curvature_energy, total_internal


def calculate_external_energy(snake, I):
    """Calculate external energy at each snake point."""
    mask = skimage.draw.polygon2mask(I.shape, snake)
    m_in = np.mean(I[mask])
    m_out = np.mean(I[~mask])
    
    f = scipy.interpolate.RectBivariateSpline(
        np.arange(I.shape[0]), np.arange(I.shape[1]), I
    )
    val = f(snake[:, 0], snake[:, 1], grid=False)
    
    f_ext = (m_in - m_out) * (2 * val - m_in - m_out)
    
    return f_ext


def plot_snake_energies(snake, I, alpha=0.01, beta=0.1, iteration=None):
    """Plot internal and external energies as curves along the snake contour."""
    
    # Calculate energies
    continuity_energy, curvature_energy, total_internal = calculate_internal_energy(snake, alpha, beta)
    external_energy = calculate_external_energy(snake, I)
    
    # Create parameter along snake
    N = len(snake)
    t = np.arange(N)
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot 1: External Energy
    axes[0, 0].plot(t, external_energy, 'b-', linewidth=2, label='External Energy')
    axes[0, 0].set_xlabel('Snake Point Index')
    axes[0, 0].set_ylabel('External Energy')
    title = 'External Energy Along Snake Contour'
    if iteration is not None:
        title += f' (Iteration {iteration})'
    axes[0, 0].set_title(title)
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()
    
    # Plot 2: Internal Energy Components
    axes[0, 1].plot(t, continuity_energy, 'r-', linewidth=2, label=f'Continuity (α={alpha})')
    axes[0, 1].plot(t, curvature_energy, 'g-', linewidth=2, label=f'Curvature (β={beta})')
    axes[0, 1].plot(t, total_internal, 'k--', linewidth=2, label='Total Internal')
    axes[0, 1].set_xlabel('Snake Point Index')
    axes[0, 1].set_ylabel('Internal Energy')
    title = 'Internal Energy Components Along Snake Contour'
    if iteration is not None:
        title += f' (Iteration {iteration})'
    axes[0, 1].set_title(title)
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].legend()
    
    # Plot 3: Total Energy
    total_energy = external_energy + total_internal
    axes[1, 0].plot(t, total_energy, 'm-', linewidth=2, label='Total Energy')
    axes[1, 0].plot(t, external_energy, 'b--', alpha=0.7, label='External')
    axes[1, 0].plot(t, total_internal, 'r--', alpha=0.7, label='Internal')
    axes[1, 0].set_xlabel('Snake Point Index')
    axes[1, 0].set_ylabel('Energy')
    title = 'Total Energy Along Snake Contour'
    if iteration is not None:
        title += f' (Iteration {iteration})'
    axes[1, 0].set_title(title)
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].legend()
    
    # Plot 4: Snake visualization with energy color coding
    axes[1, 1].imshow(I, cmap='gray')
    scatter = axes[1, 1].scatter(snake[:, 1], snake[:, 0], 
                                c=total_energy, cmap='viridis', s=20)
    title = 'Snake with Total Energy Color Coding'
    if iteration is not None:
        title += f' (Iteration {iteration})'
    axes[1, 1].set_title(title)
    plt.colorbar(scatter, ax=axes[1, 1], label='Total Energy')
    
    plt.tight_layout()
    
    if iteration is not None:
        plt.savefig(f'snake_energy_iteration_{iteration:03d}.png', dpi=150, bbox_inches='tight')
    else:
        plt.savefig('snake_energy_analysis.png', dpi=150, bbox_inches='tight')
    
    plt.show()
    
    return external_energy, continuity_energy, curvature_energy, total_internal


def create_synthetic_plusplus_image(shape=(200, 200)):
    """Create a synthetic '+' shaped image for testing."""
    I = np.zeros(shape)
    h, w = shape
    
    # Horizontal bar
    I[h//2-15:h//2+15, w//4:3*w//4] = 1.0
    
    # Vertical bar  
    I[h//4:3*h//4, w//2-15:w//2+15] = 1.0
    
    # Add some noise
    I += 0.1 * np.random.random(shape)
    
    return I


# ===== MAIN EXECUTION =====

def plot_energy_evolution(external_energies, internal_energies, title="Energy Evolution"):
    """Plot energy evolution over iterations like the examples shown."""
    iterations = range(len(external_energies))
    
    plt.figure(figsize=(8, 6))
    plt.plot(iterations, external_energies, 'b-', linewidth=2, label='EXT')
    plt.plot(iterations, internal_energies, 'orange', linewidth=2, label='INT')
    plt.xlabel('Iteration')
    plt.ylabel('Energy')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xlim(0, len(external_energies)-1)
    
    # Set y-axis to start from 0 for better comparison
    y_min = min(min(external_energies), min(internal_energies))
    y_max = max(max(external_energies), max(internal_energies))
    plt.ylim(max(0, y_min - 0.1 * (y_max - y_min)), y_max + 0.1 * (y_max - y_min))
    
    plt.tight_layout()
    plt.savefig(f'{title.lower().replace(" ", "_")}.png', dpi=150, bbox_inches='tight')
    plt.show()


def run_snake_evolution(max_iterations=60, show_animation=True):
    """Run the snake evolution with energy tracking."""
    
    # Try to load the image, otherwise create synthetic one
    try:
        if HAS_SKIMAGE_IO:
            I = skimage.io.imread('data/plusplus.png')
            if len(I.shape) == 3:
                I = I.mean(axis=2)  # Convert to grayscale
            print("Loaded plusplus.png successfully")
        else:
            raise FileNotFoundError("skimage.io not available")
    except (FileNotFoundError, OSError):
        print("Could not load 'data/plusplus.png', creating synthetic '+' image")
        I = create_synthetic_plusplus_image((200, 200))
    
    # Your original parameters
    radius = min(I.shape)/4
    center = (I.shape[0]/2, I.shape[1]/2)
    
    N = 60
    step_size = 0.0001
    plot_pause = 0.01
    alpha = 0.001
    beta = 0.001
    
    # Initialize snake and regularization matrix
    snake = make_circular_snake(N, center, radius)
    B = regularization_matrix(N, alpha, beta)
    
    print(f"Image shape: {I.shape}")
    print(f"Snake initialized with {N} points")
    print(f"Center: {center}, Radius: {radius}")
    print(f"Parameters: alpha={alpha}, beta={beta}, step_size={step_size}")
    
    # Arrays to store energy evolution
    external_energy_history = []
    internal_energy_history = []
    
    # Store initial snake for comparison
    initial_snake = snake.copy()
    
    # Create figure for snake evolution
    if show_animation:
        fig, ax = plt.subplots(figsize=(8, 8))
        closed = np.hstack([np.arange(N), 0])  # Indices of the closed curve
    
    # Calculate initial energies
    continuity_energy, curvature_energy, total_internal = calculate_internal_energy(snake, alpha, beta)
    external_energy = calculate_external_energy(snake, I)
    
    external_energy_history.append(np.sum(np.abs(external_energy)))
    internal_energy_history.append(np.sum(total_internal))
    
    print(f"Initial - External: {external_energy_history[0]:.2f}, Internal: {internal_energy_history[0]:.2f}")
    
    # Evolution loop
    for i in range(max_iterations):
        snake = evolve_snake(snake, I, B, step_size)
        
        # Calculate and store energies
        continuity_energy, curvature_energy, total_internal = calculate_internal_energy(snake, alpha, beta)
        external_energy = calculate_external_energy(snake, I)
        
        # Store total energies (sum over all points)
        external_energy_history.append(np.sum(np.abs(external_energy)))
        internal_energy_history.append(np.sum(total_internal))
        
        # Plot evolution
        if show_animation:
            ax.clear()
            ax.imshow(I, cmap='gray')    
            ax.plot(snake[closed, 1], snake[closed, 0], 'r-', linewidth=2, markersize=4)
            ax.set_title(f'Snake Evolution - Iteration {i+1}')
            ax.axis('equal')
            fig.canvas.draw()
            plt.pause(plot_pause)
        
        if (i+1) % 10 == 0:
            print(f"Iteration {i+1} - External: {external_energy_history[-1]:.2f}, Internal: {internal_energy_history[-1]:.2f}")
    
    if show_animation:
        plt.show()
    
    # Create comparison figure like the one shown
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Initial snake
    closed = np.hstack([np.arange(N), 0])
    axes[0].imshow(I, cmap='gray')
    axes[0].plot(initial_snake[closed, 1], initial_snake[closed, 0], 'r-', linewidth=2)
    axes[0].set_title('Initial Snake')
    axes[0].axis('equal')
    
    # Final snake
    axes[1].imshow(I, cmap='gray')
    axes[1].plot(snake[closed, 1], snake[closed, 0], 'r-', linewidth=2)
    axes[1].set_title(f'Final Snake (after {max_iterations} iterations)')
    axes[1].axis('equal')
    
    # Energy evolution
    iterations = range(len(external_energy_history))
    axes[2].plot(iterations, external_energy_history, 'b-', linewidth=2, label='EXT')
    axes[2].plot(iterations, internal_energy_history, 'orange', linewidth=2, label='INT')
    axes[2].set_xlabel('Iteration')
    axes[2].set_ylabel('Energy')
    axes[2].set_title('Energy Evolution')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('snake_evolution_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Create standalone energy evolution plot
    plot_energy_evolution(external_energy_history, internal_energy_history, 
                         "Snake Energy Evolution")
    
    print(f"\nEvolution completed!")
    print(f"Final - External: {external_energy_history[-1]:.2f}, Internal: {internal_energy_history[-1]:.2f}")
    
    return snake, I, external_energy_history, internal_energy_history


def compare_different_parameters():
    """Compare energy evolution with different parameter settings."""
    # Load/create image
    try:
        if HAS_SKIMAGE_IO:
            I = skimage.io.imread('data/plusplus.png')
            if len(I.shape) == 3:
                I = I.mean(axis=2)
        else:
            raise FileNotFoundError("skimage.io not available")
    except (FileNotFoundError, OSError):
        I = create_synthetic_plusplus_image((200, 200))
    
    # Different parameter sets to test
    param_sets = [
        {'alpha': 0.001, 'beta': 0.001, 'step_size': 0.0001, 'name': 'Original'}
        #{'alpha': 0.01, 'beta': 0.001, 'step_size': 0.0001, 'name': 'High Alpha'},
        #{'alpha': 0.001, 'beta': 0.01, 'step_size': 0.0001, 'name': 'High Beta'},
        #{'alpha': 0.005, 'beta': 0.005, 'step_size': 0.0002, 'name': 'Balanced High'},
    ]
    
    # Create comparison plot
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    
    for idx, params in enumerate(param_sets):
        # Initialize
        radius = min(I.shape)/4
        center = (I.shape[0]/2, I.shape[1]/2)
        N = 60
        
        snake = make_circular_snake(N, center, radius)
        B = regularization_matrix(N, params['alpha'], params['beta'])
        
        external_history = []
        internal_history = []
        
        # Evolution
        for i in range(60):
            # Calculate energies
            _, _, total_internal = calculate_internal_energy(snake, params['alpha'], params['beta'])
            external_energy = calculate_external_energy(snake, I)
            
            external_history.append(np.sum(np.abs(external_energy)))
            internal_history.append(np.sum(total_internal))
            
            # Evolve
            snake = evolve_snake(snake, I, B, params['step_size'])
        
        # Plot
        iterations = range(len(external_history))
        axes[idx].plot(iterations, external_history, 'b-', linewidth=2, label='EXT')
        axes[idx].plot(iterations, internal_history, 'orange', linewidth=2, label='INT')
        axes[idx].set_xlabel('Iteration')
        axes[idx].set_ylabel('Energy')
        axes[idx].set_title(f"{params['name']}\n(α={params['alpha']}, β={params['beta']})")
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('parameter_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()


# Run the evolution
if __name__ == "__main__":
    print("=== Running Snake Evolution with Energy Tracking ===")
    final_snake, image, ext_history, int_history = run_snake_evolution(max_iterations=60, show_animation=True)
    
    print("\n=== Comparing Different Parameter Settings ===")
    compare_different_parameters()

=== Running Snake Evolution with Energy Tracking ===
Loaded plusplus.png successfully
Image shape: (588, 588)
Snake initialized with 60 points
Center: (294.0, 294.0), Radius: 147.0
Parameters: alpha=0.001, beta=0.001, step_size=0.0001
Initial - External: 1459933.79, Internal: 14.36
Iteration 10 - External: 923306.62, Internal: 24.02
Iteration 20 - External: 684349.36, Internal: 34.68
Iteration 30 - External: 614057.13, Internal: 44.06
Iteration 40 - External: 658984.20, Internal: 49.82
Iteration 50 - External: 539051.06, Internal: 56.76
Iteration 60 - External: 394869.04, Internal: 52.34

Evolution completed!
Final - External: 394869.04, Internal: 52.34

=== Comparing Different Parameter Settings ===


KeyboardInterrupt: 

: 