In [7]:
using ForwardDiff
using DifferentialEquations
const DE = DifferentialEquations
using Zygote
using DiffEqSensitivity
using PyPlot
using CUDA
using LinearAlgebra

In [75]:
function f(a)
    A = a*[.1 2; .3 .4]
    b = [1/sqrt(2); 1/sqrt(2)]
    for i=1:100
        b = A*b
    end
    return b[1]
end     
f_grad = x -> ForwardDiff.derivative(f,x)

#17 (generic function with 1 method)

In [59]:
function pi_pulse_julia(ω_d)
    ω_q = 5.0; Ω = .1
    H0 = ω_q/2.0*[1 0; 0 -1] .+ 0.0*im; H1 = Ω*[0 1; 1 0] .+ 0.0*im
    ψᵢ = [1 0; 0 0] .+ 0.0*im
    tf = π/Ω
    
    function ρ_dot(ρ, p, t)
        fac = cos(ω_d*t)
        H = H0 + H1*fac
        dρ = -1.0*im * (H * ρ)
        dρ += -1.0*im * (-ρ * H)
        return dρ
    end
    
    prob = DE.ODEProblem(ρ_dot, ψᵢ, (0,tf))
    sol = DE.solve(prob, Tsit5(); reltol=1e-8, abstol=1e-8)
    return sol
end

@btime results = pi_pulse_julia(5.0);

  10.124 ms (154024 allocations: 18.81 MiB)


In [None]:
using BenchmarkTools

In [27]:
function Eq(du,u,p,t)
    E = 200.0
    du[1] = -im * E * u[1]
    du[2] = -im * E * u[2]
end

u = [1.0 + 0.0im, 1.5 + 0.0im]
T = 1000.0
prob = DE.ODEProblem(Eq, u, (0, T))
sol = solve(prob; dense=false)

retcode: Success
Interpolation: 1st order linear
t: 131670-element Vector{Float64}:
    0.0
    0.005000000000000001
    0.010344633659463654
    0.01619847574422298
    0.022344232061996535
    0.02881235563143842
    0.035492729025985
    0.04237475170780038
    0.049399834238509494
    0.05655002245053027
    0.06379269375026475
    0.07111254964928064
    0.07849080044624362
    ⋮
  999.9223526005073
  999.9299476242321
  999.937542647655
  999.9451376717822
  999.9527326956076
  999.9603277191311
  999.9679227433588
  999.9755177672848
  999.983112790909
  999.9907078152373
  999.9983028392638
 1000.0
u: 131670-element Vector{Vector{ComplexF64}}:
 [1.0 + 0.0im, 1.5 + 0.0im]
 [0.5403022976485196 - 0.8414709729990841im, 0.8104534464727791 - 1.2622064594986262im]
 [-0.47778395896528436 - 0.8784773183456952im, -0.7166759384479265 - 1.3177159775185427im]
 [-0.9951916771173307 + 0.09794521574736847im, -1.492787515675996 + 0.14691782362105268im]
 [-0.24114204859261804 + 0.970489544053002

In [138]:
function f(a)
    p0 = a
    f(u,p,t) = p*u
    u0 = 1
    for i=1:100
        tspan = ((i-1)/100,i/100)
        prob = DE.ODEProblem(f,u0,tspan, p=p0)
        tmp_prob = remake(prob, p=p0)
        sol = DE.solve(tmp_prob, Tsit5(); reltol=1e-8, abstol=1e-8)
        u0 = sol.u[end]
    end
    return u0
end
f_grad = x -> ForwardDiff.derivative(f,x)

#69 (generic function with 1 method)

In [139]:
@time f_grad(1.0) 

  1.227274 seconds (2.51 M allocations: 140.309 MiB, 4.15% gc time, 99.45% compilation time)


2.71828182845909

In [140]:
@time f_grad(1.0)

  0.006111 seconds (30.10 k allocations: 1.910 MiB)


2.71828182845909

In [213]:
function test3(a)
    p0 = a
    f(u,p,t) = p*u
    u0 = 0.0*im
    for i=1:10
        tspan = (0.0,2.0)
        prob = DE.ODEProblem(f,u0,tspan, p=p0)
        tmp_prob = remake(prob, p=p0)
        sol = DE.solve(tmp_prob, Tsit5(); reltol=1e-8, abstol=1e-8)
        u0 = sol.u[end]
    end
    return real(u0)
end

test3 (generic function with 1 method)

In [214]:
test3(1.0*im)

0.0

In [215]:
DF(test3) = x -> Zygote.gradient(test3, x)

DF (generic function with 1 method)

In [216]:
DF(test3)(1.0)

LoadError: MethodError: no method matching similar(::Float64, ::Int64)
[0mClosest candidates are:
[0m  similar([91m::PyCall.PyVector{T}[39m, ::Int64...) where T at ~/.julia/packages/PyCall/7a7w0/src/conversions.jl:268
[0m  similar([91m::PyCall.PyVector[39m, ::Any, [91m::Tuple{Vararg{Int64, N}} where N[39m) at ~/.julia/packages/PyCall/7a7w0/src/conversions.jl:265
[0m  similar([91m::ReverseDiff.TrackedArray[39m, ::Union{Integer, AbstractUnitRange}...) at ~/.julia/packages/ReverseDiff/Z4pL0/src/tracked.jl:387
[0m  ...

# Playground

In [6]:
f = x -> x^2

#6 (generic function with 1 method)

In [10]:
@time DF(f)(12)

  0.000012 seconds


24

test3 (generic function with 1 method)

In [159]:
test3_grad = x -> Zygote.gradient(test3,x);

In [160]:
test3([1.01])# + 0*im)

7.538324935121546