In [3]:
# imports
import jax
import jax.numpy as jnp
import numpy as np

In [4]:
class Node:
    def __init__(self, bias, in_weights, out_weights, id): #, decay_time, processing_type, signal_delay): id int
        """
        Setup for simple node
        We get inputs from other nodes and outputs to other nodes
        The nummber and connection nodes in training should dynamically change
        The processing inside is not fixed and will be either an embedding or attention
        To close mimic real neurons the activation will also have a deacay time

        This architecture is still in development only some features will be added for now
        so debugging is simpler

Current features:
- Simple processing
- Topology functions
To be added:
- Decay time
- Different processing (embedding, attention)
- Signal delay
        """
        # setup
        self.bias = bias # float
        self.in_weights = in_weights # float tensor
        self.out_weights = out_weights # float tensor

        # processing
        # Here development will happen
        self.activation = 0 # float

        # used in other classes
        self.id = id
    
# Run functions
    def run(self, inputs: jnp.ndarray) -> jnp.ndarray:
        # processing inputs and computing next step
        # TODO: change computation so intermediary value is saved in node
        outputs = jnp.dot(self.out_weights, self.activation)
        self.activation = jnp.dot(self.in_weights, inputs) + self.bias
        return outputs
    
# Topology functions
    # TODO: Change so that it's connection pair wise, do in Cluster
    # Maybe rewrite function for better functionality with cluster
    def add_input(self, weight: float):
        #Add a new input to the node
        self.in_weights = jnp.append(self.in_weights, weight)
        return None
    
    def add_output(self, weight: float):
        #Add a new input to the node
        self.out_weights = jnp.append(self.out_weights, weight)
        return None
    
    def remove_input(self, index: int):
        #Remove a input from the node
        self.in_weights = jnp.delete(self.in_weights, index)
        return None
    
    def remove_output(self, index: int):
        # Remove a output from the node
        self.out_weights = jnp.delete(self.out_weights, index)
        return None

# Get functions
    def get_id(self):
        return self.id

# Debugging functions
    def get_activation(self):
        return self.activation 
    
    def get_bias(self):
        return self.bias
    
    def get_in_weights(self):
        return self.in_weights
    
    def get_out_weights(self):
        return self.out_weights
    
    def get_input_size(self):
        return self.in_weights.size
    
    def get_output_size(self):
        return self.out_weights.size
    
    def print_node(self):
        # Gives a overview of the node
        print("Node id: ", self.id)
        print("Activation: ", self.activation)
        print("Bias: ", self.bias)
        print("Inputs: ", self.in_weights)
        print("Outputs: ", self.out_weights)
        return None

In [39]:
class ComputationTree:
    def __init__(self, nodes: list[Node] = None):
        """
        This class saves and edits the computation tree of the nodes
        Meaning it saves the connections between the nodes
        Additionally it displays neighboring nodes needed for training
        """
        # TODO : Code when no nodes are given
        # TODO : Make compatible with node ids
        self.nodes = nodes # jnp.array of nodes
        self.connections = jnp.zeros((len(nodes), len(nodes))) # directional adjecency matrix, row 

    def add_node(self, node: Node):
        #TODO: Check
        # add node to list of nodes
        self.nodes.append(node)
        # not connected to any node
        self.connections = jnp.append(self.connections, jnp.zeros((len(self.connections), 1)), axis=1)
        self.connections = jnp.append(self.connections, jnp.zeros((1, len(self.connections[0]))), axis=0)
        return None
    
    def remove_node(self, index: int):
        # remove node from list of nodes
        self.nodes.delete(index)
        self.connections = jnp.delete(self.connections, index, axis=0)
        self.connections = jnp.delete(self.connections, index, axis=1)
        return None

    def add_connection(self, node1: Node, node2: Node):
        # TODO : Add connection between connection and node input place
        # get ids like addresses in connection
        node1_id = int(node1.get_id())
        node2_id = int(node2.get_id())

        # add connection
        self.connections = self.connections.at[node1_id, node2_id].set(1)

        # add in and out to node
        # TODO: variable initialisation
        node1.add_output(0)
        node2.add_input(0)
        return None
    
    def get_neighbors(self, node: Node, degree: int) -> jnp.array:
        # TODO: Add function that removes double entries and check inputs
        # This function gets neighbors of a certain degree using collect_neighbors
        # This function is primarly concerned with converting ids and getting rid of double entries
        neighbors = self.collect_neighbors(node, degree)
        neighbors = jnp.array(neighbors) # list to array
        neighbors = jnp.unique(neighbors) # remove double entries
        return neighbors

    def collect_neighbors(self, node, degree):
        # collects the neighbor of a certain degree of a node
        # degree 1: direct neighbors, degree 2: neighbors of neighbors, ...

        # init empty list to collect neighbors
        neighbors = []

        # recursive search
        if degree <= 1:
            neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
        else:
            cur_neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # go over neighbors
            for cur_neighbor in cur_neighbors:
                neighbors.extend(self.collect_neighbors(self.nodes[cur_neighbor], degree - 1))
        return neighbors
    
    def run(self, inputs: jnp.ndarray, input_idx: jnp.ndarray, output_idx: jnp.ndarray):
        # TODO: Use edge list to run nodes
        # TODO: Design method to input data into nodes, not overwriting them
        # run nodes in order
        for i in range(len(self.nodes)):
            inputs = self.nodes[i].run(inputs)

        return inputs

    # debugging functions
    def print_connections(self):
        print(self.connections)
        return None

In [40]:
# Create a few more nodes for testing
Node0 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 0)
Node1 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 1)
Node2 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 2)
Node3 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 3)
Node4 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 4)
Node5 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 5)

In [41]:
# Test prep for ComputationTree
nodes = [Node0, Node1, Node2, Node3, Node4] # used nodes
CT = ComputationTree(nodes)
CT.add_connection(Node0, Node1)
CT.add_connection(Node1, Node2)
CT.add_connection(Node0, Node3)
CT.add_connection(Node3, Node4)
CT.add_node(Node5)
CT.add_connection(Node0, Node5)
CT.add_connection(Node5, Node4)

In [42]:
# Test ComputationTree
CT.print_connections()
print(CT.get_neighbors(Node0, 2))

[[0. 1. 0. 1. 0. 1.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0.]]
[2 4]


In [43]:
# resume development after compuleting computation Tree
class Cluster:
    def __init__(self, input, outputs, num_nodes=1, nodeModel: Node = Node): # uses basic node model, can be changed in the future
        """
        This creates a cluster of nodes
        The cluster will have an input and output
        these are either connected to other clusters or inference for the user

Planned:
- Different types of nodes
- Different types of connections
- Propagation delay
- Spacial embedding used for topology
        """
        # basic params
        self.input = input
        self.outputs = outputs
        self.num_nodes = num_nodes
        self.nodes = []
        self.nodeModel = nodeModel

        # the nodes use to get information into and out of the cluster
        self.input_nodes = jnp.array([])
        self.output_nodes = jnp.array([])
        
        # setup
        for i in range(num_nodes):
            # create disconnected nodes
            self.nodes.append(self.nodeModel(0, jnp.array([]), jnp.array([]), i))
        # create computation tree
        self.computationTree = ComputationTree(self.nodes)
        # out and input for testing, later done in setupNetwork
        self.input_nodes = jnp.append(self.input_nodes, int (nodes[0].id))
        self.output_nodes = jnp.append(self.output_nodes, int (nodes[-1].id))
        # setup network
        self.setupNetwork()

# additional functions for later

    def setupNetwork(self, initialization: str = "simple"):
        """ 
        This function uses diffrent types of algorithms to setup the network
        For now it creates a simple network with a topology similar to normal nns
Planned:
- Random connections
- Evolutionary algorithm
- Load own nodes
        """
        # initialization for in and output
        
        # network initialization
        if initialization == "simple":
            # simple network
            # just a linear connection
            for i in range(len(self.nodes) - 1):
                self.computationTree.add_connection(self.nodes[i], self.nodes[i+1])
        
        
    def run(self, inputs):
        # runs the whole network
        if len(inputs) != len(self.input_nodes):
            raise ValueError("Input size does not match")
        # run network
        output = self.computationTree.run(inputs, self.input_nodes, self.output_nodes)
        return output

    def expand(self):
        """
        Expands the Cluster depending on information density
        and the surrounding nodes
        """
        pass

    def connections(self):
        pass
    
    def get_nodes(self):
        pass


In [44]:
# Test of Cluster
Cluster0 = Cluster(1, 1, 5) # 1 input, 1 output, 5 nodes
value = jnp.array([1])
print(Cluster0.run(value))

0.0


TypeError: dot_general requires contracting dimensions to have the same shape, got (0,) and (1,).