In [None]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from data_classes import State, Control, Trajectory, VolumeElements
from solver import FiniteVolumeSolver, FVEulerSolver
from adjoint_solver import AdjointFiniteVolumeSolver, AdjointFVEulerSolver
from optimizer import Optimizer
from plotting_utils import plot_densities, plot_controls, plot_residuals, plot_all_residuals, plot_objective, animate_evolution, animate_evolutions, animate_trajectories


In [2]:
# 1. Spatial domains
edges = jnp.linspace(0.0, 25.0, 1001)
ve   = VolumeElements.from_edges(edges)
centres = ve.centers

In [3]:
# 2. Time discretization
t     = jnp.linspace(0.0, 1.0, 1001)
dt    = t[1] - t[0]

In [4]:
# Example parameters

K0 = 0.05  
a0 = 0.2
lam = 1/2
nu = 0.0 # 0 #  -1 < nu <= inf
K_fun     = lambda x,y : K0 * jnp.where( x + y <=25, (1 + x)**0.25 * (1+y)**0.25 , 0.0)     # coagulation kernel
alpha_fun = lambda x: a0 * x ** lam                                                         # fragmentation rate
b_fun     = lambda x, y: (2.0 + nu) / y * (x / y) ** nu                                     # daughter distribution

# Initial conditions
centre = 7.5
f0 = jnp.where(centres <= 25, 1.0 * jnp.exp( - (centres - centre) ** 2 /50 ), 0.0)
state0  = State(f0, centres)

# running cost weight
weight = 0.5 

sign = 1.0 # -1 for maximization, +1 for minimization

u_min = 0.5
u_max = 4.0

output_dir = 'test_case_paper'

In [5]:
# Example costs

# Target region
lower_bound = 0.0 # minimum size of particles in target region
upper_bound = 5.0 # maximum size of particles in target region

def chi(x: jax.Array, lb: jax.Array, ub: jax.Array) -> jax.Array:
    i = jnp.logical_and(x >= lb, x <= ub)
    return jnp.where(i, 1.0, 0.0)

def psi(f: jax.Array, x: jax.Array, lb: jax.Array, ub: jax.Array) -> jax.Array:
    dx = x[1] - x[0]
    return jnp.sum(f * chi(x, lb, ub) * dx)

# terminal cost
def terminal_cost(f: State) -> jax.Array:
    return sign * psi(f.f, f.centers, lower_bound, upper_bound)

def terminal_cost_grad(f: State) -> jax.Array:
    return sign * chi(f.centers, lower_bound, upper_bound)

# running cost
def running_cost(control: Control) -> jax.Array:
    dt = control.times[1] - control.times[0]
    running_cost = 0.5 * weight * jnp.sum((control.values - 1.0) ** 2) * dt
    return running_cost

def running_cost_grad(control: Control) -> jax.Array:
    return weight * (control.values - 1.0)

In [6]:
# ------------------------------------
#      Solvers and optimizer
# ------------------------------------

solver = FVEulerSolver(ve, t, K_fun, alpha_fun, b_fun, mass_fix = 'no')

adjoint_solver = AdjointFVEulerSolver(ve, t, K_fun, alpha_fun, b_fun)

opt = Optimizer(
    init_state     = state0,
    forward_solver = solver,
    adjoint_solver = adjoint_solver,
    terminal_cost  = terminal_cost,
    terminal_cost_grad = terminal_cost_grad,
    running_cost = running_cost,
    running_cost_grad = running_cost_grad,
    weight = weight,
    u_min = u_min, u_max = u_max,
    max_ls_iter = 20,
    armijo_rho = 0.5,
    armijo_sigma = 0.001,
    verbose = True
)

In [None]:
# --------------------------------------------------------------------
# 5.  Optimisation
# --------------------------------------------------------------------

# 5.1 Optimization routine

u0 = jnp.ones_like(t) # no control

print("\n--- optimise with adjoint ---")
out_adj = opt.pgd_adjoint( 1.0 * u0, lr=0.1, n_iter=40, tolerance = 7.5e-2)
u_adj, J_adj, terminal_cost_adj, grad_hist, delta_control_adj, delta_H_adj, res_hist_adj, loss_hist_adj, terminal_hist_adj, it_final = out_adj

# 5.2  Re-run forward solver to obtain full trajectorie
optimal_control = Control(u_adj, t)
no_control = Control(u0, t)

state_T_adj, traj_adj, _ = solver.solve(initial_state=state0, control=optimal_control)
state_T_no_control, traj_no_control, _ = solver.solve(initial_state=state0, control=no_control)

J_1 = opt.total_cost(traj_no_control, no_control)
terminal_cost_1 = opt.terminal_cost(state_T_no_control)

print(f"No control   J_no:con  = {float(J_1):.6e}, terminal cost = {float(terminal_cost_1):6e}")

# 5.2 Mass loss diagnostics

rel_mass_loss = (state0.first_moment()-state_T_adj.first_moment())/ state0.first_moment()
print(rel_mass_loss)


--- optimise with adjoint ---
iter 000,  delta_control_rel=7.500e-01, delta_H_rel=6.968e-01 ||g||=3.046e+00, J=4.382829e+00 terminal_cost = 4.382829189300537
iter 001, ls=0, delta_control_rel=6.748e-01, delta_H_rel=5.645e-01 res=2.671e+00, ||g||=3.046e+00, J=3.625527e+00 terminal_cost = 3.6023333072662354
iter 002, ls=0, delta_control_rel=6.113e-01, delta_H_rel=4.632e-01 res=2.011e+00, ||g||=2.088e+00, J=3.254709e+00 terminal_cost = 3.1888129711151123
iter 003, ls=0, delta_control_rel=5.535e-01, delta_H_rel=3.812e-01 res=1.584e+00, ||g||=1.599e+00, J=3.034791e+00 terminal_cost = 2.921473979949951
iter 004, ls=0, delta_control_rel=5.009e-01, delta_H_rel=3.137e-01 res=1.284e+00, ||g||=1.285e+00, J=2.892425e+00 terminal_cost = 2.7317216396331787


In [None]:
# --------------------------------------------------------------------
# 6.  Plots
# --------------------------------------------------------------------


weight_str = f"w_{weight}_"

plot_controls(no_control, optimal_control, u_min, u_max, name = weight_str + 'controls.png', output_dir = output_dir)

plot_densities(state0, state_T_adj, state_T_no_control, lower_bound, upper_bound, name = weight_str + 'densities.png', output_dir = output_dir)

plot_residuals(delta_control_adj[:-1], res_hist_adj, tol=7.5e-2, name = weight_str + 'residuals.png', output_dir = output_dir, mark_every=1)

plot_all_residuals(delta_control_adj[:-1], delta_H_adj[:-1], res_hist_adj, tol=1e-1, name = weight_str + 'residuals.png', output_dir = output_dir, mark_every=1)

plot_objective(loss_hist_adj, terminal_hist_adj, name = weight_str + 'loss.png', output_dir = output_dir, markevery=2)

plt.show()