Skip to content

Commit

Permalink
Fix tests for WaveOpticsPropagation 0.2.0 API
Browse files Browse the repository at this point in the history
  • Loading branch information
roflmaostc committed Mar 11, 2024
1 parent 32ae56d commit a239bc5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Parameters = "0.12"
Plots = "1"
RadonKA = "0.6"
Statistics = "1.10"
WaveOpticsPropagation = "0.1.0"
WaveOpticsPropagation = "0.2.0"
Zygote = "0.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/wave_optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Optimize the patterns to match the target with the wave optical model.
"""
function optimize_patterns(target, ps::WaveOptics, op::GradientBased, loss::LossThreshold)
function optimize_patterns(target, ps::WaveOptics, op::GradientBased, loss::Union{LossThreshold, LossThresholdSparsity})
angles = ps.angles
μ = ps.μ
L = ps.L
Expand Down
29 changes: 29 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using ChainRulesTestUtils
@test target2 !== (0.7 .> optimize_patterns((target2), geometry_vial, optimizer2, LossThreshold(thresholds=(0.65, 0.75)))[2])
@test target2 == (0.45 .< optimize_patterns((target2), geometry_vial, optimizer2, LossThreshold(thresholds=(0.4, 0.5)))[2])
@test target2 !== (0.45 .> optimize_patterns((target2), geometry_vial, optimizer2, LossThreshold(thresholds=(0.4, 0.5)))[2])

@test target2 == (0.45 .< optimize_patterns((target2), geometry_vial, optimizer2, LossThresholdSparsity(thresholds=(0.4, 0.5)))[2])

patterns, printed, res = optimize_patterns((target2), geometry_vial, optimizer2, LossThreshold(thresholds=(0.4, 0.5)))
save_patterns(tempdir(), patterns, printed, angles2, target2; overwrite=true)
Expand Down Expand Up @@ -65,6 +67,33 @@ end

end

@testset "Simple wave optical simulation with sparse loss" begin

sz2 = (24, 24, 24)
target = box(Float32, sz2, (17, 17, 10)) .- box(Float32, sz2, (9, 9, 8));

n_resin = 1.5f0
angles = range(0, π, 20)
optimizer = GradientBased(optimizer=Optim.LBFGS(), options=Optim.Options(iterations=15, store_trace=true))


L = 100f-6
loss = LossThresholdSparsity(thresholds=(0.65, 0.75))

optimizer = GradientBased(optimizer=Optim.LBFGS(), options=Optim.Options(iterations=20, store_trace=true))

waveoptics = WaveOptics(
z=(range(-L/2, L/2, size(target,1))),
L=L,
λ=405f-9 / n_resin,
μ=nothing,
angles=angles,
)
patterns, printed, res = optimize_patterns(target, waveoptics, optimizer, loss)
@test target == (0.7 .< printed)
@test target !== (0.7 .> printed)

end


@testset "test rrule of custom loss" begin
Expand Down

0 comments on commit a239bc5

Please sign in to comment.