In [1]:
using Plots
using Random
using Statistics
using Enzyme

In [8]:
@views function gravity(pos :: AbstractArray{Float64, 2}) :: AbstractArray{Float64, 2}
    vec_x = pos[1, :] .- pos[1, :]'
    vec_y = pos[2, :] .- pos[2, :]'
    vec_z = pos[3, :] .- pos[3, :]'

    r_sq = (vec_x.^2 + vec_y.^2 + vec_z.^2) .+ 1e-5

    force_x = -vec_x ./ r_sq
    force_y = -vec_y ./ r_sq
    force_z = -vec_z ./ r_sq

    force_x = sum(force_x, dims=2)
    force_y = sum(force_y, dims=2)
    force_z = sum(force_z, dims=2)

    force = [force_x force_y force_z]

    return transpose(force)
end

gravity (generic function with 1 method)

In [9]:
function kdk!(
    pos_t0 :: AbstractArray{Float64, 2},
    vel_t0 :: AbstractArray{Float64, 2},
    pos_t1 :: AbstractArray{Float64, 2},
    vel_t1 :: AbstractArray{Float64, 2},
    dt :: Float64) :: Nothing
    
    # compute a(t)
    force_t0 = gravity(pos_t0)
    # v(t + dt/2) = v(t) + a(t) * dt/2
    vel_t1 .+= force_t0 .* dt / 2
    # x(t + dt) = x(t) + v(t + dt/2) * dt
    pos_t1 .+= vel_t1 .* dt
    # a(t + dt)
    force_t1 = gravity(pos_t1)
    # v(t + dt) = v(t + dt/2) + a(t + dt) * dt/2
    vel_t1 .+= force_t1 .* dt / 2    

    return nothing
end

kdk! (generic function with 1 method)

In [10]:
@views function forward!(
    pos :: AbstractArray{Float64, 3},
    vel :: AbstractArray{Float64, 3},
    dt :: Float64,
    n_steps :: Int) :: Nothing

    for i in 1:n_steps
        kdk!(
            pos[i, :, :],
            vel[i, :, :],
            pos[i+1, :, :],
            vel[i+1, :, :],
            dt)
    end

    return nothing
end

forward! (generic function with 1 method)

Initialize and compute the forward pass

In [11]:
n_particles = 50
n_steps = 200

pos = zeros(Float64, (n_steps+1, 3, n_particles))
vel = zeros(Float64, (n_steps+1, 3, n_particles))

pos_t0 = randn(Float64, (3, n_particles))

pos[1, :, :] .= pos_t0

forward!(pos, vel, 0.01, n_steps)

Let us define the loss function.

In [33]:
@views function loss!(
    pos_tn :: AbstractArray{Float64, 2},
    pos_tn_star :: AbstractArray{Float64, 2},
    loss :: AbstractArray{Float64, 2}) :: Nothing

    loss[:, :] = (pos_tn .- pos_tn_star).^2

    return nothing
end

loss! (generic function with 1 method)

Now we can find the gradients of the loss with respect to the final positions of the particles.

In [56]:
loss_values = zeros(Float64, (3, n_particles))
partial_loss = ones(Float64, (3, n_particles))
partial_pos = zeros(Float64, (n_steps+1, 3, n_particles))
partial_vel = zeros(Float64, (n_steps+1, 3, n_particles))

@views begin
    autodiff(
        Reverse,
        loss!, 
        Duplicated(pos[end, :, :], partial_pos[end, :, :]),
        Const(pos_tn_star),
        Duplicated(loss_value, partial_loss))
end

partial_pos[end, :, :]

3×50 Matrix{Float64}:
 2.0525       1.84907      1.78852    …  -1.84012    -1.86306    -1.92625
 0.00413336  -0.00260682  -0.0597922     -0.0195634  -0.0491868   0.00554026
 0.0562904   -0.0336393    0.0060033      0.0259994  -0.0211828   0.0202027

We introduce the backward function, which passes the gradients through the simulation.

In [59]:
function backward(
    pos :: AbstractArray{Float64, 3},
    vel :: AbstractArray{Float64, 3},
    partial_pos :: AbstractArray{Float64, 3},
    partial_vel :: AbstractArray{Float64, 3},
    dt :: Float64,
    n_steps :: Int) :: Nothing

    for i in n_steps:-1:1

        autodiff(
            Reverse,
            kdk!,
            Duplicated(view(pos, i, :, :), view(partial_pos, i, :, :)),
            Duplicated(view(vel, i, :, :), view(partial_vel, i, :, :)),
            Const(view(vel, i, :, :)),
            Const(view(pos, i+1, :, :)),
            Const(dt)
        )
        
    end

end
    

backward (generic function with 1 method)

In [60]:
backward(
    pos,
    vel,
    partial_pos,
    partial_vel,
    0.01,
    n_steps)

[33m[1m│ [22m[39m  T = Tuple{Vararg{Base.OneTo{Int64}}}
[33m[1m└ [22m[39m[90m@ Enzyme C:\Users\andri\.julia\packages\GPUCompiler\kqxyC\src\utils.jl:59[39m


Enzyme.Compiler.EnzymeRuntimeException: Enzyme execution failed.
Mismatched activity for:   %value_phi87 = phi {} addrspace(10)* [ %103, %L187 ], [ %getfield, %L179 ] const val:   %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr unordered, align 8, !dbg !139, !tbaa !57, !alias.scope !90, !noalias !93, !nonnull !51, !dereferenceable !141, !align !142
 value=Unknown object of type Array{Float64, 3}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
  [1] ==
    @ .\promotion.jl:521
  [2] !=
    @ .\operators.jl:276
  [3] _newindexer
    @ .\broadcast.jl:631
  [4] shapeindexer
    @ .\broadcast.jl:626
  [5] newindexer
    @ .\broadcast.jl:625
  [6] extrude
    @ .\broadcast.jl:676
  [7] preprocess
    @ .\broadcast.jl:984
  [8] preprocess_args
    @ .\broadcast.jl:986
  [9] preprocess
    @ .\broadcast.jl:983
 [10] copyto!
    @ .\broadcast.jl:1000
 [11] copyto!
    @ .\broadcast.jl:956
 [12] copy
    @ .\broadcast.jl:928
 [13] materialize
    @ .\broadcast.jl:903
 [14] gravity
    @ c:\Users\andri\projects\msc-thesis\experiments\dp.ipynb:2
