In [1]:
from aiida import load_profile
from aiida.orm import Int
from aiida_workgraph import WorkGraph
from aiida.engine import calcfunction
from aiida_workgraph.decorator import build_task_from_callable

In [2]:
from inspect import isfunction

In [3]:
def get_kwargs(lst):
    return {t['targetHandle']: {'source': t['source'], 'sourceHandle': t['sourceHandle']} for t in lst}

In [4]:
def wrap_function(func, **kwargs):
    # First, apply the calcfunction decorator
    func_decorated = calcfunction(func)
    # Then, apply task decorator
    task_decorated = build_task_from_callable(
        func_decorated,
        inputs=kwargs.get("inputs", []),
        outputs=kwargs.get("outputs", []),
    )
    identifier = kwargs.get("identifier", None)
    func_decorated.identifier = identifier if identifier else func.__name__
    func_decorated.task = func_decorated.node = task_decorated
    return func_decorated

In [5]:
def group_edges(edges_lst):
    # edges_sorted_lst = sorted(edges_lst, key=lambda x: x['target'], reverse=True)     
    total_dict = {}
    tmp_lst = []
    target_id = edges_lst[0]['target'] 
    for ed in edges_lst:
        if target_id == ed["target"]:
            tmp_lst.append(ed)
        else:
            total_dict[target_id] = get_kwargs(lst=tmp_lst)
            target_id = ed["target"]
            tmp_lst = [ed]
    total_dict[target_id] = get_kwargs(lst=tmp_lst)
    return total_dict

In [6]:
def get_output_labels(edges_lst):
    output_label_dict = {}
    for ed in edges_lst:
        if ed['sourceHandle'] is not None:
            if ed["source"] not in output_label_dict.keys():
                output_label_dict[ed["source"]] = []
            output_label_dict[ed["source"]].append(ed['sourceHandle'])
    return output_label_dict

In [7]:
def get_function_dict(nodes_dict, output_label_dict):
    function_dict = {}
    for k, v in nodes_dict.items():
        if isfunction(v):
            if k in output_label_dict.keys():
                kwargs = {"outputs": [{"name": label} for label in output_label_dict[k]]}
                function_dict[k] = wrap_function(func=v, **kwargs)
            else: 
                function_dict[k] = wrap_function(func=v)
    
    return function_dict

In [8]:
def build_workgraph(function_dict, total_dict, nodes_dict, label="my_workflow"):
    wg = WorkGraph(label)
    mapping_dict = {}
    for k, v in function_dict.items():
        name = v.__name__
        mapping_dict[k] = name
        total_item_dict = total_dict[k]
        kwargs = {}
        for tk, tv in total_item_dict.items():
            if tv['source'] in mapping_dict.keys():
                kwargs[tk] = wg.tasks[mapping_dict[tv['source']]].outputs[tv['sourceHandle']]
            elif tv['sourceHandle'] is None:
                kwargs[tk] = nodes_dict[tv['source']]
            else:
                raise ValueError()
        wg.add_task(v, name=name, **kwargs)
    return wg

In [9]:
def add_x_and_y(x, y):
    c = x + y
    a = x * 1.0
    b = y * 1.0
    return {"a": a, "b": b, "c": c}

In [10]:
def add_x_and_y_and_z(l, m, n):
    w = l + m + n
    return w

In [11]:
edges_lst = [
    {'target': 1, 'targetHandle': 'l', 'source': 0, 'sourceHandle': 'c'},
    {'target': 1, 'targetHandle': 'm', 'source': 0, 'sourceHandle': 'a'},
    {'target': 1, 'targetHandle': 'n', 'source': 0, 'sourceHandle': 'b'},
    {'target': 0, 'targetHandle': 'x', 'source': 2, 'sourceHandle': None},
    {'target': 0, 'targetHandle': 'y', 'source': 3, 'sourceHandle': None},
]

In [12]:
nodes_dict = {
    0: add_x_and_y,
    1: add_x_and_y_and_z,
    2: 1,
    3: 2,
}

In [13]:
output_label_dict = get_output_labels(edges_lst)
output_label_dict

{0: ['c', 'a', 'b']}

In [14]:
total_dict = group_edges(edges_lst=edges_lst)
total_dict

{1: {'l': {'source': 0, 'sourceHandle': 'c'},
  'm': {'source': 0, 'sourceHandle': 'a'},
  'n': {'source': 0, 'sourceHandle': 'b'}},
 0: {'x': {'source': 2, 'sourceHandle': None},
  'y': {'source': 3, 'sourceHandle': None}}}

In [15]:
function_dict = get_function_dict(nodes_dict=nodes_dict, output_label_dict=output_label_dict)
function_dict 

{0: <function __main__.add_x_and_y(x, y)>,
 1: <function __main__.add_x_and_y_and_z(l, m, n)>}

In [16]:
load_profile()

Profile<uuid='7bb8761123324468bb98821cbb757251' name='presto'>

In [17]:
wg = build_workgraph(function_dict=function_dict, total_dict=total_dict, nodes_dict=nodes_dict, label="my_workflow")
wg

NodeGraphWidget(settings={'minimap': True}, style={'width': '90%', 'height': '600px'}, value={'name': 'my_work…

In [18]:
wg.run()

01/17/2025 06:25:49 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: Continue workgraph.
01/17/2025 06:25:49 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: tasks ready to run: add_x_and_y
01/17/2025 06:25:49 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|run_tasks]: Run task: add_x_and_y, type: CALCFUNCTION


------------------------------------------------------------
kwargs:  {'x': 1, 'y': 2}


01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|update_task_state]: Task: add_x_and_y finished.
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: Continue workgraph.
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: tasks ready to run: add_x_and_y_and_z
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|run_tasks]: Run task: add_x_and_y_and_z, type: CALCFUNCTION


------------------------------------------------------------
kwargs:  {'l': <Int: uuid: 96a6024b-7d8c-413d-899f-7274e4f8b7e7 (pk: 331) value: 3>, 'm': <Float: uuid: 89a98d17-178c-4fd9-a53f-cc3c7505aa79 (pk: 329) value: 1.0>, 'n': <Float: uuid: a92ee4c9-bb9d-4f59-a24d-a95bdd64655f (pk: 330) value: 2.0>}


01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|update_task_state]: Task: add_x_and_y_and_z finished.
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: Continue workgraph.
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|continue_workgraph]: tasks ready to run: 
01/17/2025 06:25:50 PM <121241> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [325|WorkGraphEngine|finalize]: Finalize workgraph.


{'execution_count': <Int: uuid: 256c990b-b752-4bd6-b1c4-1e877dfc4946 (pk: 334) value: 1>}

In [19]:
wg.to_dict()['tasks']['add_x_and_y_and_z']['outputs']['result']['property']['value']

<Float: uuid: cf10cf10-df64-4654-94db-6b81805ff254 (pk: 333) value: 6.0>