In [None]:
using BenchmarkTools

# Lab 09 - Generated Functions & IR

In this lab you will practice two advanced meta programming techniques:

* _**Generated functions**_ can help you write specialized code for certain
  kinds of parametric types with more flexibility and/or less code.
* _**IRTools.jl**_ is a package that simplifies the manipulation of lowered and
  typed Julia code
  
  
## `@generate`d Functions

Remember the three most important things about generated functions:
* They return *quoted expressions* (like macros).
* You have access to type information of your input variables.
* They have to be _**pure**_

### A faster `polynomial`

Throughout this course we have come back to our `polynomial` function which
evaluates a polynomial based on the Horner schema. Below you can find a version
of the function that operates on a tuple of length $N$.

In [None]:
function polynomial(x, p::NTuple{N}) where N
    acc = p[N]
    for i in N-1:-1:1
        acc = x*acc + p[i]
    end
    acc
end

Julia has its own implementation of this function called `evalpoly`. If we
compare the performance of our `polynomial` and Julia's `evalpoly` we can
observe a pretty big difference:

In [None]:
x = 2.0

In [None]:
p = ntuple(float,20)

In [None]:
@btime polynomial($x,$p)

In [None]:
@btime evalpoly($x,$p)

Julia's implementation uses a generated function which specializes on different
tuple lengths (i.e. it *unrolls* the loop) and eliminates the (small) overhead
of looping over the tuple. This is possible, because the length of the tuple is
known during compile time. You can check the difference between `polynomial`
and `evalpoly` yourself via the introspectionwtools you know - e.g.
`@code_lowered`.

<div class="alert alert-block alert-success">
<b>Exercise:</b>
    
Rewrite the `polynomial` function as a generated function with the signature
```genpoly(x::Number, p::NTuple{N}) where N```.
    
**Hints:**
* Remember that you have to generate a quoted expression inside your generated
  function, so you will need things like `:($expr1 + $expr2)`.
* You can debug the expression you are generating by omitting the `@generated`
  macro from your function.
    
</div>

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

###
You should get the same performance as `evalpoly` (and as `@poly` from Lab 7 with
the added convenience of not having to spell out all the coefficients in your code
like: `p = @poly 1 2 3 ...`).

In [None]:
@btime genpoly($x,$p)

### Fast, Static Matrices

Another great example that makes heavy use of generated functions are *static
arrays*. A static array is an array of fixed size which can be implemented via
an `NTuple`. This means that it will be allocated on the stack, which can buy
us a lot of performance for smaller static arrays. We define a
`StaticMatrix{T,C,R,L}` where the paramteric types represent the matrix element
type `T` (e.g. `Float32`), the number of rows `R`, the number of columns `C`,
and the total length of the matrix `L=C*R` (which we need to set the size of
the `NTuple`).


In [None]:
struct StaticMatrix{T,R,C,L} <: AbstractArray{T,2}
    data::NTuple{L,T}
end

function StaticMatrix(x::AbstractMatrix{T}) where T
    (R,C) = size(x)
    StaticMatrix{T,R,C,C*R}(x |> Tuple)
end

<div class="alert alert-block alert-success">
<b>Exercise:</b>
    
As a warm-up, overload the `Base` functions `size`, `length`,
`getindex(x::StaticMatrix,i::Int)`, and `getindex(x::Solution,r::Int,c::Int)`.
    
</div>

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

###

You can check if everything works correctly by comparing to a normal `Matrix`:


In [None]:
x = rand(2,3)

In [None]:
x[1,2]

In [None]:
a = StaticMatrix(x)

In [None]:
a[1,2]

<div class="alert alert-block alert-success">
<b>Exercise:</b>

Overload matrix multiplication between two static matrices
```julia
Base.:*(x::StaticMatrix{T,K,M},y::StaticMatrix{T,M,N})
```
with a generated function that creates an expression without loops.  Below you
can see an example for an expression that would be generated from multiplying
two $2\times 2$ matrices.
```julia
:(StaticMatrix{T,2,2,4}((
    (x[1,1]*y[1,1] + x[1,2]*y[2,1]),
    (x[2,1]*y[1,1] + x[2,2]*y[2,1]),
    (x[1,1]*y[1,2] + x[1,2]*y[2,2]),
    (x[2,1]*y[1,2] + x[2,2]*y[2,2])
)))
```

**Hints:**

* You can get output like above by leaving out the `@generated` in front of your
  overload.
* It might be helpful to implement matrix multiplication in a *normal* Julia
  function first.
* You can construct an expression for a sum of multiple elements like below.
    
</div>

In [None]:
Expr(:call,:+,1,2,3)

In [None]:
Expr(:call,:+,1,2,3) |> eval

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

###

You can check that your matrix multiplication works by multiplying two random
matrices. Which one is faster?

In [None]:
a = rand(2,3)

In [None]:
b = rand(3,4)

In [None]:
c = StaticMatrix(a)

In [None]:
d = StaticMatrix(b)

In [None]:
a*b

In [None]:
c*d

## `OptionalArgChecks.jl`

The package [`OptionalArgChecks.jl`](https://github.com/simeonschaub/OptionalArgChecks.jl)
makes is possible to add checks to a function which can then be removed by
calling the function with the `@skip` macro.  For example, we can check if the
input to a function `f` is an even number


In [None]:
function f(x::Number)
    iseven(x) || error("Input has to be an even number!")
    x
end

If you are doing more involved argument checking it can take quite some time to
perform all your checks. However, if you want to be fast and are completely
sure that you are always passing in the correct inputs to your function, you
might want to remove them in some cases. Hence, we would like to transform the
IR of the function above


In [None]:
using IRTools
using IRTools: @code_ir
@code_ir f(1)

To some thing like this

In [None]:
transformed_f(x::Number) = x
@code_ir transformed_f(1)

### Marking Argument Checks
As a first step we will implement a macro that marks checks which we might want
to remove later by surrounding it with `:meta` expressions. This will make it
easy to detect which part of the code can be removed. A `:meta` expression can
be created like this


In [None]:
Expr(:meta, :mark_begin)

In [None]:
Expr(:meta, :mark_end)

and they will not be evaluated but remain in your IR. To surround an expression
with two meta expressions you can use a `:block` expression:

In [None]:
ex = :(x+x)

In [None]:
Expr(:block, :(print(x)), ex, :(print(x)))

<div class="alert alert-block alert-success">
<b>Exercise:</b>
    
Define a macro `@mark` that takes an expression and surrounds it with two
meta expressions marking the beginning and end of a check.
**Hints**
* Defining a function `_mark(ex::Expr)` which manipulates your expressions can
  help a lot with debugging your macro.

    
</div>

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

###

If you have defined a `_mark` function you can test that it works like this

In [None]:
_mark(:(println(x)))

The complete macro should work like below

In [None]:
function f(x::Number)
    @mark @show x
    x
end;

In [None]:
@code_ir f(2)

In [None]:
f(2)

### Removing Argument Checks

Now comes tricky part for which we need `IRTools.jl`.
We want to remove all lines that are between our two meta blocks.
You can delete the line that corresponds to a certain variable with the `delete!`
and the `var` functions.
E.g. deleting the line that defines variable `%4` works like this:

In [None]:
using IRTools: delete!, var

In [None]:
ir = @code_ir f(2)

In [None]:
delete!(ir, var(4))

<div class="alert alert-block alert-success">
<b>Exercise:</b>
    
Write a function `skip(ir::IR)` which deletes all lines between the meta
expression `:mark_begin` and `:mark_end`.

**Hints**
You can check whether a statement is one of our meta expressions like this:
    
</div>

In [None]:
ismarkbegin(e::Expr) = Meta.isexpr(e,:meta) && e.args[1]===:mark_begin

In [None]:
ismarkbegin(Expr(:meta,:mark_begin))

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>


###

Your function should transform the IR of `f` like below.


In [None]:
using IRTools: func

In [None]:
ir = @code_ir f(2)

In [None]:
ir = skip(ir)

In [None]:
func(ir)(nothing, 2)  # no output from @show!

However, if we have a slightly more complicated IR like below this version of
our function will fail. It actually fails so badly that running
`func(ir)(nothing,2)` after `skip` will cause the build of this page to crash,
so we cannot show you the output here ;).

In [None]:
function g(x)
    @mark iseven(x) && println("even")
    x
end

In [None]:
ir = @code_ir g(2)

In [None]:
ir = skip(ir)

The crash is due to `%4` not existing anymore. We can fix this by emptying the
block in which we found the `:mark_begin` expression and branching to the
block that contains `:mark_end` (unless they are in the same block already).
If some (branching) code in between remained, it should then be removed by the
compiler because it is never reached.

<div class="alert alert-block alert-success">
<b>Exercise:</b>
    
Use the functions `IRTools.block`, `IRTools.branches`, `IRTools.empty!`, and
`IRTools.branch!` to modify `skip` such that it also empties the `:mark_begin`
block, and adds a branch to the `:mark_end` block (unless they are the same
block).

**Hints**
* `block` gets you the block of IR in which a given variable is if you call e.g. `block(ir,var(4))`.
* `empty!` removes all statements in a block.
* `branches` returns all branches of a block.
* `branch!(a,b)` creates a branch from the end of block `a` to the beginning
  block `b`

    
</div>

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

###

The result should construct valid IR for our `g` function.


In [None]:
g(2)

In [None]:
ir = @code_ir g(2)

In [None]:
ir = skip(ir)

In [None]:
func(ir)(nothing,2)

And it should not break when applying it to `f`.

In [None]:
f(2)

In [None]:
ir = @code_ir f(2)

In [None]:
ir = skip(ir)

In [None]:
func(ir)(nothing,2)

### Recursively Removing Argument Checks

The last step to finalize the `skip` function is to make it work recursively.
In the current version we can handle functions that contain `@mark` statements,
but we are not going any deeper than that. Nested functions will not be touched:


In [None]:
foo(x) = bar(baz(x))

function bar(x)
    @mark iseven(x) && println("The input is even.")
    x
end

function baz(x)
    @mark x<0 && println("The input is negative.")
    x
end

In [None]:
ir = @code_ir foo(-2)

In [None]:
ir = skip(ir)

In [None]:
func(ir)(nothing,-2)

For recursion we will use the macro `IRTools.@dynamo` which will make recursion
of our `skip` function a lot easier. Additionally, it will save us from all the
`func(ir)(nothing, args...)` statements. To use `@dynamo` we have to slightly
modify how we call `skip`:

```julia
@dynamo function skip(args...)
    ir = IR(args...)
    
    # same code as before that modifies `ir`
    # ...

    return ir
end

# now we can call `skip` like this
skip(f,2)
```

Now we can easily use `skip` in recursion, because we can just pass the
arguments of an expression like this:

```julia
using IRTools: xcall

for (x,st) in ir
    isexpr(st.expr,:call) || continue
    ir[x] = xcall(skip, st.expr.args...)
end
```

The function `xcall` will create an expression that calls `skip` with the given
arguments and returns `Expr(:call, skip, args...)`.  Note that you can modify
expressions of a given variable in the IR via `setindex!`.

<div class="alert alert-block alert-success">
<b>Exercise:</b>

Modify `skip` such that it uses `@dynamo` and apply it recursively to all
`:call` expressions that you ecounter while looping over the given IR.
This will dive all the way down to `Core.Builtin`s and `Core.IntrinsicFunction`s
which you cannot maniuplate anymore (because they are written in C).
You have to end the recursion at these places which can be done via multiple
dispatch of `skip` on `Builtin`s and `IntrinsicFunction`s.

Once you are done with this you can also define a macro such that you can
conveniently call `@skip` with an expression:

```julia
skip(f,2)
@skip f(2)
```
    
</div>

<div class="alert alert-block alert-info">
<b>Solution</b>: </div>

####

In [None]:
@code_ir foo(2)

In [None]:
@code_ir skip(foo,2)

In [None]:
foo(-2)

In [None]:
skip(foo,-2)

In [None]:
@skip foo(-2)

## References

* [Static matrices](https://wesselb.github.io/2020/12/13/julia-learning-circle-meeting-3.html) with `@generate`d functions blog post
* [`OptionalArgChecks.jl`](https://github.com/simeonschaub/OptionalArgChecks.jl)
* IRTools [Dynamo](https://fluxml.ai/IRTools.jl/latest/dynamo/)