Inspired by https://towardsdatascience.com/hopfield-networks-neural-memory-machines-4c94be821073

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

class Hopfield:
    # array: jax.Array
    weights: jax.Array
    
    def __init__(self, size):
        self.weights = jnp.zeros((size, size))

    def update_weights(self, state: jax.Array): 
        self.weights = (1/state.shape[0]) * jnp.outer(state.T, state)
        self.weights = jnp.fill_diagonal(self.weights, 0, inplace=False)

    def update_state(self, state, key):
        mask =jax.random.uniform(key, state.shape) > 0.75
        # print(mask)
        # print(f"Weights: {self.weights}, state: {state}")
        # print(self.weights @ state)
        new_values = 2*(jnp.dot(self.weights, state) > 0) - 1
        return mask * new_values + (1 - mask) * state
    
    def compute_energy(self, state):
        return -0.5*jnp.dot(jnp.dot(state.T, self.weights),state)
    

In [2]:
key = jax.random.key(0)
hop = Hopfield(2)
hop.update_weights(jnp.array([-1, 1, 1]))
print(hop.weights)
state = jnp.array([-1, 1, -1])
for i in range(10):
    key, k = jax.random.split(key)
    # print(state)
    state = hop.update_state(state, key)

print(state)
hop.weights

[[ 0.         -0.33333334 -0.33333334]
 [-0.33333334  0.          0.33333334]
 [-0.33333334  0.33333334  0.        ]]
[-1  1  1]


Array([[ 0.        , -0.33333334, -0.33333334],
       [-0.33333334,  0.        ,  0.33333334],
       [-0.33333334,  0.33333334,  0.        ]],      dtype=float32, weak_type=True)

In [3]:

#for MNIST fetch
import requests, gzip, os, hashlib
import pygame
import matplotlib.pyplot as plt
import tensorflow 

#Fetch MNIST dataset from the ~SOURCE~
def fetch_MNIST():

    (x_train, y_train), (x_test, y_test) = tensorflow.keras.datasets.mnist.load_data()
    x_train = jnp.expand_dims(x_train, -1)
    x_test = jnp.expand_dims(x_test, -1)
    return x_train

import os
from PIL import Image
import jax
import jax.numpy as jnp

# Make a directory to save frames
os.makedirs("hopfield_frames", exist_ok=True)

def MNIST_Hopfield(): 
    # Fetch MNIST dataset for some random memory downloads
    X = fetch_MNIST().reshape((-1, 784))
    
    # Convert to binary
    X_binary = jnp.where(X > 20, 1, -1)

    # Snag a memory from computer brain
    memories_list = jnp.array([X_binary[1]])
    
    # Initialize Hopfield object
    H_Net = Hopfield(784)
    H_Net.update_weights(memories_list)

    # Draw it all out, updating board each update iteration
    cellsize = 20
    pygame.init()  # Initialize pygame
    surface = pygame.display.set_mode((28 * cellsize, 28 * cellsize)) 
    pygame.display.set_caption("Hopfield Network Visualization")
    
    Running = True
    key = jax.random.PRNGKey(0)
    cells = jax.random.normal(key=jax.random.PRNGKey(1), shape=(784,)) > 0

    frame_count = 0  # To save frames
    
    while Running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                Running = False
                
                # Quit pygame
                pygame.quit()
    
        # Fills surface with color
        surface.fill((211, 211, 211)) 
        
        # Loop through network state array and update colors for each cell
        for r in range(28):  # Iterate over rows
            for c in range(28):  # Iterate over columns
                if cells[r * 28 + c] == -1:
                    col = (135, 206, 250)
                elif cells[r * 28 + c] == 1:
                    col = (0, 0, 128)
                else: 
                    col = (255, 140, 0)
                pygame.draw.rect(surface, col, (r * cellsize, c * cellsize, cellsize, cellsize)) # Draw new cell
        
        # Save each frame as an image
        pygame.image.save(surface, f"hopfield_frames/frame_{frame_count:03d}.png")
        frame_count += 1

        # Update network state
        key, kk = jax.random.split(key)
        cells = H_Net.update_state(cells, kk)
        pygame.display.update()  # Updates display from new .draw in update function
        pygame.time.wait(50)

    # Create a GIF from saved frames
    frames = [Image.open(f"hopfield_frames/frame_{i:03d}.png") for i in range(frame_count)]
    frames[0].save("hopfield_network.gif", save_all=True, append_images=frames[1:], duration=50, loop=0)
    print("GIF saved as hopfield_network.gif")

MNIST_Hopfield()
plt.show() 


pygame 2.6.1 (SDL 2.28.4, Python 3.11.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


2025-01-10 12:19:46.350 python[36745:9956961] +[IMKClient subclass]: chose IMKClient_Legacy
2025-01-10 12:19:46.350 python[36745:9956961] +[IMKInputSession subclass]: chose IMKInputSession_Legacy


error: display Surface quit