# Concept of the pipeline
## Basic ideas
- rubix essentially implements a big data transformation pipeline. 

- a pipeline is composed of nodes that are ordered in a list ordered by execution order (or more generally a DAG [not supported currently]). Each node is called a transformer. 

- each step in this pipeline (i.e., each transformer) can ultimately be seen in itself as being composed of other, smaller transformers. This gives us a pattern that can be used to guide the implementation of transformers

- simple implementation in rubix.pipeline

## Restrictions
- jax is pure functional. Anything that needs to be transformed with jax has to be a pure function. 
Any stuff that comes from the environment must be explicitly copied into the function or be bound to it such that the internal state is of the function is self-contained. 

- It's irrelevant what builds these pure functions. Therefore, we use a factory pattern to do all configuration work like reading files, pulling stuff from the net, providing any function arguments to be used in the pipeline and so on. A factory then produces a pure function that contains all the data we need as static arguments and retains only the stuff it computes on as tracable arguments. 

- we can leverage [`jax.tree_util.Partial`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html) for this, which works like `functools.partial` but is compatible with jax transformations. Note that stateful objects can still be used internally as long as no stuff from an outer scope (that may change over time) is read or written. This is the user's responsibility 


In [1]:
from copy import deepcopy

In [2]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr
from jax.tree_util import Partial
from jax import jit, grad

In [3]:
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

## Build some simple decorator for function configuration
-  leverages jax.tree_util.Partial
-  builds a partial object to which jax transformations can be applied 
-  three cases: 
   -  build the pure function object: you have to take care about static args/kwargs yourself upon calling jit. The decorator only builds the function object
   -  jit it right away: the usual. here you can tell it which args/kwargs to trace or not with the `static_args` and `static_kwargs` keyword arguments
   -  build expression: mainly to check what comes out of the thing at the end of for intermediate steps. can build a jax expression (wiht no arguments) or a jax core expression (when arguments are given as well). Note that for some reasone, `jax.make_jaxpr` does not have `static_argnames` like `jit` does. 
-  With these, we can configure our pipeline transformers. 
-  Not entirely sure right now which are useful or needed
-  these decorators/factory functions live in `rubix.pipeline.transformer`

**simple transformer decorator that binds function to arguments** 

this shows the basic implementation, they are available under `rubix.pipeline.transformer` in the package.


In [4]:
def bound_transformer(*args, **kwargs):
    """
    bound_transformer  Create a jax.Partial function from an input
    function and given arguments and keyword arguments.
    The user must take care that arguments are bound starting from the first,
    i.e., leftmost. If specific arguments should be bound using keyword
    arguments may be advisable.
    """

    def transformer_wrap(kernel):
        return Partial(
            deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)
        )  # deepcopy to avoid context dependency

    return transformer_wrap

In [5]:
@bound_transformer(z=5, k=3.14)
def add(x, y, z: float = 0, k: float = 0):
    return x + y + z + k

In [6]:
type(add)

jax.tree_util.Partial

In [7]:
addjit = jax.jit(add)

In [8]:
x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)

In [9]:
addjit(x, x)

Array([14.14, 12.14, 10.14], dtype=float32)

#### Compiling transformer to jit individual elements and bind them to traceable partial arguments

- can be used for the final pipeline or for intermediate steps during debug or whatever
- combines a `Partial` to bind arguments that is then jitted with static args and kwargs. However, bound args and kwargs can **NOT** be static at the same time. In principle, we would want a partial of a jit here, which kind of defeats the purpose of the jit because of overhead of the wrapper? 
- A solution to this would yield a configurable jit factory, essentially. 
- I am not entirely sure why the below works the way it does
- not even entirely sure it is useful at all... 

In [10]:
def compiled_transformer(
    *args,
    static_args: list = [],
    static_kwargs: list = [],
    **kwargs,
):

    def transformer_wrap(kernel):

        return jit(
            Partial(deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)),
            static_argnums=deepcopy(static_args),
            static_argnames=deepcopy(static_kwargs),
        )

    return transformer_wrap

In [11]:
@compiled_transformer(
    z=5,
    k=-3.14,
    static_kwargs=[
        "k",
    ],
)
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

In [12]:
cond_add

<PjitFunction of jax.tree_util.Partial(<function cond_add at 0x11fda5440>, z=5, k=-3.14)>

In [13]:
cond_add(x, x)

Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)

In [14]:
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

use on predefined functions without the decorator syntax

In [15]:
cond_add_plus = compiled_transformer(z=5, static_kwargs=["k"])(cond_add)

In [16]:
cond_add_plus

<PjitFunction of jax.tree_util.Partial(<function cond_add at 0x11fda6840>, z=5)>

In [17]:
cond_add_plus(x, x, k=-3.14)

Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)

**Problem**: the `compiled_transformer` decorator cannot make args or kwargs static that are bound to the function, i.e., configured parameters are not static here. This only works if the entire pipeline is compiled after assembling it. Not sure how to fix that at the moment, if at all

#### Expression based decorator for getting out the intermediate `jaxpr` object for inspection** 
- `make_jaxpr` does not support kwargs. god knows why?

In [18]:
def expression_transformer(
    *args,
    static_args: list = [],
):

    def transformer_wrap(kernel):
        if len(args) > 0:
            return jax.make_jaxpr(kernel, static_argnums=static_args)(*args)
        else:
            return jax.make_jaxpr(kernel, static_argnums=static_args)

    return transformer_wrap

In [19]:
@expression_transformer(x, x, 5, 3.14, static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

In [20]:
cond_add

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m b[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = add a b
    d[35m:f32[3][39m = add c 5.0
    e[35m:f32[3][39m = add d 6.28000020980835
  [34m[22m[1min [39m[22m[22m(e,) }

make sure to use the right `static_args` when doing control flow, or use jax/lax primitives

In [21]:
@expression_transformer(x, x, 5, -3.14, static_args=[3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

In [22]:
cond_add

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m b[35m:f32[3][39m c[35m:i32[][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[3][39m = add a b
    e[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] c
    f[35m:f32[3][39m = add d e
    g[35m:f32[3][39m = add f -3.140000104904175
  [34m[22m[1min [39m[22m[22m(g,) }

without giving arguments you get out a function that produces an expression when arguments are added

In [23]:
@expression_transformer(static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

In [24]:
cond_add

<function jax.make_jaxpr(cond_add)(x, y, z: float = 0, k: float = 0)>

In [25]:
cond_add(x, x, 3, 2.71)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m b[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = add a b
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[3][39m = add d 5.420000076293945
  [34m[22m[1min [39m[22m[22m(e,) }

### Define a number of simple, dump transformers
- we pretend that their second value is something we want to configure from the start and hence it should not be traced

- we can use the above decorators to bind their second arg to something we know

In [26]:
def add(x, s: float):
    return x + s


def mult(x, m: float):
    return x * m


def div(x, d: float):
    return x / d


def sub(x, s: float):
    return x - s

## Configuration files and pipeline building

### General remarks about yaml
- yaml format: dictionary 
- inside the dictionary one can arbitrarily nest lists, dicts. 
- yaml is customizable for node formats that are not provided by default, or for reading in types. Look up yaml-tags for more. 
- there's yaml libraries for pretty much all common languages

### Here, we use yaml in the following way: 
- the config file builds an adjacency list of a DAG essentially, but currently the design is limited to only one child per node => linear pipelines only, no branching
- consequently, the build algorithm is limited to linear pipelines for the moment. Both must evolve together.
- while a more general abstract base class is provided, we only implement a linear pipeline class `LinearTransformerPipeline` at the moment
- the essential part is the `Transformers` node of the config. this is the actual DAG adjacency list. This needs to adhere to the format outlined below.
- you can add other nodes to configure other parts of your system: data directories and so on.
- The starting point is always defined by a node that does not depend on another node. 
- The stop point is just the last element in the pipeline


#### Config node structure: 


```yaml
name_of_pipeline_step:

    name: name_of_function to use for this step

    depends_on: name_of_step_immediatelly_prior_in_pipeline 
    
    args: # arguments to the transformer functions that should be bound to it

        argument1: value1 

        argument2: value2 

        argumentN: valueN

```

the arguments in `args` will be used to create the `Partial` object, using the transformer decorators above.


**Example** 
```yaml 
  B: 
    name: sub
    depends_on: C
    args:
      s: 2
  C:
    name: add
    depends_on: null
    args:
      s: 4

```
Here, `C` is the starting node, i.e., the first function in the pipeline. 
Whatever you do before that with your data does not concern the pipeline and hence has no influence on differentiability etc. 

For a full example, see the `demo.yml` file.



### Read yaml and build pipeline 

`read_yaml` is available from `rubix.utils.py` and is very simple


In [27]:
read_cfg = read_yaml("./demo.yml")  # implemented in utils

In [28]:
read_cfg

{'Transformers': {'A': {'name': 'add',
   'depends_on': 'B',
   'args': [],
   'kwargs': {'s': 3.0}},
  'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
  'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
  'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
  'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}}

In [29]:
type(read_cfg)

dict

In [30]:
read_cfg["Transformers"]

{'A': {'name': 'add', 'depends_on': 'B', 'args': [], 'kwargs': {'s': 3.0}},
 'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
 'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
 'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
 'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}

In [31]:
type(read_cfg["Transformers"])

dict

Transformers need to be registered upon creation. If you have fixed ones or many of them, maybe it makes sense to write a factory function. 

In [32]:
tp = ltp.LinearTransformerPipeline(read_cfg, [add, mult, div, sub])

In [33]:
tp.transformers

{'add': <function __main__.add(x, s: float)>,
 'mult': <function __main__.mult(x, m: float)>,
 'div': <function __main__.div(x, d: float)>,
 'sub': <function __main__.sub(x, s: float)>}

The `transformers` member gives us a dict of `name: function` pairs for the transformers 
This currently has to be done before the assembly of the pipeline, or the pipeline will not know what to assemble it from

In [34]:
tp.assemble()

In [35]:
tp.pipeline

{'C': jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=4),
 'B': jax.tree_util.Partial(<function sub at 0x11fda7740>, s=2),
 'A': jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=3.0),
 'X': jax.tree_util.Partial(<function mult at 0x11fda62a0>, m=3),
 'Z': jax.tree_util.Partial(<function div at 0x11fda79c0>, d=4)}

Now we have a list of jax `Partial`s to which we can apply, assuming the individual elements are well behaved, all jax transformations in principle. If this is true for the elements, then it is true for the composition as long as the function we use for composition is pure functional itself

In [36]:
x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)

The expression that a pipeline builds is a partial object that is bound to the pipeline 

In [37]:
tp.expression

jax.tree_util.Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x11fda7880>, pipeline=[jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=4), jax.tree_util.Partial(<function sub at 0x11fda7740>, s=2), jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=3.0), jax.tree_util.Partial(<function mult at 0x11fda62a0>, m=3), jax.tree_util.Partial(<function div at 0x11fda79c0>, d=4)])

... it has the same signature as the first function in the pipeline.

In [38]:
func = tp.compile_expression()

In [39]:
func

<PjitFunction of jax.tree_util.Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x11fda7880>, pipeline=[jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=4), jax.tree_util.Partial(<function sub at 0x11fda7740>, s=2), jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=3.0), jax.tree_util.Partial(<function mult at 0x11fda62a0>, m=3), jax.tree_util.Partial(<function div at 0x11fda79c0>, d=4)])>

In [40]:
func(x)

Array([6.  , 5.25, 4.5 ], dtype=float32)

In [41]:
div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)

Array([6.  , 5.25, 4.5 ], dtype=float32)

... output's the same. yay :) 

In [42]:
expr = tp.get_jaxpr()(x)

In [43]:
expr

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 4.0
    c[35m:f32[3][39m = sub b 2.0
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[3][39m = mul d 3.0
    f[35m:f32[3][39m = div e 4.0
  [34m[22m[1min [39m[22m[22m(f,) }

In [44]:
type(expr)

jax._src.core.ClosedJaxpr

In [45]:
def func_manual(x):
    return div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)

In [46]:
make_jaxpr(func_manual)(x)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 4.0
    c[35m:f32[3][39m = sub b 2.0
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[3][39m = mul d 3.0
    f[35m:f32[3][39m = div e 4.0
  [34m[22m[1min [39m[22m[22m(f,) }

... expressions are too, because JAX is smart enough to trace across loops and we don't have to mess with expression composition ourselves. We hence should end up with something that's jax transformable if its elements are jax transformable. yay :) 

just for completeness, we can mess a bit more with the expression stuff

In [47]:
jax.jacfwd(tp.compile_expression())(x)

Array([[0.75, 0.  , 0.  ],
       [0.  , 0.75, 0.  ],
       [0.  , 0.  , 0.75]], dtype=float32)

In [48]:
jax.jacrev(tp.compile_expression())(x)

Array([[0.75, 0.  , 0.  ],
       [0.  , 0.75, 0.  ],
       [0.  , 0.  , 0.75]], dtype=float32)

In [49]:
jax.hessian(tp.compile_expression())(x)

Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)

### Query individual elements 

this is mainly useful for debugging

compile or get expressions for individual elements

In [50]:
tp.compile_element("A")

<PjitFunction of jax.tree_util.Partial(_HashableCallableShim(jax.tree_util.Partial(<function add at 0x11fda4ae0>, s=3.0)))>

In [51]:
tp.get_jaxpr_for_element("A", x)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[3][39m = add a 3.0 [34m[22m[1min [39m[22m[22m(b,) }

when building an expression with no arguments, a function is returned that creates an expression once args are added

In [52]:
f = tp.get_jaxpr_for_element("A")

In [53]:
f

<function jax.<unnamed function>(x, *, s: float = 3.0)>

In [54]:
f(x)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet[39m[22m[22m b[35m:f32[3][39m = add a 3.0 [34m[22m[1min [39m[22m[22m(b,) }

## Alternative structures that allow for more complex systems

- allow to inject new data at intermediate steps: multiple starting points: transforms the pipeline into an inverted tree. 

- allow for a step to depend on multiple other steps: transforms the pipeline into a directed acyclic graph. Common structure in more general data processing systems. 

=> if possible use something simple like `Partial` to accomplish this 

## Tentative best practices

- think in small steps: a more granular pipeline is easier to write in a pure functional style, easier to reason about and probably also better to optimize. 
- A more granular system also is easier to test and extend 
- ideally write the pipeline such that it can be compiled all at once with `compile_expression`. 


## Summary 
- pipeline produces same jax code as handwritten stuff. This seems encouraging.
- at which points do we still need to ensure pure functional behavior?
- how will we enforce transformer compatibility
- this is a pathologically simple case, hence not representative for real-world scenarios
- when does it break? 
- what use cases are not covered?
- what else do you need? 
