Skip to content

Commit

Permalink
test: fix typo in NNODE tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jan 21, 2024
1 parent 5df70a8 commit 6e95919
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))
sol = solve(prob, NeuralPDE.NNODE(chain, opt), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200)

@test_throws Any solve(prob, NeuralPDE.NNODE(chain, opt; autodiff = true), dt = 1 / 20.0f0,
solve(prob, NeuralPDE.NNODE(chain, opt; autodiff = true), dt = 1 / 20.0f0,
verbose = true, abstol = 1.0f-10, maxiters = 200)

sol = solve(prob, NeuralPDE.NNODE(chain, opt), verbose = true,
Expand All @@ -26,7 +26,7 @@ sol = solve(prob, NeuralPDE.NNODE(chain, opt), verbose = true,
sol = solve(prob, NeuralPDE.NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true,
abstol = 1.0f-10, maxiters = 200)

Any solve(prob, NeuralPDE.NNODE(luxchain, opt; autodiff = true),
solve(prob, NeuralPDE.NNODE(luxchain, opt; autodiff = true),
dt = 1 / 20.0f0,
verbose = true, abstol = 1.0f-10, maxiters = 200)

Expand Down Expand Up @@ -90,13 +90,13 @@ opt = OptimizationOptimisers.Adam(0.01)
sol = solve(prob, NeuralPDE.NNODE(chain, opt), verbose = true, maxiters = 400)
@test sol.errors[:l2] < 0.5

solve(prob, NeuralPDE.NNODE(chain, opt; batch = true), verbose = true,
@test_throws AssertionError solve(prob, NeuralPDE.NNODE(chain, opt; batch = true), verbose = true,
maxiters = 400)

sol = solve(prob, NeuralPDE.NNODE(luxchain, opt), verbose = true, maxiters = 400)
@test sol.errors[:l2] < 0.5

@test_throws Any solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
@test_throws AssertionError solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
maxiters = 400)

sol = solve(prob,
Expand Down Expand Up @@ -150,15 +150,15 @@ sol = solve(prob, NeuralPDE.NNODE(chain, opt), verbose = true, maxiters = 400,
abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5

@test_throws Any solve(prob, NeuralPDE.NNODE(chain, opt; batch = true), verbose = true,
@test_throws AssertionError solve(prob, NeuralPDE.NNODE(chain, opt; batch = true), verbose = true,
maxiters = 400,
abstol = 1.0f-8)

sol = solve(prob, NeuralPDE.NNODE(luxchain, opt), verbose = true, maxiters = 400,
abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5

@test_throws Any solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
@test_throws AssertionError solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
maxiters = 400,
abstol = 1.0f-8)

Expand Down

0 comments on commit 6e95919

Please sign in to comment.