<a href="https://colab.research.google.com/github/AWorldOfChaos/SoC-2024-Robust-ML/blob/main/Uday/bnn_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten, BatchNormalization
from tensorflow.keras import Model
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

@tf.custom_gradient
def binarize(x):
    def grad(dy):
        return dy * tf.cast(tf.abs(x) <= 1, dtype=tf.float32)
    return tf.where(x >= 0, 1.0, -1.0), grad

# def hard_tanh(x):
#     return tf.clip_by_value(x, -1, 1)

class BinarizedDense(Layer):
    def __init__(self, units, activation=None):
        super(BinarizedDense, self).__init__()
        self.units = units
        self.activation = activation

    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='glorot_uniform',
                                 trainable=True)
        self.b = self.add_weight(shape=(self.units,),
                                 initializer='zeros',
                                 trainable=True)

    def call(self, inputs):
        binary_w = binarize(self.w)
        outputs = tf.matmul(inputs, binary_w) + self.b
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

class BNNModel(Model):
    def __init__(self):
        super(BNNModel, self).__init__()
        self.flatten = Flatten(input_shape=(28, 28))
        self.dense1 = BinarizedDense(512)
        self.bn1 = BatchNormalization()
        self.dense2 = BinarizedDense(256)
        self.bn2 = BatchNormalization()
        self.dense3 = BinarizedDense(128)
        self.bn3 = BatchNormalization()
        self.dense4 = BinarizedDense(64)
        self.bn4 = BatchNormalization()
        self.dense5 = BinarizedDense(32)
        self.bn5 = BatchNormalization()
        self.dense6 = BinarizedDense(10, activation=tf.nn.softmax)

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.bn1(self.dense1(x))
        x = binarize(x)
        x = self.bn2(self.dense2(x))
        # x = hard_tanh(x)
        x = binarize(x)
        x = self.bn3(self.dense3(x))
        x = binarize(x)
        x = self.bn4(self.dense4(x))
        x = binarize(x)
        x = self.bn5(self.dense5(x))
        x = binarize(x)
        return self.dense6(x)

model = BNNModel()

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)



early_stopping = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)

model.fit(x_train, y_train, epochs=100, validation_split=0.1, callbacks=[early_stopping])

test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc}')

model.summary()


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Test accuracy: 0.9679999947547913
Model: "bnn_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_1 (Flatten)         multiple                  0         
                                                                 
 binarized_dense_6 (Binariz  multiple                  401920    
 edDense)                                                        
                                                                 
 batch_normalization_5 (Bat  multiple                  2048      
 chNormalization)                                                
                                                             