Skip to content

Conversation

@andrewrosemberg
Copy link
Member

Tests end-to-end proxy penalty training pipeline with Lux.

@codecov
Copy link

codecov bot commented Jul 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@klamike klamike changed the title WIP: LUX + BNK Lux integration test Jul 10, 2025
@klamike klamike merged commit b29456b into main Jul 10, 2025
4 checks passed
@klamike
Copy link
Collaborator

klamike commented Jul 10, 2025

Gave it a shot with Reactant.jl (patch below) and got an interesting error, seemingly from inside Lux? Probably I am just doing something wrong..

2025-07-10 00:12:50.523984: I external/xla/xla/service/service.cc:153] XLA service 0x17f09970 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-10 00:12:50.524010: I external/xla/xla/service/service.cc:161]   StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752120770.526854  654708 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752120770.527096  654708 gpu_helpers.cc:136] XLA backend allocating 63771869184 bytes on device 0 for BFCAllocator.
I0000 00:00:1752120770.527199  654708 gpu_helpers.cc:177] XLA backend will use up to 21257289728 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752120770.573955  654708 cuda_dnn.cc:471] Loaded cuDNN version 91000
Penalty Training: Error During Test at BatchNLPKernels.jl/test/test_penalty.jl:92
  Got exception outside of a @test
  Scalar indexing is disallowed.
  Invocation of getindex(::ConcretePJRTArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
  This is typically caused by calling an iterating implementation of a method.
  Such implementations *do not* execute on the GPU, but very slowly on the CPU,
  and therefore should be avoided.
  
  If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
  to enable scalar iteration globally or for the operations in question.
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] errorscalar(op::String)
      @ GPUArraysCore GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
    [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
      @ GPUArraysCore GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
    [4] assertscalar(op::String)
      @ GPUArraysCore GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
    [5] getindex(::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ::Int64, ::Int64)
      @ Reactant Reactant/6PN6T/src/ConcreteRArray.jl:306
    [6] _generic_matmatmul!(C::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, A::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, B::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
      @ LinearAlgebra LinearAlgebra/src/matmul.jl:894
    [7] generic_matmatmul!
      LinearAlgebra/src/matmul.jl:868 [inlined]
    [8] _mul!
      LinearAlgebra/src/matmul.jl:287 [inlined]
    [9] mul!
      LinearAlgebra/src/matmul.jl:285 [inlined]
   [10] mul!
      LinearAlgebra/src/matmul.jl:253 [inlined]
   [11] *
      LinearAlgebra/src/matmul.jl:114 [inlined]
   [12] matmul
      @ LuxLib/XxZ1M/src/impl/matmul.jl:54 [inlined]
   [13] fused_dense
      @ LuxLib/XxZ1M/src/impl/dense.jl:26 [inlined]
   [14] fused_dense
      @ LuxLib/XxZ1M/src/impl/dense.jl:16 [inlined]
   [15] fused_dense_bias_activation
      @ LuxLib/XxZ1M/src/api/dense.jl:36 [inlined]
   [16] (::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True})(x::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ps::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, st::@NamedTuple{})
      @ Lux Lux/ie6Qh/src/layers/basic.jl:363
   [17] apply
      @ LuxCore/q0Mrq/src/LuxCore.jl:155 [inlined]
   [18] macro expansion
      @ Lux/ie6Qh/src/layers/containers.jl:0 [inlined]
   [19] applychain(layers::@NamedTuple{layer_1::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, x::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
      @ Lux Lux/ie6Qh/src/layers/containers.jl:511
   [20] (::Chain{@NamedTuple{layer_1::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing})(x::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ps::@NamedTuple{layer_1::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, layer_2::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, layer_3::@NamedTuple{weight::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, bias::ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}})
      @ Lux Lux/ie6Qh/src/layers/containers.jl:509
   [21] test_penalty_training(; filename::String, dev_gpu::Function, backend::CUDABackend, batch_size::Int64, dataset_size::Int64, rng::TaskLocalRNG, T::Type)
      @ Main BatchNLPKernels.jl/test/test_penalty.jl:60
   [22] macro expansion
      @ BatchNLPKernels.jl/test/test_penalty.jl:99 [inlined]
   [23] macro expansion
      /packagesTest/src/Test.jl:1704 [inlined]
   [24] top-level scope
      @ BatchNLPKernels.jl/test/test_penalty.jl:93
   [25] include(fname::String)
      @ Main ./sysimg.jl:38
   [26] top-level scope
      @ BatchNLPKernels.jl/test/runtests.jl:45
   [27] include(fname::String)
      @ Main ./sysimg.jl:38
   [28] top-level scope
      @ none:6
   [29] eval
      @ ./boot.jl:430 [inlined]
   [30] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:296
   [31] _start()
      @ Base ./client.jl:531

Levels 6-12 are where I think the issue is -- Reactant should intercept that matmul?

Patch
diff --git a/Project.toml b/Project.toml
index 482f08d..bbd301d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -19,6 +19,8 @@ BNKJuMP = "JuMP"
 ExaModels = "0.8.3"
 
 [extras]
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
 AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
 CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
 DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -38,4 +40,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
 
 [targets]
-test = ["Test", "CUDA", "GPUArraysCore", "LinearAlgebra", "OpenCL", "pocl_jll", "AcceleratedKernels", "DifferentiationInterface", "FiniteDifferences", "Zygote", "PGLib", "PowerModels", "Lux", "LuxCUDA", "MLUtils", "Optimisers", "Random"]
+test = ["Test", "Enzyme", "Reactant", "CUDA", "GPUArraysCore", "LinearAlgebra", "OpenCL", "pocl_jll", "AcceleratedKernels", "DifferentiationInterface", "FiniteDifferences", "Zygote", "PGLib", "PowerModels", "Lux", "LuxCUDA", "MLUtils", "Optimisers", "Random"]
diff --git a/test/runtests.jl b/test/runtests.jl
index 1b1dfd1..50de649 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -22,6 +22,7 @@ using MLUtils
 using Optimisers
 using CUDA
 using Random
+using Reactant, Enzyme
 import GPUArraysCore: @allowscalar
 
 ExaModels.convert_array(x, ::OpenCLBackend) = CLArray(x)
@@ -41,8 +42,8 @@ end
 
 include("luksan.jl")
 include("power.jl")
+include("test_penalty.jl")
 include("test_viols.jl")
 include("test_diff.jl")
 include("api.jl")
-include("config.jl")
-include("test_penalty.jl")
\ No newline at end of file
+include("config.jl")
\ No newline at end of file
diff --git a/test/test_penalty.jl b/test/test_penalty.jl
index 7723f61..04620f4 100644
--- a/test/test_penalty.jl
+++ b/test/test_penalty.jl
@@ -73,7 +73,7 @@ function test_penalty_training(; filename="pglib_opf_case14_ieee.m", dev_gpu = g
     data = DataLoader((Θ_train); batchsize=batch_size, shuffle=true) .|> dev_gpu
     for (Θ) in data
         _, loss_val, stats, train_state = Training.single_train_step!(
-            AutoZygote(),          # AD backend
+            AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)),          # AD backend
             PenaltyLoss,
             (Θ),  # data
             train_state
@@ -91,7 +91,7 @@ end
 
 @testset "Penalty Training" begin
     backend, dev = if haskey(ENV, "BNK_TEST_CUDA")
-        CUDABackend(), gpu_device()
+        CUDABackend(), reactant_device()
     else
         CPU(), cpu_device()
     end

@klamike klamike mentioned this pull request Jul 10, 2025
@klamike klamike deleted the ar/penalty_proxy branch July 24, 2025 18:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants