Skip to content

Commit

Permalink
Added jaxopt solvers interface to optimal control module.
Browse files Browse the repository at this point in the history
  • Loading branch information
alucantonio committed Jul 13, 2023
1 parent 954c26a commit 061db5f
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ __pycache__/*
.ipynb_checkpoints/*
examples/.ipynb_checkpoints/*
.DS_Store
*.vtk

# Project files
.ropeproject
Expand Down
277 changes: 216 additions & 61 deletions examples/pacman_fracture.ipynb

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions src/dctkit/math/opt/optctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typeguard import check_type
from petsc4py import PETSc, init
from petsc4py.PETSc import Vec
import jaxopt


class OptimizationProblem():
Expand Down Expand Up @@ -36,6 +37,8 @@ def __init__(self, dim: int, state_dim: int,

if solver_lib == "petsc":
self.solver = PETScSolver(self)
elif solver_lib == "jaxopt":
self.solver = JAXoptSolver(self)
else:
self.solver = PygmoSolver(self)

Expand Down Expand Up @@ -174,6 +177,24 @@ def run(self, x0: npt.NDArray, **kwargs: Dict) -> npt.NDArray:
return u


class JAXoptSolver(OptimizationSolver):
def __init__(self, prb: OptimizationProblem):
super().__init__(prb)
self.prb.obj_args = {}

def set_obj_args(self, args: Dict):
self.prb.obj_args = args

def objective(self, x: npt.NDArray | Array) -> npt.NDArray | Array:
return self.prb.obj(x, **self.prb.obj_args)

def run(self, x0: npt.NDArray, **kwargs: Dict) -> npt.NDArray | Array:
maxiter = kwargs["maxeval"]
solver = jaxopt.LBFGS(self.objective, maxiter=maxiter)
u = solver.run(x0).params.__array__()
return u


class PygmoSolver(OptimizationSolver):
def __init__(self, prb: OptimizationProblem):
super().__init__(prb)
Expand Down
23 changes: 13 additions & 10 deletions tests/test_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def test_poisson(setup_test, optimizer, energy_formulation):

boundary_values = (np.array(bnodes, dtype=dt.int_dtype), b_values)

dim_0 = S.num_nodes
f_vec = -4.*np.ones(dim_0, dtype=dt.float_dtype)
num_nodes = S.num_nodes
f_vec = -4.*np.ones(num_nodes, dtype=dt.float_dtype)
f = C.Cochain(0, True, S, f_vec)
star_f = C.star(f)

mask = np.ones(dim_0, dtype=dt.float_dtype)
mask = np.ones(num_nodes, dtype=dt.float_dtype)
mask[bnodes] = 0.

# initial guess
u_0 = 0.01*np.random.rand(dim_0).astype(dt.float_dtype)
u_0 = 0.01*np.random.rand(num_nodes).astype(dt.float_dtype)
u_0 = np.array(u_0, dtype=dt.float_dtype)

if optimizer == "scipy":
Expand Down Expand Up @@ -104,7 +104,7 @@ def obj_poisson(x, f, k, boundary_values, gamma, mask):
args = {'f': f_vec, 'k': k, 'boundary_values': boundary_values,
'gamma': gamma, 'mask': mask}

prb = oc.OptimizationProblem(dim=dim_0, state_dim=dim_0, objfun=obj)
prb = oc.OptimizationProblem(dim=num_nodes, state_dim=num_nodes, objfun=obj)
prb.set_obj_args(args)
u = prb.solve(u_0, algo="lbfgs").astype(dt.float_dtype)

Expand All @@ -127,7 +127,8 @@ def energy_poisson(x, f, k, boundary_values, gamma):
energy = norm_grad + bound_term + penalty
return energy

args = (f_vec, k, boundary_values, gamma)
args = {'f': f_vec, 'k': k, 'boundary_values': boundary_values,
'gamma': gamma}
obj = energy_poisson

else:
Expand All @@ -147,12 +148,14 @@ def obj_poisson(x, f, k, boundary_values, gamma, mask):
obj = 0.5*jnp.linalg.norm(r*mask)**2 + 0.5*gamma*penalty
return obj

args = (f_vec, k, boundary_values, gamma, mask)
args = {'f': f_vec, 'k': k, 'boundary_values': boundary_values, 'gamma': gamma,
'mask': mask}
obj = obj_poisson

solver = jaxopt.LBFGS(obj, maxiter=5000)
sol = solver.run(u_0, *args)
u = sol.params
prb = oc.OptimizationProblem(dim=num_nodes, state_dim=num_nodes, objfun=obj,
solver_lib="jaxopt")
prb.set_obj_args(args)
u = prb.solve(x0=u_0)

assert u.dtype == dt.float_dtype
assert u_true.dtype == u.dtype
Expand Down

0 comments on commit 061db5f

Please sign in to comment.