In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist

2023-07-05 01:31:48.230336: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255

# Define the SNN layer

In [21]:
class SNNLayer(tf.keras.layers.Layer):
    def __init__(self, num_neurons, threshold=1.0):
        super(SNNLayer, self).__init__()
        self.num_neurons = num_neurons
        self.threshold = threshold

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[1], self.num_neurons),
            initializer="random_normal",
            trainable=True,
        )

    def call(self, inputs):
        membrane_potential = tf.matmul(inputs, self.kernel)
        spikes = tf.where(membrane_potential >= self.threshold, tf.ones_like(membrane_potential), tf.zeros_like(membrane_potential))
        return spikes

In [29]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    SNNLayer(num_neurons=256, threshold=0.5),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [31]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f9638416ef0>

In [32]:
model.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_6 (Flatten)         (None, 784)               0         
                                                                 
 snn_layer_3 (SNNLayer)      (None, 256)               200704    
                                                                 
 dense_13 (Dense)            (None, 10)                2570      
                                                                 
Total params: 203,274
Trainable params: 203,274
Non-trainable params: 0
_________________________________________________________________


In [33]:
model.save("SNN_MNIST.h5")

# hls4ml Config Part

In [36]:
import hls4ml

class SNNLayer(hls4ml.model.layers.Layer):

    def initialize(self):
        inp = self.get_input_variable()
        shape = inp.shape
        dims = inp.dim_names
        self.add_output_variavle(shape, dims)

In [35]:
def parse_reverse_layer(keras_layer, input_names, input_shapes, data_reader):
    layer = {}
    layer['class_name'] = 'SNNLayer'
    layer['name'] = keras_layer['config']['name']
    layer['n_in'] = input_shapes[0][1]

    if input_names is not None:
        layer['inputs'] = input_names

    return layer, [shape for shape in input_shapes[0]]