In [1]:
import numpy as np
import networkx as nx
from SNN_neat import DAG, plot_dag
from SNN_neat import ReLU, LeakyReLU, Sigmoid, Linear, Tanh
from tqdm import tqdm

def dummy_uuid_generator():
    i = 0
    while True:
        yield str(i)
        i += 1

dummy_gen = dummy_uuid_generator()

def uuid1():
    return next(dummy_gen)

In [2]:
INPUTS = 2
OUTPUTS = 1

# test creation of DAG
dag = DAG(INPUTS, OUTPUTS)
plot_dag(dag)

In [3]:
# test adding connections
for n1_id in dag.input_nodes:
    for n2_id in dag.output_nodes:
        n1 = dag.get_node(n1_id)
        n2 = dag.get_node(n2_id)
        dag.add_connection(n1.id, n2.id)
plot_dag(dag)

In [4]:
# test getting conections and connection check
connections = dag.get_connections()
print("Created Connections", connections)
print("Check if connection can be found:", dag.is_connected(dag.input_nodes[0], dag.output_nodes[0]))

Created Connections [('0', '2'), ('1', '2')]
Check if connection can be found: True


In [5]:
# test adding nodes
for connection in connections:
    print(f"Adding Node between {(connection[0], connection[1])}", dag.add_node(connection[0], connection[1]))
plot_dag(dag)

Adding Node between ('0', '2') (True, '3')
Adding Node between ('1', '2') (True, '4')


In [6]:
dag.add_connection('3', '4')
 
plot_dag(dag)

In [7]:
dag.add_node('3', '4')
 
plot_dag(dag)

In [8]:
dag.add_node('0', '3')
 
plot_dag(dag)

In [9]:
 # test getting processing order
print("Processing Order:", dag.get_processing_order())
 
plot_dag(dag)

Processing Order: [<SNN_neat.Layer object at 0x0000028B2B8F3D10>, <SNN_neat.Layer object at 0x0000028B2B8F3250>, <SNN_neat.Layer object at 0x0000028B2B8F2F90>, <SNN_neat.Layer object at 0x0000028B2B8F2F10>, <SNN_neat.Layer object at 0x0000028B2B8F2E10>, <SNN_neat.Layer object at 0x0000028B2B8F2BD0>]


In [10]:
print("Starting Cons:", dag.get_connections())

dag.remove_node('3')
    

print("Ending Cons:", dag.get_connections())
 
plot_dag(dag)

Starting Cons: [('0', '6'), ('1', '4'), ('3', '2'), ('3', '5'), ('4', '2'), ('5', '4'), ('6', '3')]
Ending Cons: [('0', '6'), ('1', '4'), ('4', '2'), ('5', '4')]


In [11]:
dag.kill_floaters()
plot_dag(dag)

In [12]:
dag.add_connection('0', '4')
ok, new_id = dag.add_node('0', '4')
dag.add_connection('4', new_id)
dag.add_node('4','2')
plot_dag(dag)

In [13]:
dag.remove_connection('4', '7')
plot_dag(dag)

# Mutation Test

In [14]:
for i in range(10):
    dag.mutate()   
    plot_dag(dag)

In [15]:
print(dag.get_legal_connections())

[('0', '2'), ('0', '8'), ('0', '9'), ('1', '2'), ('1', '7'), ('1', '9'), ('1', '10'), ('4', '2'), ('4', '9'), ('4', '10'), ('7', '2'), ('7', '8'), ('7', '10'), ('8', '7'), ('8', '10'), ('9', '7'), ('9', '8'), ('9', '10'), ('10', '4'), ('10', '7'), ('10', '8'), ('10', '9')]


# test space

In [16]:
# Choose the dataset
inputs = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]

outputs = [
    [0],  # For AND gate; use [1], [0], [0], [1] for OR gate; [0], [1], [1], [0] for XOR gate
    [1],
    [1],
    [0]
]



In [17]:
# Instantiate your DAG
dag = DAG(2, 1, fully_connect=True)
plot_dag(dag)

in1, in2 = dag.input_nodes
out1 = dag.output_nodes[0]

_, new_node1 = dag.add_node(in1, out1)
_, new_node2 = dag.add_node(in2, out1)

dag.add_connection(in2, new_node1)
dag.add_connection(in1, new_node2)

plot_dag(dag)
print(dag)

DAG(['11', '12']
	SpikingNode(11: bias: 0.9223086604968431, threshold: 0.7, in: [] out: ['14', '15'], weights: [])
	SpikingNode(12: bias: 0.4729934397007223, threshold: 0.7, in: [] out: ['15', '14'], weights: [])
	SpikingNode(13: bias: -0.8554616795695391, threshold: 0.7, in: ['14', '15'] out: [], weights: [0.4309357  0.64621842])
	SpikingNode(14: bias: -0.2213201670940126, threshold: 0.7, in: ['11', '12'] out: ['13'], weights: [ 0.95436346 -0.12032397])
	SpikingNode(15: bias: -0.8514137954458241, threshold: 0.7, in: ['12', '11'] out: ['13'], weights: [0.79805811 0.9183859 ])
['13'])


In [18]:
# Train the DAG
losses = dag.train(inputs, outputs, epochs=1000, lr=0.1, verbose=True)

Start Training DAG 0 with 1000 epochs


100%|██████████| 1000/1000 [00:03<00:00, 262.08it/s]


In [19]:
print(dag)

DAG(['11', '12']
	SpikingNode(11: bias: 0.9223086604968431, threshold: 0.7, in: [] out: ['14', '15'], weights: [])
	SpikingNode(12: bias: 0.4729934397007223, threshold: 0.7, in: [] out: ['15', '14'], weights: [])
	SpikingNode(13: bias: 0.7045383204304619, threshold: 0.7, in: ['14', '15'] out: [], weights: [-0.0690643  -0.05378158])
	SpikingNode(14: bias: 0.7086798329059879, threshold: 0.7, in: ['11', '12'] out: ['13'], weights: [-0.04563654 -0.02032397])
	SpikingNode(15: bias: 0.11858620455417652, threshold: 0.7, in: ['12', '11'] out: ['13'], weights: [0.39805811 0.4183859 ])
['13'])


In [20]:
# Evaluate the DAG's performance
for i, input_pattern in enumerate(inputs):
    dag.reset()
    output = dag.process(input_pattern, return_spiketrains=True)
    print(f" {input_pattern} -> {output} ({np.mean(output)}) | {outputs[i]}")

 [0, 0] -> [[0 0 0 0 0]] (0.0) | [0]
 [0, 1] -> [[1 1 1 1 1]] (1.0) | [1]
 [1, 0] -> [[1 1 1 1 1]] (1.0) | [1]
 [1, 1] -> [[0 0 0 0 0]] (0.0) | [0]
