In [None]:
import matplotlib.pyplot as plt
import scipy.optimize as spo
import jax.numpy as np
jnp = np
import jax

from my_timer import timer_decorator
from my_solvers import trust_region

# Definice funkcí s jax knihovnou pro možnost automatické derivace (np je nyní z jax.numpy)

In [None]:
def energy_jax(v_internal, fx, v, h, p):
    v = v.at[1:-1].set(v_internal)
    vx = (v[1:] - v[:-1]) / h
    v_mid = (v[1:] + v[:-1]) / 2
    Jv_density = (1 / p) * np.abs(vx)**p - fx * v_mid

    return np.sum(h * Jv_density)


# exact minimizer for p=2
def u_init(x):
    return 0 * (x + 1) * (x - 1)


# rhs
def f(x):
    return -10 * np.ones(x.size)


# úloha stejně jako v pLaPlace_numba

In [None]:
p, a, b = 3, -1, 1
ne = 500
x = np.linspace(a, b, ne + 1)
h = np.diff(x)

v = u_init(x)            # testing function
v_internal = v[1:-1].copy()
x0 = v_internal.copy()   # initial guess


x_mid = (x[1:] + x[:-1]) / 2
fx = f(x_mid)

# Definování funkce, gradientu a hesiánu; nastavení automatické kompilace jit; vyrobení funce s jedním vstupem

In [None]:
# automatická derivace a kompilace
fun = jax.jit(energy_jax)
dfun = jax.jit(jax.grad(energy_jax, argnums=0))
ddfun = jax.jit(jax.hessian(energy_jax, argnums=0))
fun1 = lambda v_internal: fun(v_internal, fx, v, h, p)
dfun1 = lambda v_internal: dfun(v_internal, fx, v, h, p)
ddfun1 = lambda v_internal: ddfun(v_internal, fx, v, h, p)

# V následujících třech buňkách se s prvním zavoláním funkce rovnou kompiluje jit 

In [None]:
print(f"Initial energy: {fun1(v_internal)}")

In [None]:
print(f"||g||={np.linalg.norm(dfun1(v_internal))}")

In [None]:
print(f"||H||={np.linalg.norm(ddfun1(v_internal))}")

# vyzkoušení řešení úlohy pomocí trust regionu z my_solvers.py
u mně to trvalo cca 0.5s pro úlohu s dělením 200 a p=3

pro dělení 500 a p=3 to trvalo cca 2.1s v porovnání s implementací bez derivací s numbou 27s

In [None]:
solopt, iterations = trust_region(fun1 , dfun1, ddfun1, x0, c0=1, tolf=1e-6, tolg=1e-6, maxit=1000, verbose = False)

# Pokus o vyřešení pomocí scipy.optimize.minimize, zde bez použití gradientu a hesiánu, pozor vrací špatný výsledek

In [None]:
# comparison with scipy minimization
from scipy.optimize import minimize


minimize_timed = timer_decorator(minimize)

print("energy (init)=", fun(v_internal, fx, v, h, p))


solopt = minimize_timed(fun, v_internal, args=(fx, v, h, p))

print("energy (final)=", fun(solopt.x, fx, v, h, p))

# Verze trust regionu v scipy.optimize.minimize, správný výsledek, ale trvá déle než předchozí trust region (vevnitř se zřejmě děje něco jinak (víc iterací))

In [None]:
print("energy (init)=", fun(v_internal, fx, v, h, p))


solopt = minimize_timed(fun, v_internal, args=(fx, v, h, p), method='trust-exact', jac=dfun, hess=ddfun)

print("energy (final)=", fun(solopt.x, fx, v, h, p))
solopt.nit

In [None]:
import fides

def obj(x):
    return fun1(x), dfun1(x), dfun1(x)

opt = fides.Optimizer(obj)

opt_f, opt_x, opt_grad, opt_hess = opt.minimize(x0)

# větší úloha

In [None]:
p, a, b = 3, -1, 1
ne = 1000
x = np.linspace(a, b, ne + 1)
h = np.diff(x)

v = u_init(x)            # testing function
v_internal = v[1:-1].copy()
x0 = v_internal.copy()   # initial guess


x_mid = (x[1:] + x[:-1]) / 2
fx = f(x_mid)

In [None]:

solopt, iterations = trust_region(fun1 , dfun1, ddfun1, x0, c0=1, tolf=1e-6, tolg=1e-3, maxit=1000, verbose = False)


In [None]:
ddfun1(solopt)

### scipy je třeba nastavit mimo default, jinak nedořeší v 1000 iteracích

In [None]:
solopt = minimize(fun1, v_internal, method='TNC', jac=dfun1)

print("energy (final)=", fun(solopt.x, fx, v, h, p))
solopt

In [None]:
import sparsejac
import importlib
import sparsejac
importlib.reload(sparsejac)


In [None]:

from scipy.sparse import diags
n = ne-1
# create threediagonal matrix with 1 on diagonal of size n

diagonals = [np.ones(n-1), np.ones(n), np.ones(n-1)]
offsets = [-1, 0, 1]
A = diags(diagonals, offsets)

In [None]:

sparsity = jax.experimental.sparse.BCOO.from_scipy_sparse(A)


In [None]:
# automatická derivace a kompilace
fun = jax.jit(energy_jax)
dfun = jax.jit(jax.grad(energy_jax, argnums=0))
ddfun = jax.jit(sparsejac.jacrev(jax.grad(energy_jax, argnums=0), sparsity,argnums=0))
fun1 = lambda v_internal: fun(v_internal, fx, v, h, p)
dfun1 = lambda v_internal: dfun(v_internal, fx, v, h, p)
ddfun1 = lambda v_internal: ddfun(v_internal, fx, v, h, p)

In [None]:
from jax.scipy.sparse.linalg import cg


A = ddfun1(solopt)
b = solopt.copy()
for i in range(100):

    x = cg(A, b, tol=1e-3)


In [None]:
import scipy.sparse.linalg as spla
A = A.
spla.spsolve(A, b)

In [None]:

x = jax.random.uniform(jax.random.PRNGKey(0), shape=(1000,))

sparse_fn = jax.jit(sparsejac.jacrev(fn, sparsity))
dense_fn = jax.jit(jax.jacrev(fn))


# ADAM

In [None]:
# reload module
import importlib
import my_solvers
importlib.reload(my_solvers)
from my_solvers import adam


In [None]:

solopt = adam(fun1, dfun1, x0, maxit=10000)

# profilování trust regionu
- řeší s hustými maticemi (nevím jestli jax zvládne sparse)
- 171 iterací vs v článku bylo 37

In [None]:
# !pip install line_profiler

In [None]:
%load_ext line_profiler

### Line profiler (moje vysledky)
83% času stráveného při řešení hustého systému lineárních rovnic

```
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    71                                           @timer_decorator
    72                                           def trust_region(f, df, ddf, x0, c0=1.0, tolf=1e-6, tolg=1e-3, maxit=1000, verbose=False):
    73                                               """
    74                                               Trust Region (quasi-Newton method)
    75                                           
    76                                               Parameters
    77                                               ----------
    78                                               fun : function
    79                                                   The objective function to be minimized.
    80                                               x0 : numpy.ndarray
    81                                                   The initial guess for the minimum.
    82                                               c0 : float
    83                                                   The initial trust region size.
    84                                               tol : float
    85                                                   The tolerance for the stopping condition.
    86                                           
    87                                               Returns
    88                                               -------
    89                                               xmin : numpy.ndarray
    90                                                   The found minimum.
    91                                               it : int
    92                                                   The number of iterations.
    93                                               """
    94                                           
    95         1       1077.0   1077.0      0.0      c = c0
    96         1        376.0    376.0      0.0      x = x0
    97         1     830351.0 830351.0      0.0      fx = f(x)
    98         1    3272663.0 3272663.0      0.0      g = df(x)
    99         1    8270695.0 8270695.0      0.0      H = ddf(x)
   100                                           
   101         1       1110.0   1110.0      0.0      it = 0
   102       172   14835459.0  86252.7      0.1      while np.linalg.norm(g) > tolg:
   103                                                   # Trial step
   104       172 13996989406.0 81377845.4     82.8          h = -np.linalg.solve(H + c * np.eye(len(x)), g)
   105                                                   # Quadratic model of function f
   106       172  754014682.0 4383806.3      4.5          m = fx + np.dot(g.T, h) + 0.5 * np.dot(np.dot(h.T, H), h)
   107       172   18575530.0 107997.3      0.1          fxn = f(x + h)
   108                                           
   109                                                   # check stopping condition for f
   110       171   12395796.0  72490.0      0.1          if np.abs(fx - fxn) < tolf:
   111         1      94493.0  94493.0      0.0              print("Stopping condition for f is satisfied")
   112         1        462.0    462.0      0.0              break
   113                                           
   114       171   11968786.0  69992.9      0.1          rho = (fx - fxn) / (fx - m)
   115                                           
   116       136   14779538.0 108673.1      0.1          if rho >= 0.1:
   117       136    3481096.0  25596.3      0.0              xn = x + h
   118       136    7809308.0  57421.4      0.0              g = df(xn)
   119       136 2028782573.0 14917518.9     12.0              H = ddf(xn)
   120                                                   else:
   121        35      18548.0    529.9      0.0              xn = x
   122        35      57342.0   1638.3      0.0              fxn = fx
   123                                           
   124                                                   # Adjust the size of the trust region
   125       126   25827956.0 204983.8      0.2          if rho > 0.75:
   126        45     129868.0   2886.0      0.0              c *= 0.5
   127        91    5762495.0  63324.1      0.0          elif rho < 0.1:
   128        35      63492.0   1814.1      0.0              c *= 2
   129                                           
   130       171     664423.0   3885.5      0.0          x = xn
   131       171     518335.0   3031.2      0.0          fx = fxn
   132                                           
   133       171     247481.0   1447.3      0.0          it += 1
   134       171     104773.0    612.7      0.0          if verbose:
   135                                                       print(f"it={it}, f={fx}, c={c}, ||g||={np.linalg.norm(g)}")
   136                                           
   137       171     168070.0    982.9      0.0          if it > maxit:
   138                                                       print("Maximum number of iterations reached")
   139                                                       break
   140                                               else:
   141                                                   print("Stopping condition for g is satisfied")
   142                                           
   143         1     173675.0 173675.0      0.0      print(f"it={it}, f={fx}, c={c}, ||g||={np.linalg.norm(g)}")
   144         1        446.0    446.0      0.0      return x, it
   ```

In [None]:
%lprun -f trust_region.__wrapped__ trust_region(fun1 , dfun1, ddfun1, x0, c0=1, tolf=1e-6, tolg=1e-6, maxit=1000, verbose = False)