In [41]:
# imports
import jax
import jax.numpy as jnp
import numpy as np

In [42]:
class Node:
    def __init__(self, bias, in_weights, out_weights, id): #, decay_time, processing_type, signal_delay): id int
        """
        Setup for simple node
        We get inputs from other nodes and outputs to other nodes
        The nummber and connection nodes in training should dynamically change
        The processing inside is not fixed and will be either an embedding or attention
        To close mimic real neurons the activation will also have a deacay time

        This architecture is still in development only some features will be added for now
        so debugging is simpler

Current features:
- Simple processing
- Topology functions
To be added:
- Decay time
- Different processing (embedding, attention)
- Signal delay
        """
        # setup
        self.bias = bias # float
        self.in_weights = in_weights # float tensor
        self.out_weights = out_weights # float tensor
        self.in_connections = [] 
        self.out_connections = []

        # processing
        # Here development will happen
        self.activation = 0 # float

        # used in other classes
        self.id = id
    
# Run functions
    def run(self, inputs: jnp.ndarray) -> jnp.ndarray:
        # processing inputs and computing next step
        # TODO: change computation so intermediary value is saved in node
        outputs = jnp.dot(self.out_weights, self.activation)
        self.activation = jnp.dot(self.in_weights, inputs) + self.bias
        return outputs
    
# Topology functions
    # TODO: Change so that it's connection pair wise, do in Cluster
    # Maybe rewrite function for better functionality with cluster
    def add_input(self, weight: float, nodeId: int):
        #Add a new input to the node
        self.in_weights = jnp.append(self.in_weights, weight)
        self.in_connections.append(nodeId)
        return None
    
    def add_output(self, weight: float, nodeId: int):
        #Add a new input to the node
        self.out_weights = jnp.append(self.out_weights, weight)
        self.out_connections.append(nodeId)
        return None
    
    def remove_input(self, index: int):
        #Remove a input from the node
        self.in_weights = jnp.delete(self.in_weights, index)
        self.in_connections.pop(index)
        return None
    
    def remove_output(self, index: int):
        # Remove a output from the node
        self.out_weights = jnp.delete(self.out_weights, index)
        self.out_connections.pop(index)
        return None

# Get functions
    def get_id(self):
        return self.id

# Debugging functions
    def get_activation(self):
        return self.activation 
    
    def get_bias(self):
        return self.bias
    
    def get_in_weights(self):
        return self.in_weights
    
    def get_out_weights(self):
        return self.out_weights
    
    def get_input_size(self):
        return self.in_weights.size
    
    def get_output_size(self):
        return self.out_weights.size
    
    def print_node(self):
        # Gives a overview of the node
        print("Node id: ", self.id)
        print("Activation: ", self.activation)
        print("Bias: ", self.bias)
        print("Inputs: ", self.in_weights)
        print("Outputs: ", self.out_weights)
        return None

In [43]:
# Test for save of activation
Node0 = Node(0.1, jnp.array([0.5]), jnp.array([0.5]), 0)

In [44]:
input = jnp.array([10])
Node0.run(input)

Activation:  0
Activation:  5.1


Array([0.], dtype=float32)

In [97]:
class Cluster:
    def __init__(self, num_inputs: int, num_outputs, num_nodes: int = 0, nodes : list[Node] = None, nodeModel: Node = Node, init_net: bool = True): # uses basic node model, can be changed in the future
        """
        This class saves and edits the computation tree of the nodes
        Meaning it saves the connections between the nodes
        Additionally it displays neighboring nodes needed for training
        """
        """
        This creates a cluster of nodes
        The cluster will have an input and output
        these are either connected to other clusters or inference for the user

Planned:
- Different types of nodes
- Different types of connections
- Propagation delay
- Spacial embedding used for topology
        """

        # TODO : Code when no nodes are given
        # TODO : Make compatible with node ids
        # NOTE : Changes that either nodes are given or parameter for init
        """
        Overview

        Init:
            1) Model parameters
                - nodes init
                - connections
                - io
            2) Model setup
                - node creation
                - connection setup

        Functions:


        """


        # nodes
        self.nodeModel = nodeModel

        if len(nodes) == 0 and num_nodes == 0:
            raise ValueError("No nodes given, provide either a list of nodes or the number of nodes to be created")
        # model parameters
        # TODO: assignation of input and output nodes
        if nodes:
            self.nodes = nodes
            self.num_nodes = len(nodes)
        else:
            self.nodes = []
            self.num_nodes = num_nodes

        # io
        self.input = num_inputs
        self.outputs = num_outputs

        # cluster size
        if self.num_nodes <= 1:
            raise ValueError("Number of nodes must be greater than 1, or provide list of nodes")
        self.connections = jnp.zeros((self.num_nodes, self.num_nodes)) # directional adjecency matrix, row 
        # io
        self.input_nodes = []
        self.output_nodes = []
        
        if init_net:
            # out and input for testing, later done in setupNetwork
            self.input_nodes.append(int (nodes[0].id))
            self.output_nodes.append(int (nodes[-1].id))

            # setup
            for i in range(num_nodes):
                # create disconnected nodes
                self.nodes.append(self.nodeModel(0, jnp.array([]), jnp.array([]), i))
                
            # setup network
            self.setupNetwork()

# Setup functions

    def setupNetwork(self, initialization: str = "simple"):
            """ 
            This function uses diffrent types of algorithms to setup the network
            For now it creates a simple network with a topology similar to normal nns
    Planned:
    - Random connections
    - Evolutionary algorithm
    - Load own nodes
            """
            # TODO: IO setup
            # NOTE : This function should be outlaid to dedicated classes
            
            # TODO


# Topology functions

    def add_node(self, node: Node):
        """
        Takes a node as input and adds it to the adjacency matrix
        There it won't be connected to any other node 

        Args:
            node (Node): node of nodeModel used to create the cluster
        """
        # add node to list of nodes
        self.nodes.append(node)
        # adds node to adjacency matrix
        self.connections = jnp.append(self.connections, jnp.zeros((len(self.connections), 1)), axis=1)
        self.connections = jnp.append(self.connections, jnp.zeros((1, len(self.connections[0]))), axis=0)
        return None
    
    def add_connection(self, node1: Node, node2: Node):
        """
        Takes two nodes and adds a connection between them, node1 -> node2
        The change is made to both the adjacency matrix and the weights of the nodes

        Args:
            node1 (Node): node from which the connection originates 
            node2 (Node): node to which the connection goes
                for both also the classes nodeModel is used
        """
        # get ids, these are the adresses of the nodes
        node1_id = int(node1.get_id())
        node2_id = int(node2.get_id())
        # add connection to adjacency matrix
        self.connections = self.connections.at[node1_id, node2_id].set(1)
        # add connections to nodes
        # TODO: variable initialisation
        node1.add_output(0.5)
        node2.add_input(0.5)
        return None
    
    def remove_node(self, index: int):
        """
        Takes the index of a node and removes it from the node-list and adjacency matrix

        Args:
            index (int): index of node to be removed # NOTE: Could be later changed to node id
        """
        # NOTE: Implement that the node ids are shifted
        # remove node from list of nodes
        self.nodes.delete(index)
        # remove node from adjacency matrix
        self.connections = jnp.delete(self.connections, index, axis=0)
        self.connections = jnp.delete(self.connections, index, axis=1)
        return None
    
    #TODO: Add add-input and add-output functions

# Helper functions

    def get_neighbors(self, node: Node, degree: int = 1) -> jnp.ndarray:
        """
        This function returns the neighbors of a ceratin degree of a node
        Example: degree 1: direct neighbors, degree 2: neighbors of neighbors, ...
    
        This function uses the recursive collect_neighbors function to get the neighbors

        Args:
            node (Node): node for which the neighbors are searched
            degree (int): degree of neighbors that should be searched
        """
        neighbors = self.collect_neighbors(node, node, degree)
        neighbors = jnp.array(neighbors)
        neighbors = jnp.unique(neighbors)
        return neighbors
        
    def collect_neighbors(self, node: Node, prevNode: Node, degree: int = 1) -> jnp.ndarray:
        """
        This function returns the neighbors of a certain degree of a node
        Example: degree 1: direct neighbors, degree 2: neighbors of neighbors, ...
    
        Args:
            node (Node): node for which the neighbors are searched
            degree (int): degree of neighbors that should be searched
        """
        print("Node: ", node.get_id())
        # init empty list
        neighbors = []
        if degree <= 0:
            raise ValueError("Degree must be at least 1")
        # recursive search
        if degree <= 1:
            # end of search
            neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # remove previous node if directed connection exists
            if self.connections[node.get_id()][prevNode.get_id()] == 1:
               neighbors.remove(prevNode.get_id())
        else:
            # get next neighbors
            cur_neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # remove previous node if directed connection exists
            if self.connections[node.get_id()][prevNode.get_id()] == 1:
                neighbors.remove(prevNode.get_id())
            # go over neighbors
            for cur_neighbor in cur_neighbors:
                #TODO: optimize efficiency
                # get new values
                neighbors.extend(self.collect_neighbors(self.nodes[cur_neighbor], node, degree - 1))
        
        return neighbors
       
# Run functions
    def run(self, inputs: jnp.ndarray) -> jnp.ndarray:
        """
        This function runs the cluster for one iteration with the given input
        It might take multiple iterations to even reach the output and maybe longer for a stable output

        Args:
            inputs (jnp.ndarray): input for the cluster
        """
        # prepare input
        self.set_inputNodes(inputs, self.input_nodes)
        # create input for each node from the adjacency matrix
        node_inputs = jnp.zeros([self.num_nodes, self.num_nodes])
        for i in range(0, self.num_nodes): 
            for j in range(0, self.num_nodes):
                #cover all nodes in the adjacency matrix
                if self.connections[i][j] == 1:
                    node_inputs = self.nodes[j]

    
    def set_inputNodes(self, input, input_nodes: list[int]):
        """
        This function sets the actiavtion of the input nodes to its given input
        The input size must match the number of input nodes
        Args:
            input (jnp.ndarray): input for the cluster
            input_nodes (list[int]): list of input nodes
        """
        if input.size != len(input_nodes):
            raise ValueError("Input size does not match")
        for i in range(len(input_nodes)):
            self.nodes[input_nodes[i]].activation = input[i]
        return None

# Debugging functions
    def print_connections(self):
        """
        Prints out the adjacency matrix of the cluster
        Useful for debugging
        """
        print(self.connections)
        return None
    
    """
    def run(self, inputs: jnp.ndarray, input_idx: jnp.ndarray, output_idx: jnp.ndarray):
        # TODO: Use edge list to run nodes
        # TODO: Design method to input data into nodes, not overwriting them
        # run nodes in order
        for i in range(1, len(self.nodes)): # skip first node since it's the input node
            print(inputs)
            print("Node: ", self.nodes[i].get_id())
            print(self.nodes[i].get_in_weights())
            print(self.nodes[i].get_out_weights())
            print(self.nodes[i].get_activation())
            inputs = self.nodes[i].run(inputs)
        return inputs
    
    # functions from cluster
    
    # additional functions for later

    
        
        
    def run(self, inputs):
        # runs the whole network
        if len(inputs) != len(self.input_nodes):
            raise ValueError("Input size does not match")
        # run network
        output = self.computationTree.run(inputs, self.input_nodes, self.output_nodes)
        return output

    def expand(self):
        
        Expands the Cluster depending on information density
        and the surrounding nodes
        
        pass

    def connections(self):
        pass
    
    def get_nodes(self):
        pass


    # debugging functions
    def print_connections(self):
        print(self.connections)
        return None
    
    def print_nodes(self):
        for node in self.nodes:
            node.print_node()
        return None
    
    def print_node_inputs(self, idx: int):
        node = self.nodes[idx]
        print("Node inputs: ", node.get_in_weights())
        return None

"""

In [98]:
# Create a few more nodes for testing
Node0 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 0)
Node1 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 1)
Node2 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 2)
Node3 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 3)
Node4 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 4)
Node5 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 5)
nodes = [Node0, Node1, Node2, Node3, Node4, Node5]

In [99]:
Cluster1 = Cluster(1, 1, nodes=nodes, init_net=False)


In [100]:
Cluster1.add_connection(Node0, Node1)
Cluster1.add_connection(Node1, Node2)
Cluster1.add_connection(Node0, Node3)
Cluster1.add_connection(Node1, Node0)

In [101]:
Cluster1.print_connections()

[[0. 1. 0. 1. 0. 0.]
 [1. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]


In [102]:
print(Cluster1.get_neighbors(Node0, 2))

Node:  0
Cur neighbor:  1
Node:  1
Cur neighbor:  3
Node:  3
[2]
