In [46]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

In [47]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# Define the SNN layer

In [48]:
class SpikingLayer(tf.keras.layers.Layer):
    def __init__(self, units):
        super(SpikingLayer, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.kernel = self.add_weight("kernel", shape=[input_shape[-1], self.units])

    def call(self, inputs):
        spikes = tf.cast(inputs > 0, dtype=tf.float32)
        output = tf.matmul(spikes, self.kernel)
        return output

In [49]:
model = tf.keras.Sequential([
    SpikingLayer(64),
    SpikingLayer(32),
    SpikingLayer(10)
])

In [50]:
model.compile(
    optimizer="adam", 
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), 
    metrics=["accuracy"]
)

model.fit(
    x_train, 
    y_train, 
    batch_size=32, 
    epochs=100, 
    validation_data=(x_test, y_test)
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7f96e9d61570>

In [51]:
model.summary()

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 spiking_layer (SpikingLayer  (None, 64)               50176     
 )                                                               
                                                                 
 spiking_layer_1 (SpikingLay  (None, 32)               2048      
 er)                                                             
                                                                 
 spiking_layer_2 (SpikingLay  (None, 10)               320       
 er)                                                             
                                                                 
Total params: 52,544
Trainable params: 52,544
Non-trainable params: 0
_________________________________________________________________


In [52]:
model.save("SNN_MNIST.h5")

In [53]:
test_loss, test_acc = model.evaluate(x_test, y_test)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)

Test Loss: 1.5973379611968994
Test Accuracy: 0.45559999346733093


# hls4ml Config Part

In [36]:
import hls4ml

class SNNLayer(hls4ml.model.layers.Layer):

    def initialize(self):
        inp = self.get_input_variable()
        shape = inp.shape
        dims = inp.dim_names
        self.add_output_variavle(shape, dims)

In [35]:
def parse_reverse_layer(keras_layer, input_names, input_shapes, data_reader):
    layer = {}
    layer['class_name'] = 'SNNLayer'
    layer['name'] = keras_layer['config']['name']
    layer['n_in'] = input_shapes[0][1]

    if input_names is not None:
        layer['inputs'] = input_names

    return layer, [shape for shape in input_shapes[0]]