In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
%matplotlib inline

V_MP = 0.03
V_REST = -0.065
class Hopfield:
    def __init__(self, size=64, iter=10):
        self.iter = iter
        self.size = size
        self.W = np.zeros((size ** 2, size ** 2))

    def train(self, X):
        n = self.size ** 2
        for x in X:  # (-1,64*64)
            x = np.reshape(x, (n, 1))
            xT = np.reshape(x, (1, n))
            self.W += x * xT / n
        self.W[np.diag_indices_from(self.W)] = 0

    def test_one_frame(self, x):
        n = self.size ** 2
        x = np.reshape(x, (n,))
        energy = []
        for iter in range(self.iter):
            h = np.zeros((n,))
            for i in range(n):
                i = np.random.randint(n)
                h[i] = self.W[i, :].dot(x)
            x[h > 0] = 1
            x[h < 0] = -1
            energy.append(self.cal_energy(x))

        return np.resize(x, (self.size, self.size)), energy

    def cal_energy(self, x):
        n = self.size ** 2
        energy = np.sum(self.W.dot(x) * x)

        return -0.5 * energy


def show(x):
    img = np.where(x > 0, 255, 0).astype(np.uint8)
    cv.imshow("img", img)
    cv.waitKey(0)


'''
class mIZH(Hopfield):
    def __init__(self, size=64, iter=10):
        self.iter = iter
        self.size = size
        self.W = np.zeros((size ** 2, size ** 2))
        self.a1 = 0.04
        self.a2 = 5.0
        self.a3 = 140.0
        self.a4 = 1.0
        self.a5 = 1.0
        self.r = 0.02
        
    def memorize(self, X):
        n = self.size ** 2
        for x in X:
            x = np.reshape(x, (n, 1))
            xT = np.reshape(x, (1, n))
            self.W += x * xT / n
            # check: 先求和再减，还是先减再求和。
            self.W[np.diag_indices_from(self.W)] = 0
            
    def train(self, X):
        il = np.zeros(self.size, 1)
        u = il.copy()
        for _ in range(self.iter):
            iext = np.zeros(self.size, 1)
            for x in X:
                iext += self.W.dot(x)
            # update 
            for i in range(len(u)):
                il[i] = il[i]+self.r*((self.a1*u[i]+self.a2) * u[i] + self.a3 - il[i])
                u[i] = u[i] + (self.a1*u[i]+self.a2) * u[i] + self.a3 - self.a4 * il[i] + self.a5 * iext
                if abs(u[i]-V_REST) < 1e-5:
                    x[i] = x[i]
                else:
                    if u[i] > V_MP:
                        x[i] = 1;
                    else:
                        x[i] = -1;
'''


这是一个markdown模板

In [None]:
if __name__ == "__main__":

    img = cv.imread(r"D:\resource\mizh\kazuma.jpg", 0)
    size = 64
    img = cv.resize(img, (size, size))
    x = np.where(img > 255 / 2.5, 1, -1)
    x_masked = x.copy()
    x_masked[size // 2:, :] = -1
    show(x_masked)
    
    model = Hopfield(size = size)
    model.train([x])
    y, energy = model.test_one_frame(x_masked)
    show(y)
    plt.plot(energy, label='energy')
    plt.show()
    