In [1]:
import numpy as np
import networkx as nx
from SNN_neat import DAG, plot_dag
from SNN_neat import ReLU, LeakyReLU, Sigmoid, Linear, Tanh
from tqdm import tqdm
import plotly.graph_objects as go
from time import time

In [2]:
WIDTH = 100
NET_DEPH = 5

dag = DAG(WIDTH, WIDTH)

for i, node_id in enumerate(dag.input_nodes):
    out_node = dag.output_nodes[i]
    dag.add_connection(node_id, out_node)
    
    # add node in input connection
    new_node = node_id
    for i in range(NET_DEPH):
        _, new_node = dag.add_node(new_node, out_node)

dag.processing_order = dag.get_processing_order()
for i, layer in enumerate(dag.processing_order[:-1]):
    for node1 in layer:
        for node2 in dag.processing_order[i+1]:
            dag.add_connection(node1, node2)

dag.processing_order = dag.get_processing_order()
print(dag)

DAG(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
	SpikingNode(0: bias: 0.40295251218977723, threshold: 0.7, in: [] out: ['20', '25', '30', '35', '40', '45', '50', '55', '60', '65'], weights: [])
	SpikingNode(1: bias: -0.33777121061049, threshold: 0.7, in: [] out: ['25', '20', '30', '35', '40', '45', '50', '55', '60', '65'], weights: [])
	SpikingNode(2: bias: -0.6923728615493681, threshold: 0.7, in: [] out: ['30', '20', '25', '35', '40', '45', '50', '55', '60', '65'], weights: [])
	SpikingNode(3: bias: -0.8241405205512724, threshold: 0.7, in: [] out: ['35', '20', '25', '30', '40', '45', '50', '55', '60', '65'], weights: [])
	SpikingNode(4: bias: -0.46283954521328385, threshold: 0.7, in: [] out: ['40', '20', '25', '30', '35', '45', '50', '55', '60', '65'], weights: [])
	SpikingNode(5: bias: -0.48926556683580724, threshold: 0.7, in: [] out: ['45', '20', '25', '30', '35', '40', '50', '55', '60', '65'], weights: [])
	SpikingNode(6: bias: -0.39774723815808843, threshold: 0.7, in: [] ou

In [3]:
THRESHOLD = 0.5
ITERATIONS = 1000

thresholds = []
percentages = np.arange(0, 1.05, 0.05)
start = time()
for percentage in tqdm(percentages, desc="Percentage"):
    # set threshold
    for layer in dag.processing_order:
        for nodeId in layer:
            node = dag.get_node(nodeId)
            node.threshold = THRESHOLD

    layer_avg_activations = []
    # iterate multipple times
    real_percentages = []
    for _ in range(ITERATIONS):
        dag.reset()
        inputs = [int(np.random.rand() < percentage) for _ in range(WIDTH)]

        real_percentages.append(np.sum(inputs) / WIDTH)

        for layer in dag.processing_order:
            for nodeId in layer:
                node = dag.get_node(nodeId)
                new_weights = [1/((len(node.weights))*0.8) for _ in range(len(node.weights))]#(np.random.rand(len(node.weights)) * 2 - 1)
                node.bias = 0#(np.random.rand() * 2 - 1)
                node.weights = new_weights #/ np.linalg.norm(list(new_weights)+[node.bias])
                #node.bias = node.bias / np.linalg.norm(list(new_weights)+[node.bias])

        # process the network #####
        dag.process_once(inputs)
        ###########################

        # get activation 
        avg_activation_probability = 0
        for node_id in dag.output_nodes:
            node = dag.get_node(node_id)
            avg_activation_probability += node.spike

        avg_activation_probability /= len(dag.output_nodes)

        layer_avg_activations.append(avg_activation_probability)

    thresholds.append([
        percentage, 
        np.min(real_percentages), 
        np.average(real_percentages), 
        np.max(real_percentages), 
        np.min(layer_avg_activations), 
        np.average(layer_avg_activations), 
        np.max(layer_avg_activations)
    ])
        
thresholds = np.array(thresholds)
end = time()
print("Time taken: ", end-start)

Percentage: 100%|██████████| 21/21 [00:13<00:00,  1.61it/s]

Time taken:  13.055293560028076





In [4]:
# plot all the tested thresholds, by different lines in the xs vs. ys plot with plotly
# x values are input activity, y values are mean activity of the second layer
fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,0], mode='lines', name="x=y"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,1], mode='lines+markers', name="Min %"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,2], mode='lines+markers', name="Mean %"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,3], mode='lines+markers', name="Max %"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,4], mode='lines+markers', name="Min"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,5], mode='lines+markers', name="Mean"))
fig.add_trace(go.Scatter(x=thresholds[:,0], y=thresholds[:,6], mode='lines+markers', name="Max"))

fig.update_xaxes(title="Input activity")
fig.update_yaxes(title="Mean activity")
fig.show()
