Skip to content

Commit

Permalink
refactor: fix min max adaptive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jan 31, 2024
1 parent 5a8c40a commit 8abaf99
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,15 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::MiniMaxAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
pde_max_optimiser = adaloss.pde_max_optimiser
pde_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(pde_max_optimiser, adaloss.pde_loss_weights)
bc_max_optimiser = adaloss.bc_max_optimiser
bc_max_optimiser_setup = OptimizationOptimisers.Optimisers.setup(bc_max_optimiser, adaloss.bc_loss_weights)
iteration = pinnrep.iteration

function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
OptimizationOptimisers.Optimisers.update(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses)
OptimizationOptimisers.Optimisers.update(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses)
OptimizationOptimisers.Optimisers.update!(pde_max_optimiser_setup, adaloss.pde_loss_weights, -pde_losses)
OptimizationOptimisers.Optimisers.update!(bc_max_optimiser_setup, adaloss.bc_loss_weights, -bc_losses)
logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
Expand Down

0 comments on commit 8abaf99

Please sign in to comment.