In [4]:
import numpy as np

### Nelder Mead

In [1]:
def rosenbrock(x):
    return 100 * (x[0]**2 - x[1])**2 + (1 - x[0])**2

In [7]:
def toExit(f_min, f_max, epsilon):
    return (f_max - f_min)/np.maximum(np.abs(f_max) + np.abs(f_min), 1) <= epsilon

def NM(f, x, alpha, beta, gamma, epsilon, M):  

    X = np.array(x) # translate to array

    for k in range(1, M+1): 

        F = np.array([f(x_i) for x_i in X]) # F(X)
        idx = np.argsort(F) # sort across rows(vertices)
        X = X[idx] # sort the x values

        # print("________________________")
        # print("iteration: ", k)
        # print("X: \n", X)
        # print("F: \n", F)

        F_sorted = F[idx]
        f_n = F_sorted[0] # f_min
        f_0 = F_sorted[-1] # f_max
        f_1 = F_sorted[-2] # f_sec_max

        if (toExit(f_n, f_0, epsilon)): # convergence
            # print("exit at k: ", k)
            break

        # centroid: mean of F except for the max col
        u = np.mean(X[:-1], axis=0) # take mean across rows (features)
        # reflection
        v = (1 + alpha)*u - alpha*X[-1] # -x_max

        # check
        f_v = f(v)
        # expansion
        if f_v < f_n: # f(v) < f_min
            w = (1+gamma)*v - gamma*u
            f_w = f(w)
            if (f_w < f_0):
                X[-1] = w
            else: 
                X[-1] = v
        # contraction
        else: 
            if f_v < f_1: # f_min < f(v) < f_sec_max
                X[-1] = v
            else: 
                b = f_0

                if f_v < f_0: # f_min < f_sec_max < f(v) < f_max
                    X[-1] = v

                # contraction
                w = beta*X[-1] + (1-beta)*u
                f_w = f(w)

                if f_w <= b: 
                    X[-1] = w
                else: 
                    # shrink 
                    X[1:] = X[0] + 0.5*(X[1:] - X[0]) # except min column

    return X[0], f_n # x_min and f_min

In [8]:
### example
x_ini = [[-1.2, 1.0], [0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]
x_min, fx_min = NM(rosenbrock, x_ini, alpha=1, beta=0.5, gamma=1, epsilon=1e-5, M=500)
print("Solution: ", x_min)
print("Minimum value: ", fx_min)

Solution:  [0.99970352 0.9995812 ]
Minimum value:  3.118050268455314e-06
