In [24]:
using Plots
using Random
using Statistics
using Enzyme

In [1]:
function gravity(pos :: AbstractArray{Float64, 2}) :: AbstractArray{Float64, 2}
    vec_x = pos[1, :] .- pos[1, :]'
    vec_y = pos[2, :] .- pos[2, :]'
    vec_z = pos[3, :] .- pos[3, :]'

    r_sq = (vec_x.^2 + vec_y.^2 + vec_z.^2) .+ 1e-5

    force_x = -vec_x ./ r_sq
    force_y = -vec_y ./ r_sq
    force_z = -vec_z ./ r_sq

    force_x = sum(force_x, dims=2)
    force_y = sum(force_y, dims=2)
    force_z = sum(force_z, dims=2)

    force = [force_x force_y force_z]

    return transpose(force)
end

gravity (generic function with 1 method)

In [20]:
function kdk!(
    pos_t0 :: AbstractArray{Float64, 2},
    vel_t0 :: AbstractArray{Float64, 2},
    pos_t1 :: AbstractArray{Float64, 2},
    vel_t1 :: AbstractArray{Float64, 2},
    dt :: Float64) :: Nothing
    
    # compute a(t)
    force_t0 = gravity(pos_t0)
    # v(t + dt/2) = v(t) + a(t) * dt/2
    vel_t1 .+= force_t0 .* dt / 2
    # x(t + dt) = x(t) + v(t + dt/2) * dt
    pos_t1 .+= vel_t1 .* dt
    # a(t + dt)
    force_t1 = gravity(pos_t1)
    # v(t + dt) = v(t + dt/2) + a(t + dt) * dt/2
    vel_t1 .+= force_t1 .* dt / 2    

    return nothing
end

kdk! (generic function with 1 method)

In [21]:
function forward!(
    pos :: AbstractArray{Float64, 3},
    vel :: AbstractArray{Float64, 3},
    dt :: Float64,
    n_steps :: Int) :: Nothing

    for i in 1:n_steps
        kdk(
            view(pos, i, :, :),
            view(vel, i, :, :),
            view(pos, i+1, :, :),
            view(vel, i+1, :, :),
            dt)
    end

    return nothing
end

forward! (generic function with 1 method)

Initialize and compute the forward pass

In [22]:
n_particles = 50
n_steps = 200

pos = zeros(Float64, (n_steps+1, 3, n_particles))
vel = zeros(Float64, (n_steps+1, 3, n_particles))

pos_t0 = randn(Float64, (3, n_particles))

pos[1, :, :] .= pos_t0

forward(pos, vel, 0.01, n_steps)

3×50 view(::Array{Float64, 3}, 1, :, :) with eltype Float64:
  0.0711079  -0.121845   1.35599   …  1.71407   -0.391519  -0.976129
 -1.51845     2.18974    0.562172     0.122297   1.34041    1.11806
  0.842503    0.94564   -0.686825     0.312992  -1.28875   -0.235379

Let us define the loss function.

In [53]:
function loss!(
    pos_tn :: AbstractArray{Float64, 2},
    pos_tn_star :: AbstractArray{Float64, 2},
    loss :: AbstractArray{Float64, 2}) :: Nothing

    loss = (pos_tn .- pos_tn_star).^2

    return nothing
end

loss! (generic function with 1 method)

We can now find the loss with respect to the final positions of the particles.

In [54]:
pos_tn_star = zeros(Float64, (3, n_particles))
pos_tn_star[1, :] = range(-1, 1, length=n_particles)
loss_value = zeros(Float64, (3, n_particles))

loss!(pos[end, :, :], pos_tn_star, loss_value)

println("Loss: ", loss_value)

Loss: [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]


We introduce the backward function, which passes the gradients through the simulation.

In [None]:

# lets get the loss gradient
partial_loss = zeros(Float64, (3, n_particles))

autodiff(
    Reverse,
    loss,
    pos[end, :, :],
)

In [25]:
function backward(
    pos :: AbstractArray{Float64, 3},
    vel :: AbstractArray{Float64, 3},
    partial_pos :: AbstractArray{Float64, 3},
    partial_vel :: AbstractArray{Float64, 3},
    dt :: Float64,
    n_steps :: Int) :: Nothing

    for i in n_steps:-1:1

        autodiff(
            Reverse,
            kdk,
            Duplicated(view(pos, i, :, :), view(partial_pos, i, :, :)),
            Duplicated(view(vel, i, :, :), view(partial_vel, i, :, :)),
            Const(view(vel, i, :, :)),
            Const(view(pos, i+1, :, :)),
            Const(dt)
        )
        
    end

end
    

backward (generic function with 1 method)