# How to create more complex workflows

To run the following Python cells, we need to make sure that we select the correct kernel `Python3.10 (AIIDA)`. If it is
not already selected, do so as follows:

<img src="../../data/figs/change_notebook_kernel.png" width="500" style="height:auto; display:block; margin-left:auto; margin-right:auto;">

In [5]:
%load_ext aiida
%aiida

And verify that the profile was created successfully via:

In [7]:
%verdi status

[32m[22m ✔ [0m[22mversion:     AiiDA v2.6.2[0m
[32m[22m ✔ [0m[22mconfig:      /Users/alexgo/.aiida[0m
[32m[22m ✔ [0m[22mprofile:     presto-3[0m
[32m[22m ✔ [0m[22mstorage:     SqliteDosStorage[/Users/alexgo/.aiida/repository/sqlite_dos_a131f6ed7221480fae581f300190e67b]: open,[0m
[32m[22m ✔ [0m[22mbroker:      RabbitMQ v3.13.6 @ amqp://guest:guest@127.0.0.1:5672?heartbeat=600[0m
[32m[22m ✔ [0m[22mdaemon:      Daemon is running with PID 26513[0m


  "cipher": algorithms.TripleDES,
  "class": algorithms.TripleDES,


***
## Concatenating several scripts to one workflow and more :)

In [8]:
from aiida_workgraph import WorkGraph, task
from aiida_shell.parsers import ShellParser
import pathlib
from aiida import orm
import pathlib
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from aiida import orm, engine
from aiida.common.exceptions import NotExistent

In [13]:
diag_code = orm.load_code('diagonalization@localhost')  # The computer label can also be omitted here
query_code = orm.load_code('remote_query@localhost')  # The computer label can also be omitted here
db_path = str(Path('../../data/euro-scipy-2024/diag-wf/remote/matrices.db').resolve())

In [16]:
from aiida_workgraph import WorkGraph
from aiida.orm import SinglefileData
from aiida_shell.parsers import ShellParser


wg = WorkGraph("query_and_diag")

matrix_pk = 5
query_output_filename = f"matrix-{matrix_pk}.npy"
query_task = wg.add_task("ShellJob", name="query_task",
                    command=query_code,
                    arguments=["{db_path}", "{matrix_pk}"],
                    nodes={
                        'db_path': db_path,
                        'matrix_pk': orm.Int(matrix_pk)},
                    outputs=[query_output_filename]
                )

query_task_link_label = ShellParser.format_link_label(query_output_filename)
diag_output_filename = f"matrix-{matrix_pk}-eigvals.txt"

def parse_array(self, dirpath: pathlib.Path) -> dict[str, orm.Data]:
    arr = np.loadtxt(dirpath / self.node.inputs.outputs[0])
    data = orm.ArrayData(arr)
    return {'eigvals': data}

diag_task = wg.add_task("ShellJob", name="diag_task",
                    command=diag_code,
                    arguments=["{matrix_file}"],
                    parser=parse_array,
                    nodes={
                        'matrix_file': query_task.outputs[query_task_link_label]
                    },
                    outputs=[diag_output_filename],
                    parser_outputs=[{"name": "eigvals"}],
                )
diag_task_link_label = ShellParser.format_link_label(diag_output_filename)

display(wg)
wg.run()

NodeGraphWidget(settings={'minimap': True}, style={'width': '90%', 'height': '600px'}, value={'name': 'query_a…

update task state:  query_task
update task state:  diag_task
Continue workgraph.
[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: query_task


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: query_task


------------------------------------------------------------
[34m[1mReport[0m: [6242|WorkGraphEngine|run_tasks]: Run task: query_task, type: SHELLJOB


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|run_tasks]: Run task: query_task, type: SHELLJOB


Task  type: ShellJob.
task:  query_task RUNNING
task:  diag_task PLANNED
is workgraph finished:  False
[34m[1mReport[0m: [6242|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 6247


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 6247


Continue workgraph.
[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: diag_task


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: diag_task


------------------------------------------------------------
[34m[1mReport[0m: [6242|WorkGraphEngine|run_tasks]: Run task: diag_task, type: SHELLJOB


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|run_tasks]: Run task: diag_task, type: SHELLJOB


Task  type: ShellJob.
task:  query_task FINISHED
task:  diag_task RUNNING
is workgraph finished:  False
[34m[1mReport[0m: [6242|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 6257


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|on_wait]: Process status: Waiting for child processes: 6257






Continue workgraph.
[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: Continue workgraph.


[34m[1mReport[0m: [6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: 


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|continue_workgraph]: tasks ready to run: 


task:  query_task FINISHED
task:  diag_task FAILED
is workgraph finished:  True
[34m[1mReport[0m: [6242|WorkGraphEngine|is_workgraph_finished]: WorkGraph finished, but tasks: ['diag_task'] failed. Thus all their child tasks are skipped.


REPORT:aiida.orm.nodes.process.workflow.workchain.WorkChainNode:[6242|WorkGraphEngine|is_workgraph_finished]: WorkGraph finished, but tasks: ['diag_task'] failed. Thus all their child tasks are skipped.


{}

### This how you retrieve your outputs after a run with WorkGraph

In [134]:
# TODO
diag_task.node.outputs['aiida_shell_5_eigvals_txt']
diag_task.node.outputs.eigvals.get_array()

array([6.41356973e+02, 4.45374108e+00, 4.45374108e+00, 4.33318476e+00,
       3.97143598e+00, 3.97143598e+00, 3.20457100e+00, 3.20457100e+00,
       3.83870167e+00, 3.83870167e+00, 3.65871409e+00, 3.65871409e+00,
       2.89638174e+00, 2.89638174e+00, 3.44866185e+00, 3.44866185e+00,
       2.99916314e+00, 2.99916314e+00, 2.80708274e+00, 2.80708274e+00,
       2.61687224e+00, 2.61687224e+00, 2.41524706e+00, 2.41524706e+00,
       2.02885708e+00, 2.02885708e+00, 1.66872021e+00, 1.66872021e+00,
       1.59257569e+00, 1.59257569e+00, 1.43553627e+00, 1.43553627e+00,
       1.47411659e+00, 1.47411659e+00, 1.09391969e+00, 1.09391969e+00,
       8.50524615e-01, 8.50524615e-01, 7.73023747e-01, 7.73023747e-01,
       4.61000901e-01, 4.61000901e-01, 2.44442245e-01, 4.65158372e-01,
       4.65158372e-01, 8.12587183e-02, 5.74112503e-03, 1.40393919e-01,
       5.43328987e-02, 9.88294716e-02])

## Extending WorkGraph with arbitrary python code

In [32]:
import matplotlib.pyplot as plt
from pathlib import Path

wg = WorkGraph("compute_eigvals_wg")
matrix_pk = 5
query_output_filename = f"matrix-{matrix_pk}.npy"
query_task = wg.add_task("ShellJob", name="query_task",
                    command=query_task,
                    arguments=["{db_path}", "{matrix_pk}"],
                    nodes={
                        'db_path': db_path,
                        'matrix_pk': orm.Int(matrix_pk)},
                    outputs=[query_output_filename]
                )

query_task_link_label = ShellParser.format_link_label(query_output_filename)
diag_output_filename = f"matrix-{matrix_pk}-eigvals.txt"

def parse_array(self, dirpath: pathlib.Path) -> dict[str, orm.Data]:
    global diag_output_filename
    arr = np.loadtxt(dirpath / diag_output_filename)
    return {'eigvals': orm.ArrayData(arr)}

diag_task = wg.add_task("ShellJob", name="diag_task",
                    command=diag_task,
                    arguments=["{matrix_file}"],
                    parser=parse_array,
                    nodes={
                        'matrix_file': query_task.outputs[query_task_link_label]
                    },
                    outputs=[diag_output_filename],
                    parser_outputs=[{"name": "eigvals"}],
                )

diag_task_link_label = ShellParser.format_link_label(diag_output_filename)

# Why do you have to wrap your function? So aiida understands your function
# Try comment out the code and look at the provenance graph
@task.calcfunction
def compute_mean(eigenvalues: orm.ArrayData) -> dict[str, orm.Data]:
    return orm.Float(np.mean(eigenvalues.get_array()))

plot_task = wg.add_task(plot, name="plot_task", eigenvalues=diag_task.outputs["eigvals"])

wg.run()

Exception: "__deepcopy__" is not in the TaskCollection.
Acceptable names are ['query_task', 'diag_task', 'plot_task']. This collection belongs to NodeGraph(name="compute_eigvals_wg, uuid="50b35c7e-6147-11ef-80b4-86467fc6ae57")
.

#### We can see that the `compute_mean` result (the orm.Float) is not present in the provenance graph when we remove the calcfunction decorator because it is not stored in the database.

In [None]:
from aiida_workgraph.utils import generate_node_graph
generate_node_graph(wg.pk)

#### We can display the image in a similar way by 

In [None]:
import matplotlib.pyplot as plt
from IPython.display import Image, display

print(wg.tasks["plot_task"].outputs["result"].value) # SinglefileData
with wg.tasks["plot_task"].outputs["result"].value.as_path() as filepath:
    display(Image(filename=(filepath)))

In [None]:
def array_parser(self: Parser, dirpath: pathlib.Path) -> dict[str, orm.Data]:
    arr = np.loadtxt(dirpath / node.inputs.outputs[0]) # this is small aiida detail
    data = orm.ArrayData(arr)
    data.attributes['length'] = len(arr)
    return {'eigvals': data}

@task.graph_builder(outputs = [{"name": "eigvals", "from": "diag_task.eigvals"}
                               {"name": "mean_eigval", "from": "compute_mean.result"}])
def query_and_diag(matrix_pk: orm.Int):
    wg = WorkGraph()
    query_output_filename = f"matrix-{matrix_pk.value}.npy"
        
    query_code = orm.load_code(f'query@localhost')
    query_task = wg.add_task("ShellJob", name="query_task",
                        command=query_code,
                        arguments=["{db_path}", "{matrix_pk}"],
                        nodes={
                            'db_path': "/Users/alexgo/code/fair-workflows-workshop/data/euro-scipy-2024/diag-wf/remote/matrices.db",
                            'matrix_pk': matrix_pk},
                        outputs=[query_output_filename]
                    )
    query_task_link_label = ShellParser.format_link_label(query_output_filename)
    diag_output_filename = f"matrix-{matrix_pk.value}-eigvals.txt"

    diag_code = orm.load_code(f'diag@localhost')
    wg.add_task("ShellJob", name="diag_task",
                        command=diag_code,
                        arguments=["{matrix_file}"],
                        parser=array_parser,
                        nodes={
                            'matrix_file': query_task.outputs[query_task_link_label]
                        },
                        outputs=[diag_output_filename],
                        parser_outputs=[{"name": "eigvals"}],
                    )
    @task.calcfunction
    def compute_mean(eigenvalues: orm.ArrayData) -> dict[str, orm.Data]:
        node = orm.Float(np.mean(eigenvalues.get_array()))
        node.attributes["length"] = len(eigenvalues)
        return node

    # TODO add compute_mean

    return wg

wg = WorkGraph()
wg.add_task(query_and_diag)
wg

In [None]:
wg = WorkGraph("processing_data")
for i in range(1,5):
    query_and_diag_task = wg.add_task(query_and_diag, name=f"query_and_diag_pk{i}", matrix_pk=orm.Int(i))
display(wg)
wg.run()

### We want to collect all the results and plot them

In [None]:
# TODO you don't have to compute anymore the mean value because it is exposed by the graph_builder
@task.calcfunction
def assemble_plot(**collected_eigvals) -> dict[str, orm.Data]:
    #return orm.List([arr.get_array() for arr in x.values()])
    fig, ax = plt.subplots(figsize=(8, 6))
    label: str
    eigval_data: orm.ArrayData
    mean_eigenvalues = []
    for _, eigval_data in collected_eigvals.items():
        mean_eigenvalues.append(np.mean(eigval_data.get_array()))
    ax.hist(mean_eigenvalues, bins=10, color="c", edgecolor="black")
    ax.set_title("Histogram of Eigenvalues")
    ax.set_xlabel("Eigenvalue")
    filename = "plot.jpg"
    plt.legend()
    plt.savefig(filename)
    plt.close(fig)
    return orm.SinglefileData(Path(filename).absolute())


In [None]:
wg = WorkGraph("processing_data")
assemble_plot_task = wg.add_task(assemble_plot, name="assemble_plot_task")
# we have to increase the link limit because by default workgraph only supports one link per input socket
assemble_plot_task.inputs["collected_eigvals"].link_limit = 50
for i in range(1,10):
    query_and_diag_task = wg.add_task(query_and_diag, name=f"query_and_diag_pk{i}", matrix_pk=orm.Int(i))
    wg.add_link(query_and_diag_task.outputs["eigvals"], assemble_plot_task.inputs["collected_eigvals"])
display(wg)
wg.run()

## QueryBuilder

### We can query now from our results

In [None]:
# TODO QueryBuilder expand plot something
qb = QueryBuilder()
qb.append(
    orm.ArrayData,
    project=['attributes.array|default']
)
qb.all(flat=True)


### We can also introduce filters in our queriesS

In [None]:
# TODO QueryBuilder expand plot something
qb = QueryBuilder()
qb.append(
    orm.ArrayData,
    filters={
        'attributes.length': {'==': 50}
    },
    project=['attributes.array|default']
)
qb.all(flat=True)

## How can we create if conditions workflows?

In [None]:
@task.calcfunction
def compute_mean(eigvals: orm.ArrayData) -> orm.Float:
    return orm.Float(np.mean(eigvals.get_array()))

@task.calcfunction
def eigvals_less(mean_eigval: orm.Float) -> bool:
    return mean_eigval < 14.5

@task.calcfunction
def heureka(eigvals, pk):
    try:
        path = Path("storage").absolute()
        path.mkdir(exist_ok=True)
        result_path = path / f"eigvals-pk{pk}.npy"
        np.save(result_path, eigvals.get_array())
        success = orm.Int(0)
        success.attributes["path"] = str(result_path)
        success.attributes["error"] = ""
    except Exception as err:
        success = orm.Int(1)
        success.attributes["path"] = ""
        success.attributes["error"] = str(err)
    return success


In [None]:
@task.calcfunction
def eigvals_less(mean_eigval: orm.Float) -> bool:
    return mean_eigval < 14.5

wg = WorkGraph("someother")

martix_pk = 5

query_and_diag_task = wg.add_task(query_and_diag, name=f"query_and_diag_pk{matrix_pk}", matrix_pk=orm.Int(matrix_pk))
compute_mean_task = wg.add_task(compute_mean, name=f"compute_mean_pk{matrix_pk}", eigvals=query_and_diag_task.outputs["eigvals"])
eigvals_less_task = wg.add_task(eigvals_less, name=f"eigvals_less_task_pk{matrix_pk}", mean_eigval=compute_mean_task.outputs["result"])
if_less = wg.add_task("If", name=f"if_less_pk{matrix_pk}", conditions=eigvals_less_task.outputs["result"]) # there as specific conditions socket
heureka = wg.add_task(heureka_task, name=f"heureka_task_pk{matrix_pk}", eigvals=query_and_diag_task.outputs["eigvals"], pk=orm.Int(i))
if_less.children.add(f"heureka_task_pk{matrix_pk}")

display(wg)
wg.run()
