In [13]:
import cma
import numpy as np
from self_modelling import HopfieldNetworkRC

In [14]:
# Map continous states to {-1,1}
def normalise_sample(x):
    y = x.copy()
    y = np.where(y>0,1,y)
    y = np.where(y<=0,-1,y)
    return y

class CMAES:
    def __init__(self, weights):
        self.weights = weights
        self.N = self.weights.shape[0]
    
    def cost(self, sample):
        sample = normalise_sample(sample)
        energy = -0.5 * (sample @ self.weights) @ sample.T
        return energy.flatten()[0]
    
    def run(self,x,sigma,iterations):
        options = {
            'tolflatfitness': 100,
            'maxiter': iterations,
            'popsize': 50,
            # Time in seconds
            'timeout': '5*60*60',
            'verbose': 3,
        }
        es = cma.CMAEvolutionStrategy(x, sigma,options)
        return es.optimize(self.cost,iterations=iterations,min_iterations=iterations).result

In [15]:
N = 100
hopfield_model = HopfieldNetworkRC(N)

In [16]:
cmaes_model = CMAES(hopfield_model.weights)
x0 = np.random.choice([-1,1],size=cmaes_model.N)
sigma0 = 1 / 3
iterations = 100

In [17]:
result = cmaes_model.run(x0,sigma0,iterations)

(25_w,50)-aCMA-ES (mu_w=14.0,w_1=14%) in dimension 100 (seed=825501, Fri Dec 17 13:45:59 2021)
Iterat #Fevals   function value  axis ratio  sigma  min&max std  t[m:s]
    1     50 -1.000000000000000e+01 1.0e+00 3.16e-01  3e-01  3e-01 0:00.0
    2    100 -6.000000000000000e+00 1.0e+00 3.03e-01  3e-01  3e-01 0:00.0
    3    150 -1.000000000000000e+01 1.0e+00 2.94e-01  3e-01  3e-01 0:00.0
  100   5000 -5.860000000000000e+02 2.0e+00 2.24e+00  2e+00  2e+00 0:01.3


In [18]:
x, energy = result.xbest, result.fbest
norm_x = normalise_sample(x)
norm_x, energy

(array([ 1.,  1., -1., -1., -1.,  1., -1., -1., -1., -1., -1.,  1.,  1.,
        -1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1., -1.,  1.,  1.,  1.,
         1.,  1.,  1., -1., -1., -1.,  1.,  1., -1., -1., -1.,  1.,  1.,
        -1., -1.,  1.,  1., -1., -1., -1., -1.,  1., -1.,  1.,  1.,  1.,
         1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1., -1.,
         1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1., -1.,  1.,  1.,
         1.,  1., -1., -1.,  1., -1., -1.,  1., -1.,  1., -1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1.]),
 -590.0)