In [1]:
import numpy as np
import networkx as nx
from SNN_neat import DAG, plot_dag
from SNN_neat import ReLU, LeakyReLU, Sigmoid, Linear, Tanh
from tqdm import tqdm
import plotly.graph_objects as go
from time import time

In [2]:
WIDTH = 10

dag = DAG(WIDTH, WIDTH)

for i, node_id in enumerate(dag.input_nodes):
    output_node = dag.output_nodes[i]
    dag.add_connection(node_id, output_node)
    #_, new_node = dag.add_node(node_id, output_node)
    # add node in input connection
    #_, new_node2 = dag.add_node(node_id, new_node)
    #_, new_node3 = dag.add_node(new_node, output_node)

dag.processing_order = dag.get_processing_order()
for i, layer in enumerate(dag.processing_order[:-1]):
    for node1 in layer:
        for node2 in dag.processing_order[i+1]:
            dag.add_connection(node1, node2)

dag.processing_order = dag.get_processing_order()
print(dag)

DAG(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
	SpikingNode(0: bias: -0.9803834756929917, threshold: 0.7, in: [] out: ['10', '11', '12', '13', '14', '15', '16', '17', '18', '19'], weights: [])
	SpikingNode(1: bias: 0.019586353340679263, threshold: 0.7, in: [] out: ['11', '10', '12', '13', '14', '15', '16', '17', '18', '19'], weights: [])
	SpikingNode(2: bias: -0.8852345246654172, threshold: 0.7, in: [] out: ['12', '10', '11', '13', '14', '15', '16', '17', '18', '19'], weights: [])
	SpikingNode(3: bias: -0.20267887105479154, threshold: 0.7, in: [] out: ['13', '10', '11', '12', '14', '15', '16', '17', '18', '19'], weights: [])
	SpikingNode(4: bias: 0.6530144882420845, threshold: 0.7, in: [] out: ['14', '10', '11', '12', '13', '15', '16', '17', '18', '19'], weights: [])
	SpikingNode(5: bias: 0.5306451092875222, threshold: 0.7, in: [] out: ['15', '10', '11', '12', '13', '14', '16', '17', '18', '19'], weights: [])
	SpikingNode(6: bias: -0.8209014862491772, threshold: 0.7, in: [] out

In [3]:
thresholds = {}
percentages = np.arange(0, 1, 0.05)
start = time()
for threshold in tqdm(np.arange(-0.1, 0.1, 0.05), desc="Threshold"):
    thresholds[threshold] = [[], []] # input activity , mean activity
    for percentage in tqdm(percentages, desc="Percentage"):
        
        # set threshold
        for layer in dag.processing_order:
            for nodeId in layer:
                node = dag.get_node(nodeId)
                node.threshold = threshold

        layer_avg_activations = [[] for _ in range(len(dag.get_processing_order()))]
        # iterate multipple times
        for _ in range(100):
            dag.reset()
            inputs = [int(np.random.rand() < percentage) for _ in range(WIDTH)]

            for layer in dag.processing_order:
                for nodeId in layer:
                    node = dag.get_node(nodeId)
                    new_weights = np.random.rand(len(node.weights)) * 2 - 1
                    node.bias = np.random.rand() * 2 - 1
                    node.weights = new_weights / np.linalg.norm(list(new_weights)+[node.bias])
                    node.bias = node.bias / np.linalg.norm(list(new_weights)+[node.bias])
                    
            dag.process_once(inputs)
            # get activation 
            for i, layer in enumerate(dag.get_processing_order()):
                avg_activation_probability = 0
                for node_id in layer:
                    node = dag.get_node(node_id)
                    avg_activation_probability += node.spike

                avg_activation_probability /= len(layer)

                layer_avg_activations[i].append(avg_activation_probability)

        thresholds[threshold][0].append(percentage)
        thresholds[threshold][1].append(np.mean(layer_avg_activations[-1]))
        
end = time()
print("Time taken: ", end-start)

Percentage: 100%|██████████| 20/20 [00:01<00:00, 14.07it/s]
Percentage: 100%|██████████| 20/20 [00:01<00:00, 14.26it/s]
Percentage: 100%|██████████| 20/20 [00:01<00:00, 14.26it/s]
Percentage: 100%|██████████| 20/20 [00:01<00:00, 14.46it/s]
Threshold: 100%|██████████| 4/4 [00:05<00:00,  1.41s/it]

Time taken:  5.629703760147095





In [4]:
# plot all the tested thresholds, by different lines in the xs vs. ys plot with plotly
# x values are input activity, y values are mean activity of the second layer
fig = go.Figure()
for threshold, [xs, ys] in thresholds.items():
    fig.add_trace(go.Scatter(x=xs, y=np.array(ys), mode='lines+markers', name=f"T: {threshold}"))

fig.add_trace(go.Scatter(x=xs, y=xs, mode='lines', name="y=x"))
fig.update_xaxes(title="Input activity")
fig.update_yaxes(title="Mean activity")
fig.update_layout(title=f"T: {threshold}")
fig.show()


In [5]:
# extract for all input activities all the threshold where the mean activity is the closest to the input activity
print(percentages.shape)
closest_thresholds = [[], [], [], []]
for input_percentage in percentages:
    closest_threshold = None
    closest_distance = float("inf")
    for threshold, [xs, ys] in thresholds.items():
        distance = abs(input_percentage - ys[xs.index(input_percentage)])
        if distance < closest_distance:
            closest_distance = distance
            closest_threshold = threshold
    closest_thresholds[0].append(closest_threshold)
    closest_thresholds[1].append(closest_distance)
    closest_thresholds[2].append(ys[xs.index(input_percentage)])

# plot the closest thresholds for each input activity
fig = go.Figure()
fig.update_layout(title=f"Closest Thresholds")
fig.add_trace(go.Scatter(x=percentages, y=closest_thresholds[0], mode='lines+markers'))
fig.add_trace(go.Scatter(x=percentages, y=closest_thresholds[1], mode='lines+markers'))
fig.add_trace(go.Scatter(x=percentages, y=closest_thresholds[2], mode='lines+markers'))
fig.add_trace(go.Scatter(x=percentages, y=closest_thresholds[3], mode='lines+markers'))
fig.show()

(20,)


In [6]:
import json

# save the thresholds to a json file
with open("thresholds.json", "w") as f:
    json.dump(thresholds, f, indent=4)

In [7]:
closest_thresholds = {}
for input_percentage in percentages:
    closest_thresholds[input_percentage] = {"threshold": None, "distance": float("inf")}
    closest_threshold = None
    closest_distance = float("inf")
    for threshold, [xs, ys] in thresholds.items():
        distance = abs(input_percentage - ys[xs.index(input_percentage)])
        if distance < closest_distance:
            closest_distance = distance
            closest_threshold = threshold
    closest_thresholds[input_percentage]["threshold"] = closest_threshold
    closest_thresholds[input_percentage]["distance"] = closest_distance

# save the closest thresholds to a json file
with open("closest_thresholds.json", "w") as f:
    json.dump(closest_thresholds, f, indent=4)