# How to create more complex workflows

**Please note, this notebook depends on successful execution of the first notebook `1-aiida-intro.ipynb`!**

In the previous notebook, we have seen how we can run arbitrary executables through `aiida-shell` without requiring any
code-specific infrastructure (typically contained in a dedicated AiiDA plugin). In addition, we have seen how we can
feed the output of one task to the input of another task, linking the two and effectively creating a workflow.

Building on this concept, the `aiida-workgraph` provides the capability to create workflows in the same manner as one would
build up an actual graph. That is, by adding nodes and edges to it. It further extends the possible building blocks for
our workflow from
external scripts (as seen with `aiida-shell`) to other AiiDA buliding blocks (`CalcFunction`s, `CalcJob`s, `WorkChain`s, etc.), as well as custom
Python code.

We'll cover lots of material in this notebook, so strap yourself in and buckle up! :rocket:

To run the following Python cells, we need to make sure that we select the correct kernel `Python3.10 (AIIDA)`. If it is
not already selected, do so as follows:

<img src="../../data/figs/change_notebook_kernel.png" width="500" style="height:auto; display:block; margin-left:auto; margin-right:auto;">

We then load the AiiDA jupyter notebook extension, check the profile status, import the libraries all that we need. So nothing new
here, really...

In [None]:
%load_ext aiida
%aiida

In [None]:
%verdi status

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import time
from IPython.display import Image, display

from aiida import orm
from aiida_shell.parsers import ShellParser
from aiida.tools.visualization import Graph

from aiida_workgraph import WorkGraph, task
from aiida_workgraph.utils import generate_node_graph

In [None]:
diag_code = orm.load_code('diagonalization@localhost')  # The computer label can also be omitted here
query_code = orm.load_code('remote_query@localhost')  # The computer label can also be omitted here
db_path = str(Path('../../data/euro-scipy-2024/diag-wf/remote/matrices.db').resolve())

In [None]:
def provenance_graph(aiida_node):
    graph = Graph()
    graph.recurse_ancestors(aiida_node, annotate_links="both")
    graph.recurse_descendants(aiida_node, annotate_links="both")
    display(graph.graphviz)

## WorkGraph vs. provenance graph

As evident from the import statement:

```python
from aiida_workgraph import WorkGraph, task
```

the first entity we'll be using is, of course, the `WorkGraph`. In addition, we import the `task`, which actually
presents the `WorkGraph` equivalent of a *node* in the graphs we'll be building up.

In line with common graph nomenclature, we'd have loved to use the **Node** keyword for that, but remember, the `Node`
class is already defined in `aiida-core`. To avoid confusion, it is good to mention here, that we will now be talking
about two different kinds of graphs:
- **The provenance graph**: AiiDA's way of storing the **Data** and **Processes** inside the SQL database as **Node**s
  and **Link**
- **The WorkGraph**: The workflow we are building up using the `aiida-workgraph` library

As such, we can build up our workflow as a **WorkGraph**, run it, and AiiDA will store all data in its database, allowing
us to explore the resulting **provenance graph** of our workflow.

Let's maybe best start with some simple examples, this will make things clear. We'll close the cycle to the previous
notebook in a bit.

In [None]:
def sleep_and_print(sleep_time, print_statement):
    time.sleep(sleep_time)
    print(print_statement)

wg = WorkGraph('First WG')

wg.add_task(sleep_and_print)
wg.to_html()

Congratulations, you just created your first `WorkGraph`! Let's unpack the code: We first created a very simple Python
function, we then instantiated the `WorkGraph`, and added our function as a task (remember, think of *graph nodes*).

`aiida-workgraph` comes with a visualization tool in which we can see the setup of our workflow. Note that we didn't
actually run it at this point, yet. Let's add a second task:

In [None]:
wg.add_task(sleep_and_print)
wg.to_html()

We can see that we now have two disconnected tasks in our workgraph. To define dependencies between those, we can either
link inputs and outpus, just as we did before with `aiida-shell`, or explicitly enforce that the second task has to wait
on the first one. For now, let's actually focus on the second case (the first one will require us to introduce a few
more concepts):

In [None]:
wg.tasks.sleep_and_print2.waiting_on.add('sleep_and_print1')
wg.to_html()

In the cell above, we accessed the second task through our `WorkGraph` instance, `wg`. However, the `add_task` function
actually returns the task, so we can also write:

In [None]:
wg = WorkGraph('First WG')

task1 = wg.add_task(sleep_and_print)
task2 = wg.add_task(sleep_and_print)

task2.waiting_on.add('sleep_and_print1')
wg.to_html()

which achieves the same.

## Running Python code with WorkGraph and AiiDA provenance


If we want to actually run our workflow, we should specify some inputs to our tasks. We can do that, as well as name our
tasks like so:

In [None]:
wg = WorkGraph("Run WG")

task_without_provenance = wg.add_task(
    sleep_and_print, name="lets_start",
    sleep_time=1,
    print_statement="Let's start"
)

display(wg.to_html())
wg.run()

Seems like everything worked out smoothly. Now, let's show the provenance graph of our workflow:

In [None]:
provenance_graph(aiida_node=wg)

But where are our tasks? :anguished:

It is important to note here that AiiDA does not store the plain Python function we used to define our tasks in its
database. Remember, the AiiDA classes derived from `Node` implement this functionality, so AiiDA doesn't know how to
store the data in the database.  Thankfully, we can easily resolve this issue by adding AiiDA `orm.Data` types as inputs
to the task, so we access their actual `value`s inside the function 

In [None]:
# example for aiida.orm.Data types
print(orm.Int(1))
print(orm.Int(1).value) # get your int value back

In [None]:
def sleep_and_print_with_provenance(sleep_time, print_statement):
    time.sleep(sleep_time.value)
    print(print_statement.value)

wg = WorkGraph("Provenance restored")

task_with_provenance = wg.add_task(
    sleep_and_print_with_provenance, name="lets_start",
    sleep_time=orm.Int(1), # <-- Note this change
    print_statement=orm.Str("Let's start") # Note this change
)

display(wg.to_html())
wg.run()

In [None]:
provenance_graph(aiida_node=wg)

## On creating, returning, and linking data

Now, if we would like to specify data dependencies, we should define a `task.calcfunction` that actually
returns some output so that we can then link it as an input to another task (before, we were only printing).

The function in the next cell achieves just that. Here, we have manually specified our `outputs` in the decorator, and
we return a clone of the `print_statement`, as returning the actual data node would create a cycle in the graph, which
is forbidden:

In [None]:
@task.calcfunction(
    outputs=[
        {'name': 'result', 'identifier': orm.Str}
    ]
)
def sleep_and_return(sleep_time: orm.Int, print_statement):
    time.sleep(sleep_time.value)
    return {'result': orm.Str(print_statement.value)}


wg = WorkGraph("Linked data")

another_task_with_provenance = wg.add_task(
    sleep_and_return, name="actual_print_task",
    sleep_time=orm.Int(1),
    print_statement=orm.Str("I will print the previous return")
)

display(wg.to_html())
wg.run()

WorkGraph by default does not show the output sockets if they are not linked to other tasks, but we can see it when plotting directly the task.

In [None]:
another_task_with_provenance.to_html()

We can see a number of other sockets workgraph uses in the background. Note that the workgraph uses always _result_ as default output socket for the return value of the function if nothing else is specified. Now lets look at the provenance graph.

In [None]:
provenance_graph(aiida_node=wg)

We can see that now also the output is part of the provenance graph. With this, we can now define (almost) arbitrarily complex workflows, as shown below. Feel free to play around with this!

In [None]:
wg = WorkGraph("Arbitrary WorkGraph")

task1 = wg.add_task(
    sleep_and_print_with_provenance, name="lets_start", sleep_time=orm.Int(1), print_statement=orm.Str("Let's start")
)

task2 = wg.add_task(
    sleep_and_print_with_provenance,
    name="lets_continue",
    sleep_time=orm.Int(1),
    print_statement=orm.Str("Let's continue"),
)

task2.waiting_on.add("lets_start")

task3 = wg.add_task(
    sleep_and_print_with_provenance,
    name="wait_both",
    sleep_time=orm.Int(1),
    print_statement=orm.Str("I need to wait for both"),
)

task3.waiting_on.add("lets_start")
task3.waiting_on.add("lets_continue")

disconnected_task = wg.add_task(
    sleep_and_print_with_provenance,
    name="disconnected_task",
    sleep_time=orm.Int(5),
    print_statement=orm.Str("I have no dependencies, but I am one, and I take my time."),
)

task4 = wg.add_task(
    sleep_and_return,
    name="intermediate_step",
    sleep_time=orm.Int(1),
    print_statement=orm.Str("I will print the previous return."),
)

task4.waiting_on.add("disconnected_task")
task4.waiting_on.add("wait_both")

task5 = wg.add_task(
    sleep_and_print_with_provenance,
    name="final_step",
    sleep_time=orm.Int(1),
    print_statement=task4.outputs["result"],
)

display(wg.to_html())
wg.run()

## Closing the circle: Back to the `aiida-shell` example

Now that we have seen how we can construct simple workflows and define task dependencies with the `WorkGraph`, let's use
it to implement the workflow from the previous notebook. The code snippets in the following cells are rather lengthy,
however, the way we execute the external executable is the same as before, just that we now add a `ShellJob` `task` to
the `WorkGraph` (passing the same arguments as before):

In [None]:
wg = WorkGraph("query_and_diag")

matrix_pk = 5
query_output_filename = f"matrix-{matrix_pk}.npy"

query_task = wg.add_task(
    "ShellJob",
    name="query_task",
    command=query_code,
    arguments=["{db_path}", "{matrix_pk}"],
    nodes={"db_path": db_path, "matrix_pk": orm.Int(matrix_pk)},
    outputs=[query_output_filename],
)

# The file name automatically gets converted into an AiiDA link label by `aiida-shell`
# Link labels can only have alphanumericy characters and underscores, so we apply the same cleaning to the filename
# To be able to reference it later on
query_task_link_label = ShellParser.format_link_label(query_output_filename)

### Attaching a parser

Now that we have run the query task as before, the next step is the diagonalization. However, we might not only want to
write the eigenvalues to an output file, but also parse them, e.g. so that the resulting array is stored
**explicitly** in AiiDA's database (rather than just a reference to the file), and so that we can further operate on it
directly in our Python code. To achieve that, we define a parser function:

In [None]:
def parse_array(self, dirpath: Path) -> dict[str, orm.Data]:
    arr = np.loadtxt(dirpath / self.node.inputs.outputs[0])
    data = orm.ArrayData(arr)
    return {"eigvals": data}

Which we can now pass to our diagonalization task via:

In [None]:
diag_output_filename = f"matrix-{matrix_pk}-eigvals.txt"

diag_task = wg.add_task(
    "ShellJob",
    name="diag_task",
    command=diag_code,
    arguments=["{matrix_file}"],
    nodes={"matrix_file": query_task.outputs[query_task_link_label]},
    outputs=[diag_output_filename],
    # Attach parser here
    parser=parse_array,
    parser_outputs=[{"name": "eigvals"}],
)

display(wg.to_html())
wg.run()

This now allows us to retrieve the eigenvalue outputs directly from the associated AiiDA `Node` attached to the
`WorkGraph` `Task`:

In [None]:
print(diag_task.outputs["eigvals"])
print(diag_task.outputs["eigvals"].value)
print(diag_task.outputs["eigvals"].value.get_array())

Lets go back one more time to see the provenance, now that we are more familiar with the concept.

In [None]:
provenance_graph(wg)

### Extending WorkGraph with arbitrary python code

As we have seen in the simple examples in the beginning of this notebook, we can set up tasks using any Python code.
This is part of what makes AiiDA workflows so powerful. You can do literally anything!

Using Python code as steps of your workflow is also the native way of defining a workflow in AiiDA through writing
`WorkChain`s, and not a feature of the `WorkGraph`. However, the `WorkGraph` provides a simplified interface for
defining workflows.

Let's instantiate a new empty `WorkGraph` and add our previous task, as we did before:

In [None]:
wg = WorkGraph("compute_eigvals_wg")
matrix_pk = 5
query_output_filename = f"matrix-{matrix_pk}.npy"

query_task = wg.add_task(
    "ShellJob",
    name="query_task",
    command=query_code,
    arguments=["{db_path}", "{matrix_pk}"],
    nodes={"db_path": db_path, "matrix_pk": orm.Int(matrix_pk)},
    outputs=[query_output_filename],
)

query_task_link_label = ShellParser.format_link_label(query_output_filename)
diag_output_filename = f"matrix-{matrix_pk}-eigvals.txt"

diag_task = wg.add_task(
    "ShellJob",
    name="diag_task",
    command=diag_code,
    arguments=["{matrix_file}"],
    parser=parse_array,
    nodes={"matrix_file": query_task.outputs[query_task_link_label]},
    outputs=[diag_output_filename],
    parser_outputs=[{"name": "eigvals"}],
)

We now define a `calcfunction` to calculate the mean of the eigenvalues and add it to our `WorkGraph`.

(Remember, a `calcfunction` is an AiiDA process that uses ORM data types, and thus is stored in the database and the
provenance graph)

In [None]:
@task.calcfunction
def compute_mean(eigenvalues: orm.ArrayData):
    return {"result": orm.Float(np.mean(eigenvalues.get_array()))}


mean_task = wg.add_task(
    compute_mean, name="mean_task", eigenvalues=diag_task.outputs["eigvals"]
)
display(wg.to_html())
wg.run()

Again, as before, just passing the undecorated `compute_mean` Python function would, in principle, work, however, no
provenance would be recorded. It is still allowed, as one might want to execute a step in the workflow that should not
be recorded in the provenance.

Workgraph uses outputs sockets that store the property, so we can retrieve the `orm.Float` by taking the `value` of the socket:

In [None]:
print(wg.tasks["mean_task"].outputs["result"]) # output socket result
print(wg.tasks["mean_task"].outputs["result"].value) # resulting orm.Float
print(wg.tasks["mean_task"].outputs["result"].value.value) # resulting orm.Float value

## Combining tasks with the `graph_builder`

As we have seen above, when generating multiple workgraphs with the same steps (e.g. query and diagonalization), we
always need to repeat the code used to add the tasks when we create new `WorkGraph`s. This is quite cumbersome
and will lead to unwanted code repetition. For this purpose, the `aiida-workgraph` provides the `graph_builder`, which
allows one to merge together multiple tasks into one entity, thus enabling the creation of complex, nested `WorkGraph`s. 

The following cell combines the code from the querying, diagonalization, and calculation of the mean into one reusable
`query_diag_mean` entity:

In [None]:
@task.calcfunction
def compute_mean(eigenvalues: orm.ArrayData) -> dict[str, orm.Data]:
    eigenvalues_arr = eigenvalues.get_array()
    node = orm.Float(np.mean(eigenvalues_arr))
    node.attributes["length"] = len(
        eigenvalues_arr
    )  # Note this change, we will discuss this later
    return node


@task.graph_builder(
    outputs=[
        {
            "name": "eigvals",
            "from": "diag_task.eigvals",
        },  # exposes output `eigvals` of task diag_task under the name `eigvals`
        {
            "name": "mean_eigval",
            "from": "mean_task.result",
        },  # exposes output `result` of task mean_task under the name `mean_eigval`
    ]
)
def query_diag_mean(matrix_pk: orm.Int):
    global db_path
    wg = WorkGraph()
    query_output_filename = f"matrix-{matrix_pk.value}.npy"

    query_code = orm.load_code("remote_query@localhost")
    query_task = wg.add_task(
        "ShellJob",
        name="query_task",
        command=query_code,
        arguments=["{db_path}", "{matrix_pk}"],
        nodes={
            "db_path": db_path,
            "matrix_pk": matrix_pk,
        },
        outputs=[query_output_filename],
    )

    query_task_link_label = ShellParser.format_link_label(query_output_filename)
    diag_code = orm.load_code("diagonalization@localhost")
    diag_output_filename = f"matrix-{matrix_pk.value}-eigvals.txt"

    diag_task = wg.add_task(
        "ShellJob",
        name="diag_task",
        command=diag_code,
        arguments=["{matrix_file}"],
        nodes={"matrix_file": query_task.outputs[query_task_link_label]},
        outputs=[diag_output_filename],
        parser=parse_array,
        parser_outputs=[{"name": "eigvals"}],
    )

    wg.add_task(
        compute_mean, name="mean_task", eigenvalues=diag_task.outputs["eigvals"]
    )

    return wg


wg = WorkGraph()
# Here we add the collection of tasks defined previously via the `graph_builder` as one single task
query_diag_mean_task = wg.add_task(query_diag_mean, name="query_diag_mean")
wg.to_html()

We can see that the three tasks are now encapsulated into one step. When we have a look at the task we can see that the specifed outputs are also exposed.

In [None]:
query_diag_mean_task.to_html()

We can now use this `WorkGraph` task to run it in a for loop.

In [None]:
wg = WorkGraph("query_diag_mean_wg")
for i in range(5):
    query_diag_mean_task = wg.add_task(
        query_diag_mean, name=f"query_diag_mean_pk{i}", matrix_pk=orm.Int(i)
    )
display(wg.to_html())
wg.run()

While the computation is running you might want to have look at `watch -n 1 "verdi process list"` in another terminal or
notebook to see the calculations running.

## Aggregating results

We have now seen how we can use the `graph_builder` to generate reusable multi-step `WorkGraph` components. Assume we
would now like to aggregate all results from running this component various times, and e.g. create a plot from the results:

In [None]:
@task.calcfunction
def aggregate_to_plot(
    **collected_mean_eigvals: dict[str, orm.Float]
) -> dict[str, orm.Data]:
    fig, ax = plt.subplots(figsize=(8, 6))
    collected_mean_eigenvalues_list = [
        orm_float.value for orm_float in collected_mean_eigvals.values()
    ]
    ax.hist(collected_mean_eigenvalues_list, bins=10, color="c", edgecolor="black")
    ax.set_title("Histogram of Eigenvalues")
    ax.set_xlabel("Eigenvalue")
    filename = "plot.jpg"
    plt.legend()
    plt.savefig(filename)
    plt.close(fig)
    return orm.SinglefileData(Path(filename).absolute())

In [None]:
wg = WorkGraph("plot_mean_eigenvalues")

aggregate_to_plot_task = wg.add_task(aggregate_to_plot, name="aggregate_to_plot_task")

# we have to increase the link limit because by default workgraph only supports one link per input socket
max_query_pk = 3
aggregate_to_plot_task.inputs["collected_mean_eigvals"].link_limit = max_query_pk

for i in range(max_query_pk):
    query_diag_mean_task = wg.add_task(
        query_diag_mean, name=f"query_diag_mean_pk{i}", matrix_pk=orm.Int(i)
    )
    # We create a `link` between the output of the `query_diag_mean` task and the input of the aggregation and plot task
    wg.add_link(
        query_diag_mean_task.outputs["mean_eigval"],
        aggregate_to_plot_task.inputs["collected_mean_eigvals"],
    )
display(wg.to_html())
wg.run()

### Interlude: QueryBuilder with node attributes

A little detail you might have spotted: When defining the `query_diag_mean` `graph_builder`, we actually added an
attribute to the `orm.Float` entity that attaches the mean eigenvalue.

```python
@task.calcfunction
def compute_mean(eigenvalues: orm.ArrayData) -> dict[str, orm.Data]:
    eigenvalues_arr = eigenvalues.get_array()
    node = orm.Float(np.mean(eigenvalues_arr))
    node.attributes["length"] = len(eigenvalues_arr) # Note this change, we will discuss this later
    return node
```

We will now show how to query for this attribute after we collected several calculations. As we have seen in the
previous notebook, a regular query for the a specific data type can be constructed like this:

In [None]:
qb = orm.QueryBuilder()

qb.append(orm.Float)

print("Number of entries: ", len(qb.all()))
print("First entry: ", qb.first())

By default, we get a list of lists because we can `project` different properties of the object for retrieval:

In [None]:
qb = orm.QueryBuilder()
qb.append(
    orm.Float,
    project=['uuid', 'attributes.value', 'attributes.length']
)
print("Number of entries: ", len(qb.all()))
print("First entry value and length: ", qb.first())

Since in the previous calculations we added the attribute length we now filter for certain lengths:

In [None]:
qb = orm.QueryBuilder()
qb.append(
    orm.Float,
    filters=orm.Float.fields.attributes["length"].in_(
        [49, 50, 51]
    ),  # This is the attribute we have set in the `compute_mean` calcfunction
    project=["attributes.value", "attributes.length"],
)
print("Number of entries: ", len(qb.all()))
print("First entry value and length: ", qb.first())

And also add logical operations like filtering below a specific value:

In [None]:
qb = orm.QueryBuilder()
qb.append(
    orm.Float,
    filters=(
        (orm.Float.fields.attributes["length"].in_([49, 50, 51]))
        & (orm.Float.fields.value < 14.5)
    ),
    project=["attributes.value", "attributes.length"],
)
print("Number of entries: ", len(qb.all()))
print("First entry value and length: ", qb.first())

This concludes our little interlude on AiiDA's `QueryBuilder`. We have introduced some concepts that are necessary
for the next and final section of this tutorial notebook. Just keep in mind that the `QueryBuilder` is an extremely
powerful tool that lets you construct (almost) arbitrarily complex queries on your database. More information can be
found in the documentation, e.g.
[here](https://aiida.readthedocs.io/projects/aiida-core/en/latest/howto/query.html#how-to-find-and-query-for-data).

## How can we incorporate if conditions in `WorkGraph` workflows?

An if condition changes the type of task that is executed and the type of output that is passed through the upcoming
tasks and therefore needs additional logic in the workflow manager to be handled properly. We take an example from
material science where we are often interested in structures that correspond to very low eigenvalues as these structures
are more stable (there are more subtleties we ignore for the sake of simplicity). Let us filter out the matrices with an
eigenvalue below a threshold of 14.5 and incorporate this selection into the workflow.

In [None]:
# This is the condition task that will be used in the if task
@task.calcfunction
def eigvals_less(mean_eigval: orm.Float) -> bool:
    return mean_eigval < 14.5

# When we found a right candidate we can celebrate
@task.calcfunction
def heureka(eigvals, pk):
    print("Heureka we found a new stable material, lets publish in Nature!")

In [None]:
wg = WorkGraph("matrix_discovery")

for matrix_pk in [1, 5]:
    query_diag_mean_task = wg.add_task(
        query_diag_mean,
        name=f"query_diag_mean_pk{matrix_pk}",
        matrix_pk=orm.Int(matrix_pk),
    )
    eigvals_less_task = wg.add_task(
        eigvals_less,
        name=f"eigvals_less_task_pk{matrix_pk}",
        mean_eigval=query_diag_mean_task.outputs["mean_eigval"],
    )
    if_less = wg.add_task(
        "If",  # Note that this is an identifier that marks this to be an If task
        name=f"if_less_pk{matrix_pk}",
        conditions=eigvals_less_task.outputs["result"],  # An If task has this attribute
    )

    heureka_task = wg.add_task(
        heureka,
        name=f"heureka_task_pk{matrix_pk}",
        eigvals=query_diag_mean_task.outputs["eigvals"],
        pk=orm.Int(i),
    )
    if_less.children.add(
        f"heureka_task_pk{matrix_pk}"
    )  # this adds the task to the if condition

    # To create at task for the case the condition is false, one can use `invert_condition=True`
    if_greater_equal = wg.add_task(
        "If",
        name=f"if_greater_equal_pk{matrix_pk}",
        conditions=eigvals_less_task.outputs["result"],
        invert_condition=True
    )

display(wg.to_html())
wg.run()

We can see that for `PK<1>` the if condition was fulfilled while for `PK<5>` it was not so the heureka task was not executed.

In [None]:
for matrix_pk in [1, 5]:
    print(
        f"Mean eigval for martix PK <{matrix_pk}>: ",
        wg.tasks[f"query_diag_mean_pk{matrix_pk}"].outputs["mean_eigval"].value.value,
    )

It should be mentioned that, while the `WorkGraph` provides a simplified interface for writing AiiDA workflows, for very
complex setups, writing a `WorkChain` class with arbitrary Python code might actually be the preferred approach.
Always evaluate your requirements!

If you've made it this far, congratulations! You can now call yourself a `WorkGraph` expert.

Just like we did for the hero run, you have thus acquired the power to fill your entire HPC cluster with AiiDA workflows
that keep full provenance of your data (but please don't ;)