Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Tutorials

- [Partial evaluation](@ref partial-evaluation).
- [Control flow](@ref control-flow).
- [Sharding](@ref sharding).
- [Profiling](@ref profiling).
- [Multi-Host Environments](@ref distributed).
- [Local build of ReactantExtra](@ref local-build).
- [Control flow](@ref control-flow).
- [Sharding](@ref sharding).
- [Persistent Compilation Cache](@ref persistent_compile_cache).

We are currently working on adding more tutorials to Reactant!! Please check back soon!
116 changes: 116 additions & 0 deletions docs/src/tutorials/partial-evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# [Partial Evaluation](@id partial-evaluation)

When compiling functions with Reactant, the function arguments (and possible
closure fields) may contain non-Reactant values, i.e. numbers and arrays that
are not of type `Reactant.AbstractConcreteNumber` or
`Reactant.AbstractConcreteArray`.

The Reactant compiler may (but is not guaranteed to) treat these non-Reactant
values as constant and partially evaluate the function to be compiled based
on this.

For example, the function


```jldoctest partial_evaluation_tutorial
using Reactant
function add(a, b)
a + b
end;

# output

add (generic function with 1 method)
```

when compiled with two `ConcreteRNumber` arguments

```jldoctest partial_evaluation_tutorial; filter = r"I000.*" => s""
using Reactant

x = ConcreteRNumber(3)
y = ConcreteRNumber(4)

addxy = @compile add(x, y)

addxy(x, y)

# output

I0000
I0000
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)
```

returns a result that depends on both arguments `x` and `y`:


```jldoctest partial_evaluation_tutorial
addxy(ConcreteRNumber(7), ConcreteRNumber(8))

# output

ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(15)
```

The StableHLO IR code generated here is:

```jldoctest partial_evaluation_tutorial
@code_hlo add(x, y)

# output

module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<i64> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
return %0 : tensor<i64>
}
}
```

So at HLO-level, there a are two variable inputs `%arg0` and `%arg1`.

However, if argument `y` has a non-Reactant value during compilation, (`4` in
this example) then the result when executing the compiled function

```jldoctest partial_evaluation_tutorial; filter = r"I000.*" => s""
addx4 = @compile add(x, 4)

addx4(x, 4)

# output

I0000
I0000
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)
```

will only change based on `x`, not on the non-Reactant argument `y`, we get
`7 + 4 == 11`, not `7 + 8 == 15`:

```jldoctest partial_evaluation_tutorial
addx4(ConcreteRNumber(7), 8)

# output

ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(11)
```

The StableHLO code shows that the second argument has been replaced by a
constant `%c` during partial evaluation. When the compiled function is
executed, the value of `y` is ignored - at HLO-level, there is only one
variable input `%arg0`:

```jldoctest partial_evaluation_tutorial
@code_hlo add(x, 4)

# output

module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64>) -> tensor<i64> {
%c = stablehlo.constant dense<4> : tensor<i64>
%0 = stablehlo.add %arg0, %c : tensor<i64>
return %0 : tensor<i64>
}
}
```
Loading