# Compare Autodiff packages for speeds

+ The goal is to differentiate a log-likelihood function - the workhorse of probability theory, mathematical statistics and machine learning
    + $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$
    + Forward mode AD: efficient for $m >> n$ 
    + Reverse mode AD: efficient for $m << n$ 
    + In GWAS, 
        - $f$ is the loglikelihood function (i.e. $m = 1$)
        - $n = (\# \text{ fixed effects}) + (\# \text{ VC params}) + (\# \text{ SNP effect}) \approx 20$
        - **Isn't it true that only the SNP effect should be allowed to vary? So $n = 1$?**
+ Source: https://gist.github.com/ForceBru/63a08b62cb4bdf6a6d8bc23924929d48
+ We add a few more recent packages (e.g. Enzyme) to the list

In [1]:
# ml julia/1.9 python/3.9.0
using Random, DelimitedFiles
using ForwardDiff, ReverseDiff, Zygote, Symbolics, Enzyme
using LinearAlgebra
using BenchmarkTools
using Pkg
ENV["COLUMNS"] = 240

BLAS.set_num_threads(1)
@show Threads.nthreads()
Pkg.status()

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling ReverseDiff [37e2e3b7-166d-5795-8a7a-e32c996b4267]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling AdaptStaticArraysExt [e1699a77-9e31-5da8-bb3e-0a796f95f0a0]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling ConstructionBaseStaticArraysExt [8497ba20-d017-5d93-8a79-2639523b7219]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling StructArraysStaticArraysExt [d1e1e8be-46cf-5459-abb8-be6c7518b661]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling Symbolics [0c5d862f-8b57-4792-8d23-62f2024744c7]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling RecursiveArrayToolsZygoteExt [6283b665-1224-52f5-a8f0-5953a1198cc4]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling PreallocationToolsReverseDiffExt [723d033e-e474-5c37-8984-530550ab56d4]
[36m[1m[ [22m[39m[36m[1mInfo: [22m[3

Threads.nthreads() = 1
[36m[1mProject[22m[39m QuasiCopula v0.1.1
[32m[1mStatus[22m[39m `~/.julia/dev/QuasiCopula/Project.toml`
[33m⌅[39m [90m[f65535da] [39mConvex v0.14.18
  [90m[a93c6f00] [39mDataFrames v1.6.1
  [90m[31c24e10] [39mDistributions v0.25.107
  [90m[7da242da] [39mEnzyme v0.11.19
  [90m[7a1cc6ca] [39mFFTW v1.8.0
  [90m[f6369f11] [39mForwardDiff v0.10.36
  [90m[38e38edf] [39mGLM v1.9.0
[33m⌅[39m [90m[b6b21f68] [39mIpopt v0.8.0
  [90m[bdcacae8] [39mLoopVectorization v0.12.166
  [90m[fdba3010] [39mMathProgBase v0.7.8
  [90m[92933f4c] [39mProgressMeter v1.10.0
  [90m[189a3867] [39mReexport v1.2.2
  [90m[4e780e97] [39mSnpArrays v0.3.21
  [90m[276daf66] [39mSpecialFunctions v2.3.1
[33m⌅[39m [90m[4c63d2b9] [39mStatsFuns v0.9.18
[33m⌅[39m [90m[c751599d] [39mToeplitzMatrices v0.7.1
  [90m[37e2e46d] [39mLinearAlgebra
  [90m[9a3f8284] [39mRandom
[36m[1mInfo[22m[39m Packages marked with [33m⌅[39m have new versions available but 

## Setup 

In [2]:
# ========== Benchmark setup ==========
SEED = 42
N_SAMPLES = 10000
N_COMPONENTS = 5

rnd = Random.MersenneTwister(SEED)
data = randn(rnd, N_SAMPLES)
params0 = [rand(rnd, N_COMPONENTS); randn(rnd, N_COMPONENTS); 2rand(rnd, N_COMPONENTS)]

# save file to be read into python later
DelimitedFiles.writedlm("gen_data.csv", data, ',')
DelimitedFiles.writedlm("gen_params0.csv", params0, ',')

# ========== Objective function ==========

normal_pdf(x::Real, mean::Real, var::Real) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)

mixture_pdf(x::Real, weights::AbstractVector{<:Real}, means::AbstractVector{<:Real}, vars::AbstractVector{<:Real}) =
    sum(
        w * normal_pdf(x, mean, var)
        for (w, mean, var) in zip(weights, means, vars)
    )

normal_pdf(x, mean, var) =
    exp(-(x - mean)^2 / (2var)) / sqrt(2π * var)


function mixture_loglikelihood(params::AbstractVector, data::AbstractVector)
    K = length(params) ÷ 3
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]
    mat = normal_pdf.(data, means', stds' .^2) # (N, K)
    #@show size(mat)

    # original objective (doesn't work)
    sum(mat .* weights', dims=2) .|> log |> sum

    # another form of original objective commented out by the original author (same issue)
#     sum(
#         sum(
#             weight * normal_pdf(x, mean, std^2)
#             for (weight, mean, std) in zip(weights, means, stds)
#         ) |> log
#         for x in data
#     )

    # objective re-written by me (same issue)
#     obj = zero(eltype(mat))
#     for x in data
#         obj_i = zero(eltype(mat))
#         for (weight, mean, std) in zip(weights, means, stds)
#             obj_i += weight * normal_pdf(x, mean, std^2)
#         end
#         obj += log(obj_i)
#     end
#     return obj
end
        
objective = params -> mixture_loglikelihood(params, data)
        
function generate_gradient(out_fname::AbstractString, K::Integer)
    @assert K > 0
    Symbolics.@variables x ws[1:K] mus[1:K] stds[1:K]

    args=[x, ws, mus, stds]
    expr = Symbolics.gradient(
        mixture_pdf(x, ws, mus, collect(stds) .^2) |> log,
        [ws; mus; stds]
    )

    fn, fn_mut = Symbolics.build_function(expr, args...)
    
    write(out_fname, string(fn_mut))
end

        
@show params0
@show objective(params0)
@info "Settings" SEED N_SAMPLES N_COMPONENTS length(params0)

        
# ========== Gradient with Symbolics.jl ==========
@info "Generating gradient functions..."
GRAD_FNS = Union{Nothing, Function}[nothing]
for K in 2:5
    fname = "grad_$K.jl"
    @show generate_gradient(fname, K)
    push!(GRAD_FNS, include(fname))
end

function my_gradient!(out::AbstractVector{<:Real}, tmp::AbstractVector{<:Real}, xs::AbstractVector{<:Real}, params::AbstractVector{<:Real})
    K = length(params) ÷ 3
    grad! = GRAD_FNS[K]
    weights, means, stds = @views params[1:K], params[K+1:2K], params[2K+1:end]

    out .= 0
    for x in xs
        grad!(tmp, x, weights, means, stds)
        out .+= tmp
    end
end

params0 = [0.6509560930859444, 0.17036894385064993, 0.21319596776697636, 0.4705968797513471, 0.9066124779371352, -0.7596053407203316, 0.4501019833316699, -0.03382219163257187, -0.01866041173235008, 1.4306488869677423, 0.7822675207825798, 1.238815697809096, 1.9650279191800957, 1.9106539785480954, 1.1385080309238362]
objective(params0) = -7473.507394000162


[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mSettings
[36m[1m│ [22m[39m  SEED = 42
[36m[1m│ [22m[39m  N_SAMPLES = 10000
[36m[1m│ [22m[39m  N_COMPONENTS = 5
[36m[1m└ [22m[39m  length(params0) = 15
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mGenerating gradient functions...


generate_gradient(fname, K) = 5482
generate_gradient(fname, K) = 9638
generate_gradient(fname, K) = 14985
generate_gradient(fname, K) = 21520


my_gradient! (generic function with 1 method)

## Actual Benchmarks

In [3]:
@info "Computing gradient w/ Symbolics"
let
    grad_storage = similar(params0)
    tmp = similar(params0)

    # 1. Compile
    my_gradient!(grad_storage, tmp, data, params0)
    # 2. Benchmark
    trial = run(@benchmarkable $my_gradient!($grad_storage, $tmp, $data, $params0) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ ForwardDiff"
let
    grad_storage = similar(params0)
    cfg_grad = ForwardDiff.GradientConfig(objective, params0, ForwardDiff.Chunk{length(params0)}())

    # 1. Compile
    ForwardDiff.gradient!(grad_storage, objective, params0, cfg_grad)
    # 2. Benchmark
    trial = run(@benchmarkable ForwardDiff.gradient!($grad_storage, $objective, $params0, $cfg_grad) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ ReverseDiff"
let
    grad_storage = similar(params0)
    objective_tape = ReverseDiff.GradientTape(objective, params0) |> ReverseDiff.compile

    # 1. Compile
    ReverseDiff.gradient!(grad_storage, objective_tape, params0)
    # 2. Benchmark
    trial = run(@benchmarkable ReverseDiff.gradient!($grad_storage, $objective_tape, $params0) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ Zygote reverse"
let
    # 1. Compile
    grad_storage = Zygote.gradient(objective, params0)
    # 2. Benchmark
    trial = run(@benchmarkable Zygote.gradient($objective, $params0) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

@info "Computing gradient w/ Enzyme reverse"
let
    # 1. Compile
    grad_storage = zeros(length(params0))
    # 2. Benchmark
    trial = run(@benchmarkable Enzyme.gradient!($Reverse, $grad_storage, $objective, $params0) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end

println("Done!")

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ Symbolics


BenchmarkTools.Trial: 2355 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m13.548 ms[22m[39m … [35m60.095 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 36.85%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m24.025 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m25.324 ms[22m[39m ± [32m 8.043 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m1.59% ±  5.96%

  [39m [39m▂[39m▃[39m [39m▂[39m▅[39m█[39m▆[39m▃[39m [39m [39m [39m [39m [39m [39m▁[39m▂[39m [39m▂[39m [34m [39m[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39m█[39m█[39m█[39m█[

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ ForwardDiff


BenchmarkTools.Trial: 4381 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 5.645 ms[22m[39m … [35m45.086 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 70.33%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m11.848 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m13.605 ms[22m[39m ± [32m 5.503 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m11.24% ± 13.25%

  [39m [39m [39m [39m [39m [39m▅[39m▅[39m█[39m█[39m▄[39m▄[39m▅[39m▇[39m▆[39m▇[34m▄[39m[39m▁[39m▂[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▃[39m▃[39m▄[39m█[39m

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ ReverseDiff


BenchmarkTools.Trial: 596 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 65.397 ms[22m[39m … [35m176.319 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m 96.394 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m100.475 ms[22m[39m ± [32m 18.937 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m▅[39m█[39m▅[39m▃[39m▆[39m▅[39m▃[39m [39m▇[34m▅[39m[39m▃[39m▂[32m▂[39m[39m▁[39m▃[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▄[39m▃[39m

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ Zygote reverse


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.334 ms[22m[39m … [35m35.828 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 77.48%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m4.797 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m5.136 ms[22m[39m ± [32m 2.152 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.14% ± 13.47%

  [39m [39m [39m█[39m▅[39m [39m [39m▅[39m▁[39m▁[39m▃[39m▃[39m▄[39m [39m [39m▂[34m▃[39m[39m▃[32m▂[39m[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39m█[39m█[39m▆[39m█[39m█[39

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ Enzyme reverse


BenchmarkTools.Trial: 1 sample with 1 evaluation.
 Single result which took [34m3.826 ms[39m (0.00% GC) to evaluate,
 with a memory estimate of [33m1.83 MiB[39m, over [33m58[39m allocations.
grad_storage = [5297.814324727782, 4763.559259051035, 3824.5840662504697, 3896.850998205266, 3408.9108396746446, 1058.3192613979938, -149.37508594367515, -22.473937315596682, -56.84715458264528, -1741.7742573028147, -603.9101820698656, -298.72845554586803, -285.9002461898592, -649.5577894738506, -202.70373627213408]
Done!


In [4]:
@info "Computing gradient w/ Enzyme reverse"
let
    # 1. Compile
    grad_storage = zeros(length(params0))
    # 2. Benchmark
    trial = run(@benchmarkable Enzyme.gradient!($Reverse, $grad_storage, $objective, $params0) samples=10_000 evals=1 seconds=60)
    show(stdout, MIME("text/plain"), trial)
    println()
    @show grad_storage
end


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mComputing gradient w/ Enzyme reverse


BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.600 ms[22m[39m … [35m23.682 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 32.41%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.245 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m3.525 ms[22m[39m ± [32m 1.188 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m4.77% ± 10.43%

  [39m [39m▅[39m▆[39m█[39m█[34m▅[39m[39m▂[39m▁[32m▂[39m[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m█[39m█[39m█[39m█[34m█[39m[39m

15-element Vector{Float64}:
  5297.814324727782
  4763.559259051035
  3824.5840662504697
  3896.850998205266
  3408.9108396746446
  1058.3192613979938
  -149.37508594367515
   -22.473937315596682
   -56.84715458264528
 -1741.7742573028147
  -603.9101820698656
  -298.72845554586803
  -285.9002461898592
  -649.5577894738506
  -202.70373627213408

## JAX (python)

In [1]:
# install: /share/software/user/open/python/3.9.0/bin/pip3 install --upgrade "jax[cpu]"
# test run: /share/software/user/open/python/3.9.0/bin/python3.9

import timeit

import numpy as np

import jax
import jax.numpy as jnp

# Enable float64 support
jax.config.update("jax_enable_x64", True)

@jax.jit
def normal_pdf(data, mean, var):
    return jnp.exp(-(data - mean)**2 / (2 * var)) / jnp.sqrt(2 * jnp.pi * var)

@jax.jit
def mixture_loglikelihood(params: jnp.ndarray, data: jnp.ndarray)  -> float:
    K = len(params) // 3
    weights, means, stds = params[:K], params[K:2*K], params[2*K:]
    mat = normal_pdf(data, means.T, stds.T**2) # (N, K)
    return jnp.log((mat * weights.T).sum(1)).sum()

data = np.loadtxt("gen_data.csv").flatten()[:, None]
params0 = np.loadtxt("gen_params0.csv").flatten()
params0 = jnp.array(params0)

objective = lambda params: mixture_loglikelihood(params, data)

print(objective(params0))
# Output: -443.40397372007186

the_grad = jax.jit(jax.grad(objective))

print(the_grad(params0))
# Output: [289.73084956 199.27559525 236.68945778 292.06123402  -9.42979939
#  26.72229565  -1.91803555  37.9874909  -24.09562015 -13.93568733
# -38.00044666  12.87712892]

# ========== Benchmark ==========
N_SAMPLES, N_EVALS = 10_000, 1 # like in Julia
bench_secs = timeit.repeat(
    "the_grad(params0)", globals={'the_grad': the_grad, 'params0': params0},
    repeat=N_SAMPLES, number=N_EVALS
)
bench_mus = 1_000_000 * np.array(bench_secs)

print(f"{bench_mus.mean():.3f} μs ± {bench_mus.std():.3f} μs ({N_SAMPLES} samples with {N_EVALS} evaluation)")

-7473.507394000162
[ 5297.81432473  4763.55925905  3824.58406625  3896.85099821
  3408.91083967  1058.3192614   -149.37508594   -22.47393732
   -56.84715458 -1741.7742573   -603.91018207  -298.72845555
  -285.90024619  -649.55778947  -202.70373627]
2515.012 μs ± 718.190 μs (10000 samples with 1 evaluation)


## Summary

In [None]:
Symbolics.jl
24.025 ms
ForwardDiff.jl
11.848 ms
ReverseDiff.jl
96.394 ms
Zygote.jl (reverse)
4.797 ms
Enzyme.jl (reverse mode)
3.245 ms
JAX (python)
2.515 ms


## Does Enzyme.jl/Zygote.jl support BLAS?

In [15]:
using Enzyme
using Zygote
using LinearAlgebra
using BenchmarkTools

Enzyme.jl

In [41]:
# objective = 0.5 || y - X*beta ||^2
function ols(y, X, beta)
    storage = X * beta
    obj = zero(eltype(beta))
    for i in eachindex(y)
        obj += abs2(y[i] - storage[i])
    end
    return 0.5obj
end
ols(beta::AbstractVector) = ols(y, X, beta)

# simulate data
n = 100
p = 50
X = randn(n, p)
y = randn(n)
beta = randn(p)
ols(y, X, beta)

# autodiff grad
grad1 = zeros(length(beta))
Enzyme.gradient!(Reverse, grad1, ols, beta) # method 1
grad2 = Enzyme.gradient(Reverse, ols, beta) # method 2
grad3 = zeros(length(beta))
Enzyme.autodiff(Reverse, ols, Active, Duplicated(beta, grad3)) # method 3

# analytical grad
true_grad = -X' * (y - X*beta)

# compare answers
[true_grad grad1 grad2 grad3]

50×4 Matrix{Float64}:
  223.421      223.421      223.421      223.421
   56.115       56.115       56.115       56.115
  162.285      162.285      162.285      162.285
  -58.6251     -58.6251     -58.6251     -58.6251
  109.732      109.732      109.732      109.732
 -152.867     -152.867     -152.867     -152.867
   51.7608      51.7608      51.7608      51.7608
  -64.8512     -64.8512     -64.8512     -64.8512
   38.8877      38.8877      38.8877      38.8877
  -62.3895     -62.3895     -62.3895     -62.3895
   92.2303      92.2303      92.2303      92.2303
  208.785      208.785      208.785      208.785
   47.6175      47.6175      47.6175      47.6175
    ⋮                                   
  -26.5664     -26.5664     -26.5664     -26.5664
   65.8905      65.8905      65.8905      65.8905
  279.925      279.925      279.925      279.925
  -18.8817     -18.8817     -18.8817     -18.8817
 -169.649     -169.649     -169.649     -169.649
 -134.628     -134.628     -134.628     -134.

Zygote.jl

In [19]:
# objective = 0.5 || y - X*beta ||^2
function ols(y, X, beta)
    storage = X * beta
    obj = zero(eltype(beta))
    for i in eachindex(y)
        obj += abs2(y[i] - storage[i])
    end
    return 0.5obj
end
ols(beta::AbstractVector) = ols(y, X, beta)

# simulate data
n = 100
p = 50
X = randn(n, p)
y = randn(n)
beta = randn(p)
storage1 = zeros(n)
storage2 = zeros(n)
ols(y, X, beta)

# autodiff grad
grad_storage = Zygote.gradient(ols, beta)

# analytical grad
true_grad = -X' * (y - X*beta)

# compare answers
[true_grad grad_storage[1]]

50×2 Matrix{Float64}:
  -30.7309    -30.7309
   75.0246     75.0246
   -9.56967    -9.56967
  -60.3384    -60.3384
   86.1821     86.1821
  245.025     245.025
   51.3395     51.3395
   39.1938     39.1938
   45.1931     45.1931
  -25.5218    -25.5218
  106.324     106.324
  -13.0246    -13.0246
   18.577      18.577
    ⋮        
   -4.90943    -4.90943
 -158.693    -158.693
   13.4801     13.4801
  -47.8658    -47.8658
   60.1435     60.1435
   93.3104     93.3104
   -3.67877    -3.67877
   44.3463     44.3463
   38.8036     38.8036
   82.8235     82.8235
  155.893     155.893
 -187.871    -187.871

ForwardDiff.jl

In [2]:
using ForwardDiff

function A_mul_b!(c::AbstractVector{T}, A::AbstractMatrix, b::AbstractVector) where T
    n, p = size(A)
    fill!(c, zero(T))
    for j in 1:p, i in 1:n
        c[i] += A[i, j] * b[j]
    end
    return c
end

function ols(y, X, beta, storage=zeros(eltype(beta), size(X, 1)))
    A_mul_b!(storage, X, beta)
    storage .= y .- storage
    return 0.5 * sum(abs2, storage)
end
ols(beta::AbstractVector) = ols(y, X, beta)

# simulate data
n = 100
p = 50
X = randn(n, p)
y = randn(n)
beta = randn(p)
storage = zeros(n)
ols(y, X, beta, storage)

grad = ForwardDiff.gradient(ols, beta)
true_grad = -X' * (y - X*beta)
[true_grad grad]

50×2 Matrix{Float64}:
  101.766     101.766
 -142.45     -142.45
  183.354     183.354
  -48.4099    -48.4099
   21.6648     21.6648
  -33.7665    -33.7665
  -38.6346    -38.6346
  -50.3303    -50.3303
  -49.5342    -49.5342
  -37.1839    -37.1839
   42.0213     42.0213
 -161.4      -161.4
  119.435     119.435
    ⋮        
  -89.8549    -89.8549
 -171.656    -171.656
  124.696     124.696
   48.1196     48.1196
  -32.7268    -32.7268
    9.2581      9.2581
   39.0069     39.0069
  -83.9941    -83.9941
  -17.3246    -17.3246
  274.159     274.159
    4.91454     4.91454
 -106.634    -106.634

## Does Enzyme.jl/Zygote.jl work with `struct`?

In [36]:
using Enzyme
using Zygote
using LinearAlgebra

struct MyData
    X::Matrix{Float64}
    y::Vector{Float64}
    storage::Vector{Float64}
end

# objective = 0.5 || y - X*beta ||^2
function ols(data::MyData, beta)
    storage = data.X * beta # works
#     mul!(data.storage, data.X, beta)
    obj = zero(eltype(data.X))
    for i in eachindex(data.y)
        obj += abs2(data.y[i] - storage[i])
    end
    return 0.5obj
end
ols(beta::AbstractVector) = ols(data, beta)

# simulate data
n = 10000
p = 50
X = randn(n, p)
y = randn(n)
beta = randn(p)
storage = zeros(n)
data = MyData(X, y, storage)
ols(data, beta)

# autodiff grad
grad = zeros(length(beta))
@time Enzyme.autodiff(Reverse, ols, Active, Duplicated(beta, grad))

# zygote grad
@time grad2 = Zygote.gradient(ols, beta)[1]

# analytical grad
true_grad = -X' * (y - X*beta)

# compare answers
[true_grad grad grad2]

  0.153032 seconds (251.62 k allocations: 12.037 MiB, 99.30% compilation time: 100% of which was recompilation)
  0.367664 seconds (676.90 k allocations: 1.521 GiB, 21.31% gc time, 32.68% compilation time: 100% of which was recompilation)


50×3 Matrix{Float64}:
  10958.1     10958.1     10958.1
   6872.77     6872.77     6872.77
   4268.43     4268.43     4268.43
  -6559.18    -6559.18    -6559.18
     -1.629      -1.629      -1.629
   1296.32     1296.32     1296.32
   9705.7      9705.7      9705.7
  -9632.29    -9632.29    -9632.29
 -20403.3    -20403.3    -20403.3
  -7022.28    -7022.28    -7022.28
   4736.58     4736.58     4736.58
   5876.46     5876.46     5876.46
 -20949.6    -20949.6    -20949.6
      ⋮                  
 -18599.9    -18599.9    -18599.9
  -7730.13    -7730.13    -7730.13
  -9266.25    -9266.25    -9266.25
    713.507     713.507     713.507
  -4060.83    -4060.83    -4060.83
  11029.0     11029.0     11029.0
  -8578.22    -8578.22    -8578.22
  23480.5     23480.5     23480.5
  16420.2     16420.2     16420.2
  -9767.68    -9767.68    -9767.68
  -3973.54    -3973.54    -3973.54
   1664.5      1664.5      1664.5