<a href="https://colab.research.google.com/github/Konstantin-Iakovlev/2021-Optimization_Project/blob/main/SQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Stochastic Quasi Newton methods for Neural Networks 

In [6]:
import jax.numpy as jnp
from jax import grad, hessian, jit, vmap
from jax import random
import numpy as np

In [None]:
class SQN_Optimizer:
    def __init__(self, w1: jnp.array, M: jnp.int64, L: jnp.int64):
        self.w1 = w1
        self.M = M
        self.L = L
        self.N = 10000
        self.batch_grad_size = 64
        self.batch_hess_size = 64
        # Generate data
        np.random.seed(1)
        self.X = np.random.rand(self.N, 20)
        self.Y = np.random.rand(self.N)
        # w^k
        self.w_array = [self.w1.copy()]
        self.wt_array = []
        self.alpha = 1e-4

    def f(self, w: jnp.array, i):
        return (jnp.dot(w.T, self.X[i]) - self.Y[i]) ** 2

    def F_maker(self, S: jnp.array):
        def F(w: jnp.array):
            return (1 / self.batch_grad_size) * jnp.sum(jnp.array([self.f(w, i) for i in S]))
        return F

    def Ht(self, t, M, pairs_arr):
        pass

    def optimize(self):
        # num of iterations
        iter_num = 1e3
        t = -1
        wt = jnp.zeros((self.w1.shape[0],))
        self.wt_array.append(wt.copy())
        for k in range(1, iter_num + 1):
            S = random.choice(random.PRNGKey(np.random.randint(3, size=1)[0]), self.N, shape=(self.batch_grad_size,))
            dF = grad(self.F_maker(S))
            self.wt_array.append(self.wt_array[-1] + self.w_array[-1])
            if k <= 2 * self.L:
                self.w_array.append(self.w_array[-1] - self.alpha * dF(self.w_array[-1]))
            else:
                self.w_array.append(self.w_array[-1] -
                            self.alpha * jnp.dot(self.Ht(t, self.M, jnp.array()), dF(self.w_array[-1])))
            if k % self.L == 0:
                t += 1
                self.wt_array.append(self.wt_array[-1] / self.L)
                if t > 0:
                    Sh = random.choice(random.PRNGKey(np.random.randint(3, size=1)[0]), self.N, shape=(self.batch_hess_size,))
                    hessF = hessian(self.F_maker(Sh))
                    st = self.wt_array[-1] - self.wt_array[-2]
                    yt = jnp.dot(hessF(self.wt_array[-1]), st)
                self.wt_array[-1] = jnp.zeros((self.w1.shape[0],))