In [None]:
class ComputationTree:
    def __init__(self, input, outputs, num_nodes=1, nodeModel: Node = Node): # uses basic node model, can be changed in the future
        """
        This class saves and edits the computation tree of the nodes
        Meaning it saves the connections between the nodes
        Additionally it displays neighboring nodes needed for training
        """
        """
        This creates a cluster of nodes
        The cluster will have an input and output
        these are either connected to other clusters or inference for the user

Planned:
- Different types of nodes
- Different types of connections
- Propagation delay
- Spacial embedding used for topology
        """
        # TODO : Code when no nodes are given
        # TODO : Make compatible with node ids
        self.nodeModel = nodeModel
        self.nodes = jnp.array([]) # jnp.array of nodes
        if num_nodes <= 1:
            raise ValueError("Number of nodes must be greater than 1")
        self.connections = jnp.zeros((len(num_nodes), len(num_nodes))) # directional adjecency matrix, row 

        self.input_nodes = jnp.array([])
        self.output_nodes = jnp.array([])

        # basic params
        self.input = input
        self.outputs = outputs
        self.num_nodes = num_nodes

        # out and input for testing, later done in setupNetwork
        self.input_nodes = jnp.append(self.input_nodes, int (nodes[0].id))
        self.output_nodes = jnp.append(self.output_nodes, int (nodes[-1].id))

        # setup
        for i in range(num_nodes):
            # create disconnected nodes
            self.nodes.append(self.nodeModel(0, jnp.array([]), jnp.array([]), i))
        # setup network
        self.setupNetwork()


    def add_node(self, node: Node):
        #TODO: Check
        # add node to list of nodes
        self.nodes.append(node)
        # not connected to any node
        self.connections = jnp.append(self.connections, jnp.zeros((len(self.connections), 1)), axis=1)
        self.connections = jnp.append(self.connections, jnp.zeros((1, len(self.connections[0]))), axis=0)
        return None
    
    def remove_node(self, index: int):
        # remove node from list of nodes
        self.nodes.delete(index)
        self.connections = jnp.delete(self.connections, index, axis=0)
        self.connections = jnp.delete(self.connections, index, axis=1)
        return None

    def add_connection(self, node1: Node, node2: Node):
        # TODO : Add connection between connection and node input place
        # get ids like addresses in connection
        node1_id = int(node1.get_id())
        node2_id = int(node2.get_id())

        # add connection
        self.connections = self.connections.at[node1_id, node2_id].set(1)

        # add in and out to node
        # TODO: variable initialisation
        node1.add_output(0.5)
        node2.add_input(0.5)
        return None
    
    def get_neighbors(self, node: Node, degree: int) -> jnp.array:
        # TODO: Add function that removes double entries and check inputs
        # This function gets neighbors of a certain degree using collect_neighbors
        # This function is primarly concerned with converting ids and getting rid of double entries
        neighbors = self.collect_neighbors(node, degree)
        neighbors = jnp.array(neighbors) # list to array
        neighbors = jnp.unique(neighbors) # remove double entries
        return neighbors

    def collect_neighbors(self, node, degree):
        # collects the neighbor of a certain degree of a node
        # degree 1: direct neighbors, degree 2: neighbors of neighbors, ...

        # init empty list to collect neighbors
        neighbors = []

        # recursive search
        if degree <= 1:
            neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
        else:
            cur_neighbors = list(jnp.where(self.connections[node.get_id()] == 1)[0])
            # go over neighbors
            for cur_neighbor in cur_neighbors:
                neighbors.extend(self.collect_neighbors(self.nodes[cur_neighbor], degree - 1))
        return neighbors
    
    def run(self, inputs: jnp.ndarray, input_idx: jnp.ndarray, output_idx: jnp.ndarray):
        # TODO: Use edge list to run nodes
        # TODO: Design method to input data into nodes, not overwriting them
        # run nodes in order
        for i in range(1, len(self.nodes)): # skip first node since it's the input node
            print(inputs)
            print("Node: ", self.nodes[i].get_id())
            print(self.nodes[i].get_in_weights())
            print(self.nodes[i].get_out_weights())
            print(self.nodes[i].get_activation())
            inputs = self.nodes[i].run(inputs)
        return inputs
    
    # functions from cluster
    
    # additional functions for later

    def setupNetwork(self, initialization: str = "simple"):
        """ 
        This function uses diffrent types of algorithms to setup the network
        For now it creates a simple network with a topology similar to normal nns
Planned:
- Random connections
- Evolutionary algorithm
- Load own nodes
        """
        # initialization for in and output
        
        # network initialization
        if initialization == "simple":
            # simple network
            # just a linear connection
            for i in range(len(self.nodes) - 1):
                self.computationTree.add_connection(self.nodes[i], self.nodes[i+1])
        
        
    def run(self, inputs):
        # runs the whole network
        if len(inputs) != len(self.input_nodes):
            raise ValueError("Input size does not match")
        # run network
        output = self.computationTree.run(inputs, self.input_nodes, self.output_nodes)
        return output

    def expand(self):
        """
        Expands the Cluster depending on information density
        and the surrounding nodes
        """
        pass

    def connections(self):
        pass
    
    def get_nodes(self):
        pass


    # debugging functions
    def print_connections(self):
        print(self.connections)
        return None
    
    def print_nodes(self):
        for node in self.nodes:
            node.print_node()
        return None
    
    def print_node_inputs(self, idx: int):
        node = self.nodes[idx]
        print("Node inputs: ", node.get_in_weights())
        return None