# Nelder-Mead minimum search

The Nelder-Mead optimization algorithm is not guaranteed to converge to a local minima so other methods are often better. Nevertheless, it's a cool little huristic algorithm that has a nice geometric interpritation.

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import statistics
from mpl_toolkits import mplot3d
from matplotlib import animation
import os
import subprocess

In [3]:
ALPHA = 1
GAMMA = 2
RHO = 0.5
SIGMA = 0.5

def func(pt):
    x = pt[0]
    y = pt[1]
    return x*x + y*y + 2*np.sin(x*x + y*y)

def order(simplex):
    return np.array(sorted(simplex, key=func))

def centroid(simplex):
    return np.mean(simplex[:-1], axis=0)

def reflection(simplex, x_0):
    alpha = ALPHA
    return x_0 + alpha * (x_0 - simplex[-1])

def expansion(x_0, x_r):
    gamma = GAMMA
    return x_0 + gamma * (x_r - x_0)

def contraction(x_r, x_0):
    rho = RHO
    return x_0 + rho * (x_r - x_0)

def shrink(simplex):
    sigma = SIGMA
    return np.array([simplex[0], 
                     simplex[0] + sigma * (simplex[1] - simplex[0]), 
                     simplex[0] + sigma * (simplex[2] - simplex[0])])

In [4]:
simplex_list = []

pt1 = np.array([1.6, 1.4])
pt2 = np.array([0.6, -1.3])
pt3 = np.array([1.7, -1.8])

simplex = np.array([pt1, pt2, pt3])
simplex = order(simplex)

simplex_list.append(simplex)

while statistics.pstdev([func(simplex[0]), func(simplex[1]), func(simplex[2])]) > 0.000001:
    simplex = order(simplex)
    x_0 = centroid(simplex)
    x_r = reflection(simplex, x_0)
    
    if func(simplex[0]) <= func(x_r) < func(simplex[1]):
        # reflection
        print("reflection")
        simplex[-1] = x_r
        
    elif func(x_r) < func(simplex[0]):
        # expansion
        print("expansion")
        x_e = expansion(x_0, x_r)
        if func(x_e) < func(x_r):
            simplex[-1] = x_e
        else:
            simplex[-1] = x_r
    
    elif func(x_r) < func(simplex[2]):
        x_c = contraction(x_r, x_0)
        if func(x_c) < func(x_r):
            # contraction
            print("contraction_A")
            simplex[-1] = x_c
        else:
            # shrink
            print("shrinking")
            simplex = shrink(simplex)
    
    else:
        x_c = contraction(simplex[2], x_0)
        if func(x_c) < func(simplex[2]):
            # contraction
            print("contraction_B")
            simplex[-1] = x_c
        else:
            # shrink
            print("shrinking")
            simplex = shrink(simplex)
            
    simplex_list.append(simplex)

expansion
contraction_B
expansion
shrinking
expansion
contraction_A
contraction_B
contraction_A
contraction_B
contraction_B
contraction_B
contraction_B
contraction_B
contraction_A
contraction_B
contraction_B
contraction_B
contraction_B
contraction_B
expansion
contraction_B
contraction_B
contraction_B
contraction_B
contraction_B


In [7]:
def plot_simplex(ax, simplex, func, color='k', zorder=10):
    x_list, y_list, z_list = [], [], []
    for point in simplex:
        x_list.append(point[0])
        y_list.append(point[1])
        z_list.append(func([point[0], point[1]]))
    x_list.append(simplex[0][0])
    y_list.append(simplex[0][1])
    z_list.append(func([simplex[0][0], simplex[0][1]]))
    ax.plot(x_list, y_list, z_list, color=color, linewidth=3, zorder=zorder)
    ax.scatter(x_list, y_list, z_list, color='r', zorder=zorder)
    return x_list, y_list, z_list

def create_frames(simplex_list, func, frames_dir):
    frame_counter = 0
    x_list_old, y_list_old, z_list_old = None, None, None

    for i, simplex in enumerate(simplex_list):
        fig = plt.figure(figsize=(14, 10))
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)
        ax.set_zlim(0, 15)

        lower, upper, num_pts = -2, 2, 1000
        x = np.linspace(lower, upper, num_pts)
        y = np.linspace(lower, upper, num_pts)
        X, Y = np.meshgrid(x, y)
        Z = func([X, Y])
        
        ax.plot_wireframe(X, Y, Z, rcount=20, ccount=20, color='blue', zorder=1)
        
        x_list, y_list, z_list = plot_simplex(ax, simplex, func)

        if i != 0:
            ax.plot(x_list_old, y_list_old, z_list_old, color='r', linewidth=2, zorder=9)

        ax.view_init(elev=50, azim=110)
            
        frame_path = os.path.join(frames_dir, f'frame_{frame_counter:04d}.png')
        plt.savefig(frame_path)
        plt.close()
        frame_counter += 1

        x_list_old, y_list_old, z_list_old = x_list, y_list, z_list

        fig = plt.figure(figsize=(14, 10))
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)
        ax.set_zlim(0, 15)

        ax.plot_wireframe(X, Y, Z, rcount=20, ccount=20, color='blue', zorder=1)
        
        plot_simplex(ax, simplex, func)
        
        ax.view_init(elev=50, azim=110)
        
        frame_path = os.path.join(frames_dir, f'frame_{frame_counter:04d}.png')
        plt.savefig(frame_path)
        plt.close()
        frame_counter += 1

def create_video(frames_dir, output_video_path, fps=4):
    ffmpeg_command = [
        'ffmpeg',
        '-framerate', str(fps),
        '-i', os.path.join(frames_dir, 'frame_%04d.png'),
        '-c:v', 'libx265',
        '-preset', 'veryslow',
        '-crf', '0',
        '-pix_fmt', 'yuv444p',
        '-vf', 'scale=1920:1080',
        output_video_path
    ]
    subprocess.run(ffmpeg_command, check=True)
    print(f'Video saved to {output_video_path}')

In [None]:
frames_dir = r'FILL_IN'
output_video_path = r'FILL_IN'
if not os.path.exists(frames_dir):
    os.makedirs(frames_dir)

create_frames(simplex_list, func, frames_dir)
create_video(frames_dir, output_video_path)