<a href="https://colab.research.google.com/github/JALB-epsilon/CAAM554_Codes/blob/main/HW9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#It contains all the problem from the HW9.



##Setting Libraries

In [33]:
import jax 
import jax.numpy as jnp
from jax import grad, jit, jacrev
from jax import random, device_put
import matplotlib.pyplot as plt
key = random.PRNGKey(10)

#Problem 3

In [34]:
#If you want to run it in a GPU (though is not needed), 
#Go to Runtime->Change runtime type -> Hardware accelerator.
jax.devices()

[CpuDevice(id=0)]

In [35]:
def F(x):
  F1 = jnp.dot(x,x)-2;  F2 = jnp.exp(x[0]-1)+x[1]**2-2
  return device_put(jnp.array([F1,F2]))
jit_F = jit(F)
jit_jacob = jit(jacrev(jit_F))

def f(x, jit_F = jit_F):
  return jnp.dot(jit_F(x),jit_F(x))/2
jit_f = jit(f)
#jit_f_grad = jnp.dot(jit_jacob(x).T, jit_F(x))
jit_f_grad = jit(grad(jit_f))

In [36]:
x = jnp.array([0.,0.])
jit_F(x), jit_jacob(x)

(DeviceArray([-2.       , -1.6321206], dtype=float32),
 DeviceArray([[0.        , 0.        ],
              [0.36787945, 0.        ]], dtype=float32))

In [37]:
jit_f(x), jit_f_grad(x)

(DeviceArray(3.3319087, dtype=float32),
 DeviceArray([-0.60042363, -0.        ], dtype=float32))

In [38]:
jnp.dot(jit_jacob(x).T, jit_F(x))

DeviceArray([-0.60042363,  0.        ], dtype=float32)

###Newton Method for root finding and line search

In [39]:
def Armijo(f, beta1=0.5, beta2=0.5, alpha=1):
  if beta1<=beta2: 
    return jax.random.uniform(key, minval=beta1*alpha, maxval=beta2*alpha)
  else:
    print("error")

In [42]:
def newton_method(x, jit_F = jit_F, jit_jacob_F = jit_jacob, jit_f = jit_f,jit_f_grad = jit_f_grad,
                 alpha =1, tol = 1e-12,  maxIt =100, PrintVar =1, PrintX = None,
                 save_files = None, beta1=0.5, beta2=0.5, c1= 1e-4):
  n =len(x)
  if n < 20: 
    print(f"initial state x_0 = {x}, alpha:{alpha}")
  if save_files: 
    val_F_norm = []
    val_sk_norm= []
    val_alpha_k = []
    iter = []
  for i in range(maxIt):
    F = jit_F(x)
    jacobian_F = jit_jacob_F(x)
    s = jnp.linalg.solve(jacobian_F, -F)
    aks = jnp.linalg.norm(s*alpha,ord=2)
    Fnorm = jnp.linalg.norm(F,ord=2)
    if i%PrintVar ==0: 
      print(f"iteration{i}, x={x}, |F|={Fnorm},|alpha*s|={aks}, alpha={alpha}" )
      if PrintX: 
          print(f"x: {x}")
      if save_files:
        val_F_norm.append(Fnorm)
        val_sk_norm.append(aks)
        val_alpha_k.append(alpha)
        iter.append(i)
    if Fnorm< tol:
        print("****************************************")
        print(f"stop because F(x)= {F}< {tol*(jnp.ones_like(F))} with x = {x}")
        break            
    if jnp.isnan(f(x)):
        print("****************************************")
        print(f"stop because abs(F)= {F}, Newton method diverges")
        break 
    while f(x+alpha*s) > f(x)+c1*alpha*jnp.dot(jit_f_grad(x), s):
      alpha = Armijo(f, beta1, beta2, alpha)
    x+= alpha*s   
  print("****************************************")
  if n< 20: 
    print(f"iteration: {i}, F(x): {F}, and x={x} , alpha:{alpha}")
  else: 
    print(f"iteration: {i}, F(x): {F}")
  if save_files:
    return val_F_norm, val_sk_norm, val_alpha_k, iter, x
  else:
    return x

In [43]:
xinit = jnp.array([1.5,2.])
val_F, val_sk, val_alpha, iter, x =newton_method(xinit, save_files=1, beta1=0.2, beta2=0.5)

initial state x_0 = [1.5 2. ], alpha:1
iteration0, x=[1.5 2. ], |F|=5.601398944854736,|alpha*s|=0.8538780212402344, alpha=1
iteration1, x=[1.0550297 1.2712276], |F|=0.9919562339782715,|alpha*s|=0.24811919033527374, alpha=1
iteration2, x=[1.0013834 1.0289773], |F|=0.08608981221914291,|alpha*s|=0.02860260009765625, alpha=1
iteration3, x=[1.0000008 1.000408 ], |F|=0.001156121725216508,|alpha*s|=0.0004079238569829613, alpha=1
iteration4, x=[0.99999994 1.0000001 ], |F|=2.6656007889869215e-07,|alpha*s|=2.149075726265437e-07, alpha=1
iteration5, x=[1.         0.99999994], |F|=1.6858739115832577e-07,|alpha*s|=5.960465188081798e-08, alpha=1
iteration6, x=[1. 1.], |F|=0.0,|alpha*s|=0.0, alpha=1
****************************************
stop because F(x)= [0. 0.]< [1.e-12 1.e-12] with x = [1. 1.]
****************************************
iteration: 6, F(x): [0. 0.], and x=[1. 1.] , alpha:1
