In [1]:
#---- Source https://scikit-learn.org/stable/auto_examples/neural_networks/plot_rbm_logistic_classification.html#sphx-glr-auto-examples-neural-networks-plot-rbm-logistic-classification-py

#---------- Generate data ----------------
import numpy as np

from scipy.ndimage import convolve
from sklearn import datasets
from sklearn.preprocessing import minmax_scale
from sklearn.model_selection import train_test_split


def nudge_dataset(X, Y):
    """
    This produces a dataset 5 times bigger than the original one,
    by moving the 8x8 images in X around by 1px to left, right, down, up
    """
    direction_vectors = [
        [[0, 1, 0], [0, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [1, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 1], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 0], [0, 1, 0]],
    ]

    def shift(x, w):
        return convolve(x.reshape((8, 8)), mode="constant", weights=w).ravel()

    X = np.concatenate(
        [X] + [np.apply_along_axis(shift, 1, X, vector) for vector in direction_vectors]
    )
    Y = np.concatenate([Y for _ in range(5)], axis=0)
    return X, Y


X, y = datasets.load_digits(return_X_y=True)
X = np.asarray(X, "float32")
X, Y = nudge_dataset(X, y)
X = minmax_scale(X, feature_range=(0, 1))  # 0-1 scaling

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)


In [2]:
# ------ Models definition -----------

from sklearn import linear_model
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline

logistic = linear_model.LogisticRegression(solver="newton-cg", tol=1)
rbm = BernoulliRBM(random_state=0, verbose=True)

rbm_features_classifier = Pipeline(steps=[("rbm", rbm), ("logistic", logistic)])

In [3]:
# ------ Training -----------
from sklearn.base import clone

# Hyper-parameters. These were set by cross-validation,
# using a GridSearchCV. Here we are not performing cross-validation to
# save time.
rbm.learning_rate = 0.06
rbm.n_iter = 10

# More components tend to give better prediction performance, but larger
# fitting time
rbm.n_components = 100
logistic.C = 6000

# Training RBM-Logistic Pipeline
rbm_features_classifier.fit(X_train, Y_train)

# Training the Logistic regression classifier directly on the pixel
raw_pixel_classifier = clone(logistic)
raw_pixel_classifier.C = 100.0
raw_pixel_classifier.fit(X_train, Y_train)

[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.12s
[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s
[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.14s
[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.92, time = 0.14s
[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.48, time = 0.14s
[BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.13s
[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.83, time = 0.13s
[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.55, time = 0.14s
[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.27, time = 0.13s
[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.28, time = 0.13s


In [4]:
#-------- testing --------
from sklearn import metrics

Y_pred = rbm_features_classifier.predict(X_test)
print(
    "Logistic regression using RBM features:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)

Y_pred = raw_pixel_classifier.predict(X_test)
print(
    "Logistic regression using raw pixel features:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)

Logistic regression using RBM features:
              precision    recall  f1-score   support

           0       1.00      0.98      0.99       174
           1       0.90      0.94      0.92       184
           2       0.94      0.94      0.94       166
           3       0.92      0.88      0.90       194
           4       0.95      0.95      0.95       186
           5       0.94      0.92      0.93       181
           6       0.98      0.97      0.97       207
           7       0.93      0.99      0.96       154
           8       0.90      0.88      0.89       182
           9       0.86      0.88      0.87       169

    accuracy                           0.93      1797
   macro avg       0.93      0.93      0.93      1797
weighted avg       0.93      0.93      0.93      1797


Logistic regression using raw pixel features:
              precision    recall  f1-score   support

           0       0.90      0.92      0.91       174
           1       0.59      0.57      0.58  