In [8]:
import jax
import jax.numpy as jnp
import numpy as np

In [3]:
from typing import Any, Sequence,Callable,NamedTuple, Optional, Tuple
PyTree = Any

class Optimizer(NamedTuple):
    #Given param, initialize optimizer state as tuple
    init: Callable[[PyTree], tuple]

    update: Callable[[PyTree, tuple, Optional[PyTree]], Tuple[PyTree, tuple]]


In [4]:
#Implementing SGD
from jax.tree_util import tree_map

def sgd(lr):
    def init(params):
        return tuple()
    
    def update(updates, state, params=None):
        update = tree_map(lambda u: -lr * u, updates)
        return updates, state
    
    return Optimizer(init, update)


In [5]:
#Implementing SGD with Momentum
#B1 is the weighting factor 

def sgd_momentum(lr, momentum=0.0):
    def init(params):
        param_momentum = tree_map(jnp.zeros_like, params)
        return param_momentum
    
    def update(updates, state, params=None):
        state = tree_map(lambda m, g: (1- momentum) * g + momentum * m, state, updates)
        updates = tree_map(lambda m: -lr * m, state)

        return updates, state
    
    return Optimizer(init,update)

In [9]:
#Implementing Adam

def adam(lr, beta1 = 0.9, beta2 = 0.00, eps=1e-8):
    def init(params):
        step = 0.
        param_momentum = tree_map(jnp.zeros_like, params)
        param_2nd_momentum = tree_map(jnp.zeroes_like, params)
        return (step, params, param_2nd_momentum)
    
    def update(updates, state, params = None):
        step, param_momentum, param_2nd_momentum = state
        step +=1 

        param_momentum = tree_map(lambda m, g: (1-beta1) * g + beta1 * m, param_momentum, updates)
        param_2nd_momentum = tree_map(lambda m2, g:(1-beta2) * g ** 2 + beta2 * m2, param_2nd_momentum, updates)

        def update_param(m, m2):
            #Bias correction
            m /=1 - beta1 ** step
            m2 /= 1 - beta2 ** step

            return -m * lr / (jnp.sqrt(m2) + eps)
        
        updates = tree_map(update_param, param_momentum, param_2nd_momentum)

        return updates, (step, param_momentum, param_2nd_momentum)
    
    return Optimizer(init, update)
