# Enzyme.jl

Julia is a high-level programming language using LLVM as a compiler backend.
Enzyme.jl uses Julia's GPU compiler infrastructure to provide a custom optimization
pipeline that inserts Enzyme LLVM pass. 

It uses Orc (v2/v1) to then JIT the adjoints and call them through Julia foreign-function
interface.

function mysum(X)
    acc = zero(eltype(X))
    @simd for x in X
       acc += x
    end
    acc
end

# Installing Enzyme

Tutorial tested with Julia 1.7-beta3

Using the Julia package manger:
```julia
import Pkg
Pkg.add("Enzyme")
```

In [37]:
import Pkg
Pkg.activate(; temp=true)
Pkg.add(Pkg.PackageSpec(name="Enzyme", rev="822afeff2c8a9b87c8fb93c6415cc3ffb19924e8"))
Pkg.add("BenchmarkTools")
Pkg.add("ForwardDiff")

[32m[1m  Activating[22m[39m new project at `/tmp/jl_yAfFxB`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m    Updating[22m[39m `/tmp/jl_yAfFxB/Project.toml`
 [90m [7da242da] [39m[92m+ Enzyme v0.7.0 `https://github.com/wsmoses/Enzyme.jl.git#822afef`[39m
[32m[1m    Updating[22m[39m `/tmp/jl_yAfFxB/Manifest.toml`
 [90m [79e6a3ab] [39m[92m+ Adapt v3.3.1[39m
 [90m [fa961155] [39m[92m+ CEnum v0.4.1[39m
 [90m [7da242da] [39m[92m+ Enzyme v0.7.0 `https://github.com/wsmoses/Enzyme.jl.git#822afef`[39m
 [90m [e2ba6199] [39m[92m+ ExprTools v0.1.6[39m
 [90m [61eb1bfa] [39m[92m+ GPUCompiler v0.13.7[39m
 [90m [692b3bcd] [39m[92m+ JLLWrappers v1.3.0[39m
 [90m [929cbde3] [39m[92m+ LLVM v4.6.0[39m
 [90m [d8793406] [39m[92m+ ObjectFile v0.3.7[39m
 [90m [21216c6a] [39m[92m+ Preferences v1.2.2[39m
 [90m [189a3867] [39m[92m+ Reexport v1.2.2[39m
 [90m [53d494c1] [39m[92m+ StructIO v0.3.0[39m
 [90m [a759f4b9] [39m[92m+ TimerOutputs v0.

In [38]:
using Enzyme
using ForwardDiff
using BenchmarkTools

# Activity annotations
- `Const`
- `Active`
- `Duplicated`
- `DuplicatedNoNeed`

In [39]:
square(x) = x^2

square (generic function with 1 method)

In [40]:
autodiff(square, 1.0)

()

Default activity for values is `Const`

In [41]:
autodiff(square, Const(1.0))

()

In [42]:
autodiff(square, Active(1.0))

(2.0,)

## Supporting mutating functions

Enzyme can differentiate through mutating functions. This requires that the users passes in the shadow variables with the `Duplicated` or `DuplicatedNoNeed` activity annotation.

In [43]:
function cube(y, x)
	y[] = x[]^3
	return nothing
end

cube (generic function with 1 method)

In [44]:
x = Ref(4.0)
y = Ref(0.0)
cube(y, x)
y[]

64.0


In order to calculate the gradient of `x`, we have to propagate `1.0` into the
shadow `dy`.


In [45]:
x = Ref(4.0)
dx = Ref(0.0)

y = Ref(0.0)
dy = Ref(1.0)

autodiff(cube, Duplicated(y, dy), Duplicated(x, dx))
y[], dy[], x[], dx[]

(64.0, 0.0, 4.0, 48.0)

# Reflection

In [46]:

Enzyme.Compiler.enzyme_code_llvm(cube, Const,
	Tuple{Enzyme.Duplicated{Base.RefValue{Float64}}, 
	Duplicated{Base.RefValue{Float64}}}, debuginfo=:none)

; Function Attrs: alwaysinline
define void @diffejulia_cube_9969wrap({}* %0, {}* %1, {}* %2, {}* %3) #3 {
entry:
  %"'ipc6.i" = bitcast {}* %3 to double*
  %4 = bitcast {}* %2 to double*
  %5 = load double, double* %4, align 8
  %6 = fmul double %5, %5
  %7 = fmul double %5, %6
  %"'ipc.i" = bitcast {}* %1 to double*
  %8 = bitcast {}* %0 to double*
  store double %7, double* %8, align 8
  %9 = load double, double* %"'ipc.i", align 8
  store double 0.000000e+00, double* %"'ipc.i", align 8
  %10 = load double, double* %"'ipc6.i", align 8
  %11 = fmul fast double %6, 3.000000e+00
  %reass.mul = fmul fast double %11, %9
  %12 = fadd fast double %reass.mul, %10
  store double %12, double* %"'ipc6.i", align 8
  ret void
}


# Differentiating through control-flow
Let's differentiate through some control flow. This kind of scalar code is where normally one would use `ForwardDiff.jl` since the machine learning optimized toolkits like Zygote have unacceptable overheads.

In [47]:
# Taylor series for `-log(1-x)`
# eval at -log(1-1/2) = -log(1/2)
function taylor(f::T, N=10^7) where T
    g = zero(T)
    for i in 1:N
        g += f^i / i
    end
    return g
end

autodiff(taylor, Active(0.5), Const(10^8))


(2.0,)

In [48]:
fwd_taylor(x) = ForwardDiff.derivative(taylor, 0.5)

enz_taylor(x) = autodiff(taylor, Active(x))


enz_taylor (generic function with 1 method)

In [49]:

@benchmark fwd_taylor($(Ref(0.5))[])

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m869.791 ms[22m[39m … [35m  1.031 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m903.407 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m919.832 ms[22m[39m ± [32m61.069 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m█[39m [39m█[39m█[34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m█[39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m█[39m█[34m▁[39

In [50]:
@benchmark enz_taylor($(Ref(0.5))[])

BenchmarkTools.Trial: 11 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m473.358 ms[22m[39m … [35m553.634 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m489.023 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m495.482 ms[22m[39m ± [32m 23.014 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m▁[39m▁[39m [39m [39m▁[39m [39m [39m [39m▁[39m [39m [34m█[39m[39m [39m [39m [39m▁[39m▁[32m [39m[39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m█[39m▁[39m▁


# Differentiating through more complicated codes

## A custom matrix multiply

In [51]:

function mymul!(R, A, B)
    @assert axes(A,2) == axes(B,1)
    @inbounds @simd for i in eachindex(R)
        R[i] = 0
    end
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        @inbounds @simd for k in axes(A,2)
            R[i,j] += A[i,k] * B[k,j]
        end
    end
    nothing
end

mymul! (generic function with 1 method)

In [52]:
A = rand(1024, 64)
B = rand(64, 512)

R = zeros(size(A,1), size(B,2))
∂z_∂R = rand(size(R)...)  # Some gradient/tangent passed to us

∂z_∂A = zero(A)
∂z_∂B = zero(B)

Enzyme.autodiff(mymul!, 
	Duplicated(R, ∂z_∂R),
	Duplicated(A, ∂z_∂A),
	Duplicated(B, ∂z_∂B))

()


Let's confirm correctness of result

In [53]:
R ≈ A * B

true

and correctness of the gradients

In [54]:
∂z_∂A ≈ ∂z_∂R * B'

true

# Some more fun

In [55]:
struct LList
    next::Union{LList,Nothing}
	val::Float64
end 

function sumlist(n::LList)
    sum = 0.0
    while n !== nothing
        sum += n.val
        n = n.next
    end
    sum
end

sumlist (generic function with 1 method)

In [56]:
regular = LList(LList(nothing, 1.0), 2.0)
shadow  = LList(LList(nothing, 0.0), 0.0)
autodiff(sumlist, Duplicated(regular, shadow))

()

In [57]:
shadow.val ≈ 1.0

true

In [58]:
shadow.next.val ≈ 1.0

true

# Differentiating through Parallelism

In [59]:
function tasktest(M, x)
    xr = Ref(x)
    task = Threads.@spawn begin
        @inbounds M[1] = xr[]
    end
    @inbounds M[2] = x
    wait(task)
    nothing
end

tasktest (generic function with 1 method)

In [60]:
R = Float64[0., 0.]
dR = Float64[2., 3.]

Enzyme.autodiff(tasktest, Duplicated(R, dR), Active(2.0))

└ @ Enzyme.Compiler /home/vchuravy/.julia/packages/Enzyme/2n29R/src/compiler.jl:212


(5.0,)

In [61]:
Float64[2.0, 2.0] ≈ R
Float64[0.0, 0.0] ≈ dR

true