# Interactive tool for SNN-PC

The following Jupyter Notebook code demonstrates the performance of 3-layer spiking neural network for predictive coding (SNN-PC). In this demo code, a user can interactively select an MNIST sample, which the network has never seen before but can still make infererence based on other MNIST samples on which it has been trained. 

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, interactive

import inference

# load pre-trained weights
w_mat = inference.load_and_convert_weights('weight_dict.pickle')

# build network
snn_pc = inference.snn_pc(w_mat=w_mat)

# load MNIST data
(X_train, y_train), (X_test, y_test) = inference.tf.keras.datasets.mnist.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255


# Creating widgets
digit_widget = widgets.IntSlider(min=0, max=10, step=1, description='Digit:', height='20px', width='240px')
sample_widget = widgets.IntSlider(min=0, max=100, step=1, description='Sample:', height='20px', width='240px')

def update_s_range(*args):
    sample_widget.max = len(np.where(y_test == digit_widget.value)[0])

def printer(x, y):
    digit_chosen = digit_widget.value
    sample_indices = np.where(y_test == int(digit_chosen))[0]
    sample_chosen = sample_indices[sample_widget.value]

    img = X_test[sample_chosen].astype(np.float32)

    plt.imshow(img, vmin=0, vmax=1, cmap="Reds")
    plt.title('Selected image')
    plt.axis('off')
    plt.show()

    return img


digit_widget.observe(update_s_range, 'value')
sample_selector = interactive(printer, x=digit_widget, y=sample_widget)

file = open("snn_pc_schematics.png", "rb")
image = file.read()
snn_model_img = widgets.Image(
    value=image,
    format='png',
    width=800,
    height=400,
)

# display(snn_model_img)

runSim_button = widgets.Button(
    description='Run simulation',
    icon='play')

def run_inference(a):
    snn_pc(sample_selector.result)

runSim_button.on_click(run_inference)

interactive_simulation = widgets.AppLayout(header=None,
                                           left_sidebar=sample_selector,
                                           center=snn_model_img,
                                           footer=runSim_button,
                                           pane_widths=[3, 3, 3],
                                           pane_heights=[1, 1, '50px'])

# Creating widgets
slider_widget = widgets.IntSlider(min=10, max=350, step=10, description='time (ms)')

def slide_inference(x):
    curr_t = int((slider_widget.value - 10) / 10)

    image_shape = [(28, 28), (36, 36), (34, 34)]
    cols = ['{}'.format(col) for col in ['Input', 'Error', 'Prediction']]
    rows = ['Area {}'.format(str(row)) for row in range(3)]
    
    fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))

    for i in range(3):
        bu_img = inference.tf.reshape(snn_pc.live_imgs['pc' + str(i + 1)][curr_t, 0], image_shape[i])
        err_img = inference.tf.reshape(snn_pc.live_imgs['pc' + str(i + 1)][curr_t, 1], image_shape[i])
        td_img = inference.tf.reshape(snn_pc.live_imgs['pc' + str(i + 1)][curr_t, 2], image_shape[i])

        axs[i, 0].imshow(bu_img, vmin=600, vmax=3000, cmap="Reds")
        axs[i, 1].imshow(err_img, vmin=-3000, vmax=3000, cmap="bwr")
        axs[i, 2].imshow(td_img, vmin=600, vmax=3000, cmap="Reds")

    for jj in axs.flatten():
#         jj.axis('off')
        jj.get_xaxis().set_ticks([])
        jj.get_yaxis().set_ticks([])
        
    for ax, col in zip(axs[0], cols):
        ax.set_title(col)

    for ax, row in zip(axs[:,0], rows):
        ax.set_ylabel(row, rotation='vertical', size='large')
        
    plt.show()

    return fig

slide_inference_widget = interactive(slide_inference, x=slider_widget)

SNN-PC initialized


# Please select a test image using the sliders and click on "Run Simulation."

In [2]:
display(interactive_simulation)

AppLayout(children=(Button(description='Run simulation', icon='play', layout=Layout(grid_area='footer'), style…

Reset to resting state
Simulation in progress. Please wait.


  0%|          | 0/3499 [00:00<?, ?it/s]

Simulation finished!


# You can see how the inference of the chosen sample evolves over time using the slider. Each point represents mean synaptic current, which corresponds to mean firing rates.

In [4]:
display(slide_inference_widget)

interactive(children=(IntSlider(value=10, description='time (ms)', max=350, min=10, step=10), Output(outputs=(…