In [2]:
import numpy as np

class BoltzmannMachine:
    def __init__(self, num_nodes):
        self.num_nodes = num_nodes
        self.weights = np.random.randn(self.num_nodes, self.num_nodes)
        self.states = np.random.choice([0, 1], self.num_nodes)

    def energy(self):
        E = -np.dot(self.states, np.dot(self.weights, self.states))
        return E

    def step(self):
        for i in range(self.num_nodes):
            self.update_node(i)

    def update_node(self, node):
        net_input = np.dot(self.weights[node], self.states)
        probability = self.sigmoid(net_input)
        self.states[node] = 1 if np.random.rand() < probability else 0

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def train(self, data, epochs=1000, learning_rate=0.1):
        for epoch in range(epochs):
            np.random.shuffle(data)
            for sample in data:
                self.states = sample
                self.step()

                for i in range(self.num_nodes):
                    for j in range(self.num_nodes):
                        if i != j:
                            delta_weight = learning_rate * (self.states[i] * self.states[j] - self.sigmoid(np.dot(self.weights[i], sample)) * self.sigmoid(np.dot(self.weights[j], sample)))
                            self.weights[i, j] += delta_weight
                            self.weights[j, i] = self.weights[i, j]


            if epoch % 100 == 0:
                print(f"Epoch {epoch}: Current state: {self.states}, Energy: {self.energy()}")


num_nodes = 6
bm = BoltzmannMachine(num_nodes)


data = np.array([[0, 1, 1, 0, 1, 0],
                 [1, 0, 0, 1, 0, 1]])

bm.train(data, epochs=500, learning_rate=0.05)


Epoch 0: Current state: [1 0 1 0 1 1], Energy: -16.673181034091755
Epoch 100: Current state: [1 0 1 0 1 1], Energy: -14.320573528874037
Epoch 200: Current state: [1 1 1 0 1 1], Energy: -18.875478063257095
Epoch 300: Current state: [1 1 1 0 1 1], Energy: -24.845904113297568
Epoch 400: Current state: [1 1 1 0 1 1], Energy: -23.932278192693154
