# Introduction to BNNs with Larq

This tutorial demonstrates how to train a simple binarized Convolutional Neural Network (CNN) to classify MNIST digits. This simple network will achieve approximately 98% accuracy on the MNIST test set. This tutorial uses Larq and the [Keras Sequential API](https://www.tensorflow.org/guide/keras), so creating and training our model will require only a few lines of code.

In [None]:
# import tensorflow as tf
# import larq as lq

!pip -q install tensorflow==2.10.0
!pip -q install larq==0.13.1

import tensorflow as tf
import larq as lq

### Download and prepare the CIFAR10 dataset

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

#train_images = train_images.reshape((60000, 28, 28, 1))
#test_images = test_images.reshape((10000, 28, 28, 1))

# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1

print(train_images.shape, train_labels.shape)  # Debe ser (50000, 32, 32, 3), (50000, 1)
print(test_images.shape, test_labels.shape)    # Debe ser (10000, 32, 32, 3), (10000, 1)

### Create the model

The following will create a simple binarized CNN.

The quantization function
$$
q(x) = \begin{cases}
    -1 & x < 0 \\\
    1 & x \geq 0
\end{cases}
$$
is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning.

To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE)
(the binarization is essentially replaced by a clipped identity on the backward pass):
$$
\frac{\partial q(x)}{\partial x} = \begin{cases}
    1 & \left|x\right| \leq 1 \\\
    0 & \left|x\right| > 1
\end{cases}
$$

In Larq this can be done by using `input_quantizer="ste_sign"` and `kernel_quantizer="ste_sign"`.
Additionally, the latent full precision weights are clipped to -1 and 1 using `kernel_constraint="weight_clip"`.

In [None]:
# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip")

model = tf.keras.models.Sequential()

# In the first layer we only quantize the weights and not the input
# Change the input_shape to (32, 32, 3) to match the training data
model.add(lq.layers.QuantConv2D(32, (3, 3),
                                kernel_quantizer="ste_sign",
                                kernel_constraint="weight_clip",
                                use_bias=False,
                                input_shape=(32, 32, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Flatten())

model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Activation("softmax"))

Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware.

 Here is the complete architecture of our model:

In [None]:
lq.models.summary(model)

In [None]:
import tensorflow as tf

model_2 = tf.keras.models.Sequential()

# En la primera capa convolucional, configuramos el input_shape según el tamaño de los datos de entrada
model_2.add(tf.keras.layers.Conv2D(32, (3, 3), use_bias=False, input_shape=(32, 32, 3)))
model_2.add(tf.keras.layers.MaxPooling2D((2, 2)))
model_2.add(tf.keras.layers.BatchNormalization(scale=False))

# Segunda capa convolucional
model_2.add(tf.keras.layers.Conv2D(64, (3, 3), use_bias=False))
model_2.add(tf.keras.layers.MaxPooling2D((2, 2)))
model_2.add(tf.keras.layers.BatchNormalization(scale=False))

# Tercera capa convolucional
model_2.add(tf.keras.layers.Conv2D(64, (3, 3), use_bias=False))
model_2.add(tf.keras.layers.BatchNormalization(scale=False))

# Aplanado para pasar a capas densas
model_2.add(tf.keras.layers.Flatten())

# Primera capa densa
model_2.add(tf.keras.layers.Dense(64, use_bias=False))
model_2.add(tf.keras.layers.BatchNormalization(scale=False))

# Segunda capa densa
model_2.add(tf.keras.layers.Dense(10, use_bias=False))
model_2.add(tf.keras.layers.BatchNormalization(scale=False))

# Activación final con softmax para clasificación de 10 clases
model_2.add(tf.keras.layers.Activation("softmax"))

# Obtener el resumen del modelo
model_2.summary()


In [None]:
# Después de definir el modelo
model_2.summary()

### Compile and train the model

Note: This may take a few minutes depending on your system.

In [None]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, batch_size=64, epochs=6, validation_split=0.2)

### Evaluate the model

In [None]:
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy {test_acc * 100:.2f} %")

In [None]:
import matplotlib.pyplot as plt

# Graficar las curvas de precisión y pérdida
# Asegúrate de que `val_images` y `val_labels` estén definidos si quieres usar datos de validación

import matplotlib.pyplot as plt

# Crear la figura
plt.figure(figsize=(8, 6))

# Graficar la precisión de entrenamiento
if 'accuracy' in history.history:
    plt.plot(history.history['accuracy'], label='Precisión de entrenamiento')

# Graficar la precisión de validación, si está disponible
if 'val_accuracy' in history.history:
    plt.plot(history.history['val_accuracy'], label='Precisión de validación')

# Graficar la pérdida de entrenamiento
if 'loss' in history.history:
    plt.plot(history.history['loss'], label='Pérdida de entrenamiento')

# Graficar la pérdida de validación, si está disponible
if 'val_loss' in history.history:
    plt.plot(history.history['val_loss'], label='Pérdida de validación')

# Etiquetas y leyenda
plt.xlabel('Épocas')
plt.ylabel('Valor')
plt.title('Curvas de Entrenamiento y Validación')
plt.legend()
plt.show()

As you can see, our simple binarized CNN has achieved a test accuracy of around 98 %. Not bad for a few lines of code!

For information on converting Larq models to an optimized format and using or benchmarking them on Android or ARM devices, have a look at [this guide](https://docs.larq.dev/compute-engine/end_to_end/).

In [None]:
model_2.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history_2 = model_2.fit(train_images, train_labels, batch_size=64, epochs=6, validation_split=0.2)


In [None]:
test_loss, test_acc = model_2.evaluate(test_images, test_labels)
print(f"Test accuracy {test_acc * 100:.2f} %")

In [None]:
import matplotlib.pyplot as plt

# Graficar las curvas de precisión y pérdida
# Asegúrate de que `val_images` y `val_labels` estén definidos si quieres usar datos de validación

import matplotlib.pyplot as plt

# Crear la figura
plt.figure(figsize=(8, 6))

# Graficar la precisión de entrenamiento
if 'accuracy' in history_2.history:
    plt.plot(history_2.history['accuracy'], label='Precisión de entrenamiento')

# Graficar la precisión de validación, si está disponible
if 'val_accuracy' in history_2.history:
    plt.plot(history_2.history['val_accuracy'], label='Precisión de validación')

# Graficar la pérdida de entrenamiento
if 'loss' in history_2.history:
    plt.plot(history_2.history['loss'], label='Pérdida de entrenamiento')

# Graficar la pérdida de validación, si está disponible
if 'val_loss' in history_2.history:
    plt.plot(history_2.history['val_loss'], label='Pérdida de validación')

# Etiquetas y leyenda
plt.xlabel('Épocas')
plt.ylabel('Valor')
plt.title('Curvas de Entrenamiento y Validación')
plt.legend()
plt.show()