In [1]:
# imports
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
class Node:
    """
    A class representing a node in a neural network. It's used as a building block for the Cluster class.

    This class simulates a simple node with basic processing capabilities.
    The node can dynamically change its connections and weights during training.

    This is part of a research project and is still in development.

    Current features:
        - Simple processing
        - Topology functions
    To be added:
        - Decay time
        - Different processing (embedding, attention)
        - Signal delay


    Attributes:
        bias (float): The bias of the node.
        in_weights (jnp.array): The weights of the inputs.
        out_weights (jnp.array): The weights of the outputs.
        in_connections (list): List of node IDs that feed into this node.
        out_connections (list): List of node IDs that this node feeds into.
        activation (float): The current activation value of the node.
        outputs (jnp.array): The current outputs of the node.
        inputs (jnp.array): The current inputs to the node.
        id (int): The unique identifier of the node.
    """

    def __init__(self, bias, in_weights, out_weights, id):
        """
        Initializes the Node with the given parameters.

        Args:
            bias (float): The bias of the node.
            in_weights (jnp.array): The weights of the inputs.
            out_weights (jnp.array): The weights of the outputs.
            id (int): The unique identifier of the node.
        """
        self.bias = bias
        self.in_weights = in_weights
        self.out_weights = out_weights
        self.in_connections = []
        self.out_connections = []
        self.activation = 0
        self.outputs = jnp.array([])
        self.inputs = jnp.array([])
        self.id = id

# run functions
    def run(self):
        """
        Runs one step of the node's processing.

        Processes the inputs and computes the next activation and outputs.
        """
        # TODO: Save last activation so it's not lost
        self.outputs = jnp.dot(self.out_weights, self.activation)
        self.activation = jnp.dot(self.in_weights, self.inputs) + self.bias #TODO: Look if activation function is needed

    def set_inputs(self, inputs):
        """
        Sets the inputs for the node.

        Args:
            inputs (jnp.array): The inputs for the node, must match the size of in_weights.

        Raises:
            ValueError: If the size of inputs does not match the size of in_weights.
        """
        if inputs.size != self.in_weights.size:
            raise ValueError("Input size does not match in_weights size")
        self.inputs = inputs

    def get_inputNodes(self):
        """
        Returns a list of the IDs of the nodes which feed into this node.

        Returns:
            list: List of node IDs.
        """
        return self.in_connections

    def get_outputs(self):
        """
        Returns the outputs of the node.

        If the node is run for the first time, returns a tensor of zeros.

        Returns:
            jnp.array: The outputs of the node.
        """
        if self.outputs.size == 0:
            return jnp.zeros(self.out_weights.size)
        return self.outputs

    def get_outputNodes(self):
        """
        Returns a list of the IDs of the nodes which this node feeds into.

        Returns:
            list: List of node IDs.
        """
        return self.out_connections

# Topology functions
    # TODO: Create a function that automatically sorts connections
    # TODO: Change so that it's connection pair wise, do in Cluster
    # Maybe rewrite function for better functionality with cluster

    def add_input(self, weight, nodeId):
        """
        Adds a new input to the node.

        Args:
            weight (float): The weight of the input.
            nodeId (int): The ID of the node to connect to.
        """
        self.in_weights = jnp.append(self.in_weights, weight)
        self.in_connections.append(nodeId)

    def add_output(self, weight, nodeId):
        """
        Adds a new output to the node.

        Args:
            weight (float): The weight of the output.
            nodeId (int): The ID of the node to connect to.
        """
        self.out_weights = jnp.append(self.out_weights, weight)
        self.out_connections.append(nodeId)

    def remove_input(self, index):
        """
        Removes an input from the node.

        Args:
            index (int): The index of the input to remove.
        """
        self.in_weights = jnp.delete(self.in_weights, index)
        self.in_connections.pop(index)

    def remove_output(self, index):
        """
        Removes an output from the node.

        Args:
            index (int): The index of the output to remove.
        """
        self.out_weights = jnp.delete(self.out_weights, index)
        self.out_connections.pop(index)

# Helper functions
    def get_id(self):
        """
        Returns the ID of the node.

        Returns:
            int: The ID of the node.
        """
        return self.id

    def get_activation(self):
        """
        Returns the activation of the node.

        Returns:
            float: The activation value.
        """
        return self.activation

    def get_bias(self):
        """
        Returns the bias of the node.

        Returns:
            float: The bias value.
        """
        return self.bias

    def get_in_weights(self):
        """
        Returns the input weights of the node.

        Returns:
            jnp.array: The input weights.
        """
        return self.in_weights

    def get_out_weights(self):
        """
        Returns the output weights of the node.

        Returns:
            jnp.array: The output weights.
        """
        return self.out_weights

    def get_input_size(self):
        """
        Returns the size of the input weights.

        Returns:
            int: The size of the input weights.
        """
        return self.in_weights.size

    def get_output_size(self):
        """
        Returns the size of the output weights.

        Returns:
            int: The size of the output weights.
        """
        return self.out_weights.size

# Debugging functions
    def print_node(self):
        """
        Prints the node information.
        """
        print("Node id: ", self.id)
        print("Activation: ", self.activation)
        print("Bias: ", self.bias)
        print("Inputs: ", self.in_weights)
        print("Outputs: ", self.out_weights)


    """
    Simple node class, used as building block for the cluster class

    We get inputs from other nodes and outputs to other nodes
    The nummber and connection nodes in training should dynamically change
    The processing inside is not fixed and will be either an embedding or attention
    To close mimic real neurons the activation will also have a deacay time

    This architecture is still in development only some features will be added for now
    so debugging is simpler
    """


In [3]:
# Test of node class
Node0 = Node(0.1, jnp.array([0.5]), jnp.array([0.5]), 0)

In [63]:
class Cluster:
    def __init__(self, num_inputs: int, num_outputs, num_nodes: int = 0, nodes : list[Node] = None, nodeModel: Node = Node, init_net: bool = True): # uses basic node model, can be changed in the future
        """
        This class saves and edits the computation tree of the nodes
        Meaning it saves the connections between the nodes
        Additionally it displays neighboring nodes needed for training
        """
        """
        This creates a cluster of nodes
        The cluster will have an input and output
        these are either connected to other clusters or inference for the user

Planned:
- Different types of nodes
- Different types of connections
- Propagation delay
- Spacial embedding used for topology
        """

        # TODO : Code when no nodes are given
        # TODO : Make compatible with node ids
        # NOTE : Changes that either nodes are given or parameter for init
        """
        Overview

        Init:
            1) Model parameters
                - nodes init
                - connections
                - io
            2) Model setup
                - node creation
                - connection setup

        Functions:


        """


        # nodes
        self.nodeModel = nodeModel

        if len(nodes) == 0 and num_nodes == 0:
            raise ValueError("No nodes given, provide either a list of nodes or the number of nodes to be created")
        # model parameters
        # TODO: assignation of input and output nodes
        if nodes:
            self.nodes = nodes
            self.num_nodes = len(nodes)
        else:
            self.nodes = []
            self.num_nodes = num_nodes

        # io
        self.input = num_inputs
        self.outputs = num_outputs

        # cluster size
        if self.num_nodes <= 1:
            raise ValueError("Number of nodes must be greater than 1, or provide list of nodes")
        self.connections = jnp.zeros((self.num_nodes, self.num_nodes)) # directional adjecency matrix, row 
        # io
        self.input_nodes = []
        self.output_nodes = []

        # NOTE Magic number for now
        self.input_nodes.append(int (nodes[0].id))
        self.output_nodes.append(int (nodes[-1].id))

        if init_net:
            # out and input for testing, later done in setupNetwork
            self.input_nodes.append(int (nodes[0].id))
            self.output_nodes.append(int (nodes[-1].id))

            # setup
            for i in range(num_nodes):
                # create disconnected nodes
                self.nodes.append(self.nodeModel(0, jnp.array([]), jnp.array([]), i))
                
            # setup network
            self.setupNetwork()
            return

# Setup functions

    def setupNetwork(self, initialization: str = "simple"):
        """ 
        This function uses diffrent types of algorithms to setup the network
        For now it creates a simple network with a topology similar to normal nns
    Planned:
    - Random connections
    - Evolutionary algorithm
    - Load own nodes
        """
        # TODO: IO setup
        # NOTE : This function should be outlaid to dedicated classes
            
        # TODO
        
        return None


# Topology functions

    def add_node(self, node: Node):
        """
        Takes a node as input and adds it to the adjacency matrix
        There it won't be connected to any other node 

        Args:
            node (Node): node of nodeModel used to create the cluster
        """
        # add node to list of nodes
        self.nodes.append(node)
        # adds node to adjacency matrix
        self.connections = jnp.append(self.connections, jnp.zeros((len(self.connections), 1)), axis=1)
        self.connections = jnp.append(self.connections, jnp.zeros((1, len(self.connections[0]))), axis=0)
        return None
    
    def add_connection(self, node1: Node, node2: Node):
        """
        Takes two nodes and adds a connection between them, node1 -> node2
        The change is made to both the adjacency matrix and the weights of the nodes

        Args:
            node1 (Node): node from which the connection originates 
            node2 (Node): node to which the connection goes
                for both also the classes nodeModel is used
        """
        # get ids, these are the adresses of the nodes
        node1_id = int(node1.get_id())
        node2_id = int(node2.get_id())
        # add connection to adjacency matrix
        self.connections = self.connections.at[node1_id, node2_id].set(1)
        # add connections to nodes
        # TODO: variable initialisation
        node1.add_output(0.5, node2_id)
        node2.add_input(0.5, node1_id)
        return None
    
    def remove_node(self, index: int):
        """
        Takes the index of a node and removes it from the node-list and adjacency matrix

        Args:
            index (int): index of node to be removed # NOTE: Could be later changed to node id
        """
        # NOTE: Implement that the node ids are shifted
        # remove node from list of nodes
        self.nodes.delete(index)
        # remove node from adjacency matrix
        self.connections = jnp.delete(self.connections, index, axis=0)
        self.connections = jnp.delete(self.connections, index, axis=1)
        return None
    
    #TODO: Add add-input and add-output functions

# Helper functions

    def get_neighbors(self, node: Node, degree: int = 1) -> jnp.ndarray:
        """
        This function returns the neighbors of a ceratin degree of a node
        Example: degree 1: direct neighbors, degree 2: neighbors of neighbors, ...
    
        This function uses the recursive collect_neighbors function to get the neighbors

        Args:
            node (Node): node for which the neighbors are searched
            degree (int): degree of neighbors that should be searched
        """
        neighbors = self.collect_neighbors(node, node, degree)
        neighbors = jnp.array(neighbors)
        neighbors = jnp.unique(neighbors)
        return neighbors
        
    def collect_neighbors(self, node: Node, prevNode: Node, degree: int = 1) -> jnp.ndarray:
        """
        This function returns the neighbors of a certain degree of a node
        Example: degree 1: direct neighbors, degree 2: neighbors of neighbors, ...
    
        Args:
            node (Node): node for which the neighbors are searched
            degree (int): degree of neighbors that should be searched
        """
        # init empty list
        neighbors = []
        if degree <= 0:
            raise ValueError("Degree must be at least 1")
        # recursive search
        if degree <= 1:
            # end of search
            neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # remove previous node if directed connection exists
            if self.connections[node.get_id()][prevNode.get_id()] == 1:
               neighbors.remove(prevNode.get_id())
        else:
            # get next neighbors
            cur_neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # remove previous node if directed connection exists
            if self.connections[node.get_id()][prevNode.get_id()] == 1:
                neighbors.remove(prevNode.get_id())
            # go over neighbors
            for cur_neighbor in cur_neighbors:
                #TODO: optimize efficiency
                # get new values
                neighbors.extend(self.collect_neighbors(self.nodes[cur_neighbor], node, degree - 1))
        
        return neighbors
       
# Run functions
    def run(self, u_inputs: jnp.ndarray) -> jnp.ndarray:
        """
        This function runs the cluster for one iteration with the given input
        It might take multiple iterations to even reach the output and maybe longer for a stable output
        The main complexity is in sorting in remapping the inputs and outputs
        Since all nodes are connected to each other the order of the nodes is important
        This approach will be changed later to be more efficient for now nodes are sorted

        Args:
            u_inputs (jnp.ndarray): user inputs for the cluster
        
        Returns:
            jnp.ndarray: output of the cluster
        """
        # This first parts collects the current network values and uses them to run the network

        # create input for each node from the adjacency matrix
        net_inputs = []
        output_connections = []
        # get new inputs
        input_connections = []
        for node in self.nodes:
            net_inputs.append(node.get_outputs())
            output_connections.append(node.get_outputNodes())
            input_connections.append(node.get_inputNodes())

        print("inputs: ", net_inputs)
        print("out_conn: ", output_connections)
        print("in_conn: ", input_connections)
        # map the inputs to the nodes
        node_inputs = [[] for _ in range(self.num_nodes)]
        for i in range(self.num_nodes):
            for j in input_connections[i]:
                list_pos = output_connections[j].index(i)
                node_inputs[i].append(net_inputs[j][list_pos])
        print("node_inputs: ", node_inputs)
        # run nodes
        # here also the external input is added
        for i in range(self.num_nodes):
            self.nodes[i].set_inputs(jnp.array(node_inputs[i]))
            self.nodes[i].run()
            if self.nodes[i].id in self.input_nodes:
                temp_activation = self.nodes[i].get_activation()
                self.nodes[i].activation = temp_activation + u_inputs[self.input_nodes.index(i)]
                print("input: ", u_inputs[self.input_nodes.index(i)])
        # get output
        for i in range(len(self.output_nodes)):
            output = jnp.append(output, self.nodes[self.output_nodes[i]].get_activation())
        return output

# Debugging functions
    def print_connections(self):
        """
        Prints out the adjacency matrix of the cluster
        Useful for debugging
        """
        print(self.connections)
        return None
    
    def print_activations(self):
        """
        Prints out the activations of the nodes
        Useful for debugging
        """
        for node in self.nodes:
            print("Node: ", node.get_id(), "Activation: ", node.get_activation())
        return None
    
    """
    def expand(self):
        
        Expands the Cluster depending on information density
        and the surrounding nodes
        
        pass

    def connections(self):
        pass
    
    def get_nodes(self):
        pass
"""

In [64]:
# Create a few more nodes for testing
Node0 = Node(0, jnp.array([]), jnp.array([]), 0)
Node1 = Node(0, jnp.array([]), jnp.array([]), 1)
Node2 = Node(0, jnp.array([]), jnp.array([]), 2)
Node3 = Node(0, jnp.array([]), jnp.array([]), 3)
Node4 = Node(0, jnp.array([]), jnp.array([]), 4)
Node5 = Node(0, jnp.array([]), jnp.array([]), 5)
nodes = [Node0, Node1, Node2, Node3, Node4, Node5]

In [65]:
Cluster1 = Cluster(1, 1, nodes=nodes, init_net=False)


In [66]:
Cluster1.add_connection(Node0, Node1)
Cluster1.add_connection(Node1, Node2)
Cluster1.add_connection(Node0, Node3)
Cluster1.add_connection(Node1, Node0)

In [67]:
Cluster1.print_connections()

[[0. 1. 0. 1. 0. 0.]
 [1. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]


In [68]:
Cluster1.get_neighbors(Node0, 2)

Array([2], dtype=int32)

In [84]:
input = jnp.array([1])
print(Cluster1.run(input))

inputs:  [Array([0.53125, 0.53125], dtype=float32), Array([0.1328125, 0.1328125], dtype=float32), Array([], shape=(0,), dtype=float32), Array([], shape=(0,), dtype=float32), Array([], shape=(0,), dtype=float32), Array([], shape=(0,), dtype=float32)]
out_conn:  [[1, 3], [2, 0], [], [], [], []]
in_conn:  [[1], [0], [1], [0], [], []]
node_inputs:  [[Array(0.1328125, dtype=float32)], [Array(0.53125, dtype=float32)], [Array(0.1328125, dtype=float32)], [Array(0.53125, dtype=float32)], [], []]
input:  1
[0.]


In [85]:
Cluster1.print_activations()

Node:  0 Activation:  1.0664062
Node:  1 Activation:  0.265625
Node:  2 Activation:  0.06640625
Node:  3 Activation:  0.265625
Node:  4 Activation:  0.0
Node:  5 Activation:  0.0
