In [None]:
import cupy as np
from torchvision import datasets, transforms

dtype = np.float32

In [None]:
training_data_set = datasets.MNIST(
    root        = "./data/MNIST/",
    train       = True,
    transform   = transforms.ToTensor,
    download    = True
    ).data.numpy() / 255

testing_data_set = datasets.MNIST(
    root        = "./data/MNIST/",
    train       = False,
    transform   = transforms.ToTensor,
    download    = True
    ).data.numpy() / 255

training_data_set = np.asarray(training_data_set, dtype)
testing_data_set  = np.asarray(testing_data_set , dtype)

#training_data_set[training_data_set <  0.4] = 0
#training_data_set[training_data_set >= 0.4] = 1
#testing_data_set[testing_data_set   <  0.4] = 0
#testing_data_set[testing_data_set   >= 0.4] = 1

In [None]:
def random_separate(
        data_set_size   : int,
        separate_size   : int | tuple,
        seed            : int | None = None
        ):
    
    rs = np.random.RandomState(seed)
    if hasattr(separate_size, "__iter__"):
        _randIdx = rs.choice(data_set_size, sum(list(separate_size)), replace = False)
        randIdx = []
        for i in range(len(separate_size)):
            randIdx.append(_randIdx[sum(separate_size[: i]) : sum(separate_size[: i + 1])])
    else:
        randIdx = rs.choice(data_set_size, separate_size, replace = False)

    return randIdx


n_validation = 200
n_training   = 10000 - n_validation
n_testing    = 500

idx_training_data_set, idx_validation_data_set = random_separate(
    data_set_size = training_data_set.shape[0], 
    separate_size = (n_training, n_validation),
    seed          = 12345
    )

idx_testing_data_set = random_separate(
    data_set_size = testing_data_set.shape[0], 
    separate_size = n_testing,
    seed          = 12345
)

validation_data_set = training_data_set[idx_validation_data_set]
training_data_set   = training_data_set[idx_training_data_set]
testing_data_set    = testing_data_set[idx_testing_data_set]

training_data_set.shape, validation_data_set.shape, testing_data_set.shape

In [None]:
from Restricted_Boltzmann_Machine.Restricted_Boltzmann_Machine import RBM
from Restricted_Boltzmann_Machine.Training import RBM_training

visible_size = 28 * 28
hidden_size  = 48
rbm = RBM(
    visible_size  = visible_size,
    hidden_size   = hidden_size,
    seed          = 1234,
    dtype         = np.float32
    )

batch_size = 1024
training = RBM_training(
    rbm                 = rbm,
    epochs              = 2000,
    training_samples    = training_data_set,
    validation_samples  = validation_data_set,
    testing_samples     = testing_data_set,
    batch_size          = batch_size,
    learning_rate       = 0.02
)

training.start_training(1)


In [None]:
verify_shape = (8, 16)
n_figs = verify_shape[0] * verify_shape[1]

reconstruct = np.zeros(shape = (n_figs, 28*28), dtype = dtype)
testing = testing_data_set[: n_figs].reshape((n_figs, 28*28))

for i_batch in range(n_figs // min(n_figs, batch_size)):
    verify_slice = slice(i_batch * batch_size, (i_batch + 1) * batch_size)
    reconstruct[verify_slice], _, _ = rbm.forward(testing[verify_slice])

testing     = testing.reshape((*verify_shape, 28, 28))
reconstruct = reconstruct.reshape((*verify_shape, 28, 28))

import matplotlib.pyplot as plt
#(x,y,i,j) -> (x,i,y,j)
#(0,1,2,3) -> (0,2,1,3)
testing = testing.transpose((0,2,1,3))
testing = testing.reshape(28 * verify_shape[0], 28 * verify_shape[1])

reconstruct = reconstruct.transpose((0,2,1,3))
reconstruct = reconstruct.reshape(28 * verify_shape[0], 28 * verify_shape[1])

fig = plt.figure(dpi = 200)
ax1 = fig.add_subplot(211)
ax1.imshow(testing.get(), cmap="gray")
ax1.set_axis_off()

ax2 = fig.add_subplot(212)
ax2.imshow(reconstruct.get(), cmap="gray")
ax2.set_axis_off()

In [None]:
import matplotlib.pyplot as plt

#(i,j,x,y) -> (x,i,y,j)
#(0,1,2,3) -> (2,0,3,1)
hidden_shape = (6, 8)
W = rbm.Weight.reshape(28, 28, hidden_shape[0], hidden_shape[1])
W = np.transpose(W, (2,0,3,1)).reshape(28 * hidden_shape[0], 28 * hidden_shape[1])

fig = plt.figure(dpi = 140)
ax = fig.add_subplot(111)
ax.imshow(W.get(), cmap="gray")
ax.set_axis_off()

