In [36]:
import jax
import jax.numpy as jnp
import numpy as np

In [37]:
class Node:
    def __init__(self, bias, in_weights, out_weights, id):
        """
        Setup for simple node
        We get inputs from other nodes and outputs to other nodes
        The nummber and connection nodes in training should dynamically change
        The processing inside is not fixed and will be either an embedding or attention
        To close mimic real neurons the activation will also have a deacay time

        This architecture is still in development only some features will be added for now
        so debugging is simmpler

        The inputs aggregate into a single state.
        """
        # setup
        self.bias = bias # float
        self.in_weights = in_weights # float tensor
        self.out_weights = out_weights # float tensor

        # processing
        # Here development will happen
        self.activation = 0 # float
        self.id = id
    
    def processing(self, inputs):
        """
        Processing of the inputs
        """
        self.activation = jnp.dot(self.in_weights, inputs) + self.bias
        outputs = jnp.dot(self.out_weights, self.activation)
        return outputs
    
    def add_input(self, weight):
        # TODO: Check
        #Add a new input to the node
        self.in_weights = jnp.append(self.in_weights, weight)
        return None
    
    def add_output(self, weight):
        # TODO: Check
        #Add a new input to the node
        self.out_weights = jnp.append(self.out_weights, weight)
        return None
    
    def remove_input(self, index):
        #TODO: Check
        #Remove a input from the node
        self.in_weights = jnp.delete(self.in_weights, index)
        return None
    
    def remove_output(self, index):
        #TODO: Check
        # Remove a output from the node
        self.out_weights = jnp.delete(self.out_weights, index)
        return None
    
    def get_id(self):
        return self.id

    # debugging functions
    def get_activation(self):
        return self.activation 
    
    def get_bias(self):
        return self.bias
    
    def get_in_weights(self):
        return self.in_weights
    
    def get_out_weights(self):
        return self.out_weights

In [38]:
# create a node
# 3 inputs, 3 outputs
Node1 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 0)
# 3 inputs, 1 outputs
Node2 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1]), 1)


In [39]:
# Test for node
x = jnp.array([1, 2, 3])
print(Node1.processing(x))
print(Node2.processing(x))

[0.14000002 0.28000003 0.42000005]
[0.14000002]


In [49]:
class ComputationTree:
    def __init__(self, nodes: list[Node] = None):
        """
        This class saves and edits the computation tree of the nodes
        Meaning it saves the connections between the nodes
        Additionally it displays neighboring nodes needed for training
        """
        # TODO : Code for when no nodes are given
        # TODO : Make compatible with node ids
        self.nodes = nodes # jnp.array of nodes
        self.connections = jnp.zeros((len(nodes), len(nodes))) # directional adjecency matrix, row 

    def add_node(self, node):
        #TODO: Check
        """
        Add a new node to the computation tree
        """
        self.nodes.append(node)
        # not connected to any node
        self.connections = jnp.append(self.connections, jnp.zeros((len(self.connections), 1)), axis=1)
        self.connections = jnp.append(self.connections, jnp.zeros((1, len(self.connections[0]))), axis=0)
        return None
    
    def remove_node(self, index):
        self.nodes.delete(index)
        self.connections = jnp.delete(self.connections, index, axis=0)
        self.connections = jnp.delete(self.connections, index, axis=1)
        return None

    def add_connection(self, node1, node2):
        # get ids like addresses in connection
        node1_id = int(node1.get_id())
        node2_id = int(node2.get_id())

        # add connection
        self.connections = self.connections.at[node1_id, node2_id].set(1)
        return None
    
    def get_neighbors(self, node, degree):
        # get the neighbor of a certain degree of a node
        # degree 1: direct neighbors, degree 2: neighbors of neighbors, ...

        # init empty list to collect neighbors
        neighbors = []

        # recursive search
        if degree == 1:
            neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
        else:
            cur_neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # go over neighbors
            for cur_neighbor in cur_neighbors:
                neighbors.extend(self.get_neighbors(self.nodes[cur_neighbor], degree - 1))
        return neighbors
        
    # debugging functions
    def print_connections(self):
        print(self.connections)
        return None

In [50]:
# Create a few more nodes for testing
Node3 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 2)
Node4 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 3)
Node5 = Node(0, jnp.array([0.1, 0.2, 0.3]), jnp.array([0.1, 0.2, 0.3]), 4)

In [51]:
# Test prep for ComputationTree
nodes = [Node1, Node2, Node3, Node4, Node5] # used nodes
CT = ComputationTree(nodes)
CT.add_connection(Node1, Node2)
CT.add_connection(Node2, Node3)
CT.add_connection(Node3, Node4)
CT.add_connection(Node4, Node5)


TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [52]:
# Test ComputationTree
CT.print_connections()

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [44]:
# resume development after compuleting computation Tree
class Cluster:
    def __init__(self, input, outputs, num_nodes):
        """
        This creates a cluster of nodes
        The cluster will have an input and output
        these are either connected to other clusters or inference for the user
        """

    def processing(self, inputs):
        """
        Run one cycle with all nodes
        """
        pass

    def expand(self):
        """
        Expands the Cluster depending on information density
        and the surrounding nodes
        """
        pass

    def connections(self):
        pass
    
    def get_nodes(self):
        pass
