In [1]:
using LinearAlgebra
using Random
# using BLAS
# using Zygote
using ChainRulesCore
using Flux
using CUDA
using ThreadsX

include("quantum_nn.jl")

[33m[1m│ [22m[39m- Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
[33m[1m│ [22m[39m- If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
[33m[1m└ [22m[39m[90m@ FluxCUDAExt C:\Users\jonat\.julia\packages\Flux\hiqg1\ext\FluxCUDAExt\FluxCUDAExt.jl:57[39m


train_network_fd!

In [None]:
"""
Benchmark comparison function
"""
function benchmark_comparison()
    println("=== Performance Benchmark ===")
    
    n = 8
    num_layers = 10
    
    # Create target
    Random.seed!(42)
    A = randn(ComplexF64, n, n)
    target, _ = qr(A)
    target = Matrix(target)
    
    # Test different approaches
    nn_fast = FastGivensNN(n, num_layers, :nearest_neighbor, use_gpu=false)
    
    println("Timing fast CPU version...")
    @time for _ in 1:100
        U = apply_network_fast!(nn_fast, false)
    end
    
    # Test AD optimization
    println("Testing AD optimization...")
    @time losses = train_network_ad!(nn_fast, target, epochs=200)
    
    final_U = apply_network_fast!(nn_fast, false)
    final_loss = real(tr((final_U - target)' * (final_U - target)))
    
    println("Final loss: $final_loss")
    println("Unitarity error: ", norm(final_U' * final_U - I))
end

"""
Demo function showcasing all optimizations
"""
function demo_optimized_givens_nn()
    println("=== Optimized Givens Neural Network Demo ===")
    
    n = 10
    num_layers = 8
    
    # Create target unitary
    Random.seed!(42)
    A = randn(ComplexF64, n, n)
    target, _ = qr(A)
    target = Matrix(target)
    
    println("Target unitary ($(n)×$(n)):")
    display(round.(target[1:min(4,n), 1:min(4,n)], digits=3))
    println()
    
    nn = FastGivensNN(n, num_layers, :alternating)
    
    println("Network architecture:")
    println("- Dimensions: $(n)×$(n)")
    println("- Layers: $num_layers")
    println("- Total parameters: $(nn.param_count)")
    println("- GPU available: $use_gpu")
    println()
    
    # Initial loss
    initial_U = apply_network_fast!(nn)
    initial_loss = real(tr((initial_U - target)' * (initial_U - target)))
    println("Initial loss: $initial_loss")
    
    # Train with AD
    println("\nTraining with automatic differentiation...")
    # losses = train_network_ad!(nn, target, lr=0.02, epochs=500, use_gpu=use_gpu)
    losses = train_network_fd!(nn, target, lr=0.02, epochs=200)
    
    # Final results
    final_U = apply_network_fast!(nn)
    final_loss = real(tr((final_U - target)' * (final_U - target)))
    
    println("\nOptimization completed!")
    println("Final loss: $final_loss")
    println("Improvement: $(initial_loss/final_loss)x")
    
    println("\nApproximated unitary:")
    display(round.(final_U[1:min(4,n), 1:min(4,n)], digits=3))
    
    # Check unitarity
    unitarity_error = norm(final_U' * final_U - I)
    println("\nUnitarity error: $unitarity_error")
    
    return nn, target, final_U, losses
end

# Run demonstration
nn, target, final_u, losses = demo_optimized_givens_nn()
# benchmark_comparison()

In [4]:
n = 10
nlayers = 8
nn = FastGivensNN(n, nlayers, :alternating)
A = randn(ComplexF64, n, n)
target, _ = qr(A)
target = Matrix(target)
train_network_ad!(nn, target; lr=0.01, epochs=500)

Starting AD training...


LoadError: TaskFailedException

[91m    nested task error: [39mMutating arrays is not supported -- called setindex!(Matrix{ComplexF64}, ...)
    This error occurs when you ask Zygote to differentiate operations that change
    the elements of arrays in place (e.g. setting values with x .= ...)
    
    Possible fixes:
    - avoid mutating operations (preferred)
    - or read the documentation and solutions for this error
      https://fluxml.ai/Zygote.jl/latest/limitations
    
    Stacktrace:
      [1] [0m[1merror[22m[0m[1m([22m[90ms[39m::[0mString[0m[1m)[22m
    [90m    @[39m [90mBase[39m [90m.\[39m[90m[4merror.jl:35[24m[39m
      [2] [0m[1m_throw_mutation_error[22m[0m[1m([22m[90mf[39m::[0mFunction, [90margs[39m::[0mMatrix[90m{ComplexF64}[39m[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\lib\[39m[90m[4marray.jl:70[24m[39m
      [3] [0m[1m(::Zygote.var"#544#545"{Matrix{ComplexF64}})[22m[0m[1m([22m[90m#unused#[39m::[0mNothing[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\lib\[39m[90m[4marray.jl:82[24m[39m
      [4] [0m[1m(::Zygote.var"#2627#back#546"{Zygote.var"#544#545"{Matrix{ComplexF64}}})[22m[0m[1m([22m[90mΔ[39m::[0mNothing[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\ZygoteRules\CkVIK\src\[39m[90m[4madjoint.jl:72[24m[39m
      [5] [0m[1mPullback[22m
    [90m    @[39m [90mc:\Users\jonat\OneDrive - Cornell University\programming\cornell courses\research\experimenting\ed\[39m[90m[4mquantum_nn.jl:16[24m[39m[90m [inlined][39m
      [6] [0m[1m(::Zygote.Pullback{Tuple{typeof(apply_givens_inplace!), Matrix{ComplexF64}, Int64, Int64, Float64, Float64}, Any})[22m[0m[1m([22m[90mΔ[39m::[0mNothing[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\compiler\[39m[90m[4minterface2.jl:0[24m[39m
      [7] [0m[1mPullback[22m
    [90m    @[39m [90mc:\Users\jonat\OneDrive - Cornell University\programming\cornell courses\research\experimenting\ed\[39m[90m[4mquantum_nn.jl:126[24m[39m[90m [inlined][39m
      [8] [0m[1m(::Zygote.Pullback{Tuple{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Int64}, Tuple{Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:θ, Zygote.Context{false}, FastGivensLayer, Vector{Float64}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#670"{Vector{Float64}, Tuple{Int64}, Tuple{NoTangent}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#670"{Vector{Float64}, Tuple{Int64}, Tuple{NoTangent}}}, Zygote.var"#2033#back#217"{Zygote.var"#back#215"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:φ, Zygote.Context{false}, FastGivensLayer, Vector{Float64}}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:layer, Zygote.Context{false}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, FastGivensLayer}}, Zygote.Pullback{Tuple{typeof(apply_givens_inplace!), Matrix{ComplexF64}, Int64, Int64, Float64, Float64}, Any}, Zygote.var"#2033#back#217"{Zygote.var"#back#215"{2, 1, Zygote.Context{false}, Int64}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:layer, Zygote.Context{false}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, FastGivensLayer}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:U, Zygote.Context{false}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Matrix{ComplexF64}}}, Zygote.var"#back#250"{Zygote.var"#2033#back#217"{Zygote.var"#back#215"{2, 2, Zygote.Context{false}, Int64}}}, Zygote.var"#2033#back#217"{Zygote.var"#back#215"{2, 1, Zygote.Context{false}, Int64}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:layer, Zygote.Context{false}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, FastGivensLayer}}, Zygote.var"#back#249"{Zygote.var"#2033#back#217"{Zygote.var"#back#215"{2, 1, Zygote.Context{false}, Int64}}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:pairs, Zygote.Context{false}, FastGivensLayer, Vector{Tuple{Int64, Int64}}}}, Zygote.ZBack{ChainRules.var"#getindex_pullback#670"{Vector{Tuple{Int64, Int64}}, Tuple{Int64}, Tuple{NoTangent}}}}})[22m[0m[1m([22m[90mΔ[39m::[0mNothing[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\compiler\[39m[90m[4minterface2.jl:0[24m[39m
      [9] [0m[1mPullback[22m
    [90m    @[39m [90mC:\Users\jonat\.julia\packages\ThreadsX\Bml38\src\[39m[90m[4mforeach.jl:32[24m[39m[90m [inlined][39m
     [10] [0m[1mPullback[22m
    [90m    @[39m [90m.\[39m[90m[4mthreadingconstructs.jl:410[24m[39m[90m [inlined][39m
     [11] [0m[1m(::Zygote.Pullback{Tuple{ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}}, Tuple{Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:f, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:p, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}}, Zygote.Pullback{Tuple{typeof(ThreadsX.Implementations.foreach_linear_seq), var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Val{false}}, Any}, Zygote.var"#1958#back#181"{Zygote.var"#177#180"}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:simd, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, Val{false}}}}})[22m[0m[1m([22m[90mΔ[39m::[0mNothing[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\compiler\[39m[90m[4minterface2.jl:0[24m[39m
     [12] [0m[1m(::Zygote.var"#393#394"{Nothing, Zygote.Pullback{Tuple{ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}}, Tuple{Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:f, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}}}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:p, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}}, Zygote.Pullback{Tuple{typeof(ThreadsX.Implementations.foreach_linear_seq), var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}, Val{false}}, Any}, Zygote.var"#1958#back#181"{Zygote.var"#177#180"}, Zygote.var"#2184#back#307"{Zygote.var"#back#306"{:simd, Zygote.Context{false}, ThreadsX.Implementations.var"#59#60"{var"#25#26"{Matrix{ComplexF64}, FastGivensLayer}, Val{false}, SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}, Val{false}}}}}})[22m[0m[1m([22m[0m[1m)[22m
    [90m    @[39m [35mZygote[39m [90mC:\Users\jonat\.julia\packages\Zygote\zowwZ\src\lib\[39m[90m[4mbase.jl:134[24m[39m

In [8]:
using Enzyme
using LinearAlgebra



function get_loss_fn(target::Matrix{ComplexF64}, nn)
    function loss_function_ad(params::Vector{Float64})
        set_parameters!(nn, params)
        U = apply_network_fast!(nn)
        return sum(abs2.(U - target))
    end
    return loss_function_ad
end


n = 5
nlayers = 1
nn = FastGivensNN(n, nlayers, :alternating)
A = randn(ComplexF64, n, n)
target, _ = qr(A)
target = Matrix(target)
x  =  get_parameters(nn)
dx = zero(x)

Enzyme.autodiff(Reverse, get_loss_fn(target, nn), Active, Duplicated(x, dx))
println(ex)

LoadError: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
See https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.
The potentially writing call is   store {} addrspace(10)* %.fca.1.1.extract, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !47, !noalias !61, using   %.fca.1.1.gep = getelementptr inbounds { {} addrspace(10)*, { i64, {} addrspace(10)*, i64, {} addrspace(10)*, {} addrspace(10)* } }, { {} addrspace(10)*, { i64, {} addrspace(10)*, i64, {} addrspace(10)*, {} addrspace(10)* } }* %.innerparm, i64 0, i32 1, i32 1, !dbg !47


In [6]:
get_loss_fn(target, nn)(x)

8.944069804641778

In [17]:
@time apply_network_fast!(nn)

  0.000572 seconds (403 allocations: 36.609 KiB)


6×6 Matrix{ComplexF64}:
 -0.0940218-0.0362039im  -0.293733+0.0140053im  …  -0.115311-0.246804im
  -0.130944+0.294091im    0.318866-0.0907837im     -0.324477-0.0596834im
  0.0884993+0.0138738im  -0.301453-0.295738im        0.41824+0.0528502im
   0.178484-0.258683im    0.266073+0.220996im       -0.23739+0.546708im
  0.0863162-0.806239im    0.182438-0.308155im      -0.245689-0.291155im
   0.313208-0.154415im   -0.611475-0.0591047im  …  -0.204732+0.31134im

In [6]:
losses

500-element Vector{Float64}:
 9.874497844157693
 8.822152491911272
 7.855083321644059
 6.973382995819344
 6.175124826387282
 5.4585627342107
 4.822219625799303
 4.263870558457331
 3.779553908504631
 3.363285068882016
 3.0075161254995892
 2.7039843105421224
 2.44454649456204
 ⋮
 0.0010862079451097633
 0.0010862064291848863
 0.0010862049582019981
 0.0010862035307944759
 0.0010862021456381058
 0.0010862008014497987
 0.0010861994969862568
 0.0010861982310427275
 0.0010861970024517442
 0.0010861958100820762
 0.001086194652837432
 0.0010861935296554823