Train a simple binarized Convolutional Neural Network to classify MNIST digits

In [None]:
pip install larq

Collecting larq
  Downloading larq-0.13.3-py3-none-any.whl (66 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/66.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━[0m [32m61.4/66.2 kB[0m [31m2.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.2/66.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting terminaltables>=3.1.0 (from larq)
  Downloading terminaltables-3.1.10-py2.py3-none-any.whl (15 kB)
Installing collected packages: terminaltables, larq
Successfully installed larq-0.13.3 terminaltables-3.1.10


In [None]:
import tensorflow as tf
import larq as lq

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip")

model = tf.keras.models.Sequential()

# In the first layer we only quantize the weights and not the input
model.add(lq.layers.QuantConv2D(32, (3, 3),
                                kernel_quantizer="ste_sign",
                                kernel_constraint="weight_clip",
                                use_bias=False,
                                input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Flatten())

model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Activation("softmax"))

In [None]:
lq.models.summary(model)

+sequential stats------------------------------------------------------------------------------------------+
| Layer                  Input prec.           Outputs  # 1-bit  # 32-bit  Memory  1-bit MACs  32-bit MACs |
|                              (bit)                        x 1       x 1    (kB)                          |
+----------------------------------------------------------------------------------------------------------+
| quant_conv2d                     -  (-1, 26, 26, 32)      288         0    0.04           0       194688 |
| max_pooling2d                    -  (-1, 13, 13, 32)        0         0       0           0            0 |
| batch_normalization              -  (-1, 13, 13, 32)        0        64    0.25           0            0 |
| quant_conv2d_1                   1  (-1, 11, 11, 64)    18432         0    2.25     2230272            0 |
| max_pooling2d_1                  -    (-1, 5, 5, 64)        0         0       0           0            0 |
| batch_normalizati

In [None]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, batch_size=64, epochs=6)

test_loss, test_acc = model.evaluate(test_images, test_labels)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [None]:
print(f"Test accuracy {test_acc * 100:.2f} %")

Test accuracy 96.18 %


In [None]:
predictions= model.predict(test_images)



In [None]:
print(predictions[0])

[0.03467151 0.02637186 0.02903052 0.0492917  0.02694386 0.02578786
 0.0350228  0.66962457 0.04895309 0.05430223]


In [None]:
print(sum(predictions[1]))

1.0000000335276127


In [None]:
print(len(predictions))

10000


In [None]:
for y in range (len(predictions)):
  for x in range (len(predictions[y])):
    if(predictions[y][x]==max(predictions[y])):
      print(f"Image number {y} contains the digit {x}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Image number 5000 contains the digit 3
Image number 5001 contains the digit 9
Image number 5002 contains the digit 9
Image number 5003 contains the digit 8
Image number 5004 contains the digit 4
Image number 5005 contains the digit 1
Image number 5006 contains the digit 0
Image number 5007 contains the digit 6
Image number 5008 contains the digit 0
Image number 5009 contains the digit 9
Image number 5010 contains the digit 6
Image number 5011 contains the digit 8
Image number 5012 contains the digit 6
Image number 5013 contains the digit 1
Image number 5014 contains the digit 1
Image number 5015 contains the digit 9
Image number 5016 contains the digit 8
Image number 5017 contains the digit 9
Image number 5018 contains the digit 2
Image number 5019 contains the digit 3
Image number 5020 contains the digit 5
Image number 5021 contains the digit 5
Image number 5022 contains the digit 9
Image number 5023 contains the digit 4