In [94]:
def get_keys_from_values(dic: dict, values: list):
    key_lst = list(dic.keys())
    val_lst = list(dic.values())
    return [key_lst[val_lst.index(val)] for val in values]

In [165]:
class FunctionObj:
    def __init__(self, func: callable, input_types: list[type], output_types: list[type]):
        self.func = func
        self.input_types = input_types
        self.output_types = output_types
        self.validate()

    def validate(self):
        #TODO: implement.
        pass

    def run(self, *args):
        return self.func(*args)

class ObjNode:
    def __init__(self, obj_type: type):
        self.obj_type: type = obj_type
        self.func_children: list[tuple[FuncNode,int]] = []
        self.func_parent: FuncNode = None
        self.data = None
        self.comp_graph_level = 0
    
    def update_data(self, data):
        self.data = data

    def contains_data(self):
        return self.data is not None


class FuncNode:
    def __init__(self, func_obj: FunctionObj):
        self.func_obj: FunctionObj = func_obj
        self.input_obj_nodes: list[ObjNode] = [None for i in range(len(self.func_obj.input_types))]
        self.output_obj_nodes: list[ObjNode] = [ObjNode(obj_type) for obj_type in self.func_obj.output_types]
        self.comp_graph_level = 0

    def connect_input(self, input_index: int, obj_node: ObjNode):
        if self.input_obj_nodes[input_index] == obj_node:
            raise Exception('node already connected to the same input index!')
        self.input_obj_nodes[input_index] = obj_node
        obj_node.func_children.append((self, input_index))
    
    def disconnect_input(self, input_index):
        if self.input_obj_nodes[input_index] == None:
            return
        input_obj = self.input_obj_nodes[input_index]
        input_obj.func_children.remove((self, input_index))
        self.input_obj_nodes[input_index] = None

    def run(self):
        input_data = [obj.data for obj in self.input_obj_nodes]
        out_data = self.func_obj.run(*input_data)
        for i, obj in enumerate(self.output_obj_nodes):
            obj.update_data(out_data[i])
        
    def update_level(self):
        self.comp_graph_level = max([node.comp_graph_level for node in self.input_obj_nodes if node is not None], default=-1) + 1

        for output_node in self.output_obj_nodes:
            output_node.comp_graph_level = self.comp_graph_level + 1
    

class ComputationalNetwork:
    # TODO: Write top-sort.
    # write __str__, __repr__, etc. for all classes
    
    def __init__(self):
        self.input_obj_nodes = dict()  # Stores input nodes with their unique IDs or references
        self.obj_nodes = dict()        # Stores all ObjNodes with their unique IDs or references
        self.func_nodes = dict()       # Stores all FuncNodes with their unique IDs or references

    def add_func_node(self, func_obj: FunctionObj, input_node_ids: list[int]):
        func_node = FuncNode(func_obj)

        # Connect input nodes
        for i, id in enumerate(input_node_ids):
            func_node.connect_input(i, self.obj_nodes[id])

        # Generate a unique ID for the function node
        func_node_id = max(list(self.func_nodes.keys()) + list(self.obj_nodes.keys()), default=0) + 1
        self.func_nodes[func_node_id] = func_node
        
        # Add output nodes to the obj_nodes dictionary
        for i, output_node in enumerate(func_node.output_obj_nodes):
            obj_node_id = func_node_id + i + 1
            self.obj_nodes[obj_node_id] = output_node
        
        func_node.update_level()
        return func_node

    def add_input_obj_node(self, obj_type: type):
        node = ObjNode(obj_type)
        node_id = max(list(self.obj_nodes.keys())+list(self.func_nodes.keys()), default=0) + 1
        self.input_obj_nodes[node_id] = node
        self.obj_nodes[node_id] = node  # Also include it in the general obj_nodes dictionary
        return node
    
    def clean_data(self):
        for obj in self.obj_nodes.values():
            obj.data = None

    def run(self, inputs: dict):
        # Assign initial values to input nodes
        for node_id, value in inputs.items():
            if node_id not in self.input_obj_nodes:
                raise ValueError("Provided node ID is not a valid input node.")
            self.input_obj_nodes[node_id].update_data(value)
        
        for func_node in sorted(self.func_nodes.values(), key=lambda func_node: func_node.comp_graph_level):
            func_node.run()
        
        return {node_id: node.data for node_id, node in self.obj_nodes.items() if node.contains_data()}
    
    def run_remaining(self):
        pass

    def get_all_nodes(self):
        nodes_dict = self.func_nodes.copy()
        nodes_dict.update(self.obj_nodes)
        return nodes_dict
    
    def delete_func_node(self, id_number):
        if id_number not in self.func_nodes:
            return
            #raise IndexError("function node doesn't exist.")
        func_node: FuncNode = self.func_nodes[id_number]
        for index in range(len(func_node.input_obj_nodes)):
            func_node.disconnect_input(index)

        #TODO: Rewrite this shit
        def get_keys_from_values(dic: dict, values: list):
            key_lst = list(dic.keys())
            val_lst = list(dic.values())
            return [key_lst[val_lst.index(val)] for val in values]
        
        obj_children_ids = get_keys_from_values(self.obj_nodes, func_node.output_obj_nodes)
        for obj_child_id in obj_children_ids:
            obj: ObjNode = self.obj_nodes[obj_child_id]
            func_children_ids = get_keys_from_values(self.func_nodes, [p[0] for p in obj.func_children])
            for child_func_node_id in func_children_ids:
                self.delete_func_node(child_func_node_id)

            del self.obj_nodes[obj_child_id]
        del self.func_nodes[id_number]
    
    def get_dependent_func_ids(self, func_id_number):
        pass

    def get_dic_structure_schematic(self):
        dic = {}
        for obj_id, obj in self.obj_nodes.items():
            dic[obj_id] = get_keys_from_values(self.func_nodes, [f for f, id in obj.func_children])

        for func_id, func in self.func_nodes.items():
            dic[func_id] = get_keys_from_values(self.obj_nodes, func.output_obj_nodes)
        
        return dic


In [167]:

network = ComputationalNetwork()
func1 = FunctionObj(lambda x: [x**2], input_types=[int], output_types=[int])
func2 = FunctionObj(lambda x,y: [x+y], input_types=[int, int], output_types=[int])
func3 = FunctionObj(lambda x,y: [x+y, x-y], input_types=[int, int], output_types=[int,int])
network.add_input_obj_node(int)
network.input_obj_nodes
network.add_func_node(func1,[1])
network.add_func_node(func2, [1,3])
network.add_func_node(func3, [3,5])
network.add_func_node(func1, [5])
print(network.get_dic_structure_schematic())
print(network.run({1:2}))
#print('\n'.join([str((x, x.comp_graph_level)) for x in list(network.get_all_nodes().values())]))
network.delete_func_node(4)
print(network.get_dic_structure_schematic())

{1: [2, 4], 3: [4, 6], 5: [6, 9], 7: [], 8: [], 10: [], 2: [3], 4: [5], 6: [7, 8], 9: [10]}
{1: 2, 3: 4, 5: 6, 7: 10, 8: -2, 10: 36}
{1: [2], 3: [], 2: [3]}


In [162]:
network.func_nodes

{2: <__main__.FuncNode at 0x115f3b070>,
 4: <__main__.FuncNode at 0x115f38340>,
 6: <__main__.FuncNode at 0x115f38b80>,
 9: <__main__.FuncNode at 0x115f31b10>}