In [1]:
import matplotlib.pyplot as plt
import autograd.numpy as np  # Thinly-wrapped version of Numpy
from autograd import grad, elementwise_grad
from autograd import extend

EPS_TOL = 1e-10

In [2]:
def grad_descent(f, df_dx, initial_x):
    alpha = 1.0
    curr_x = initial_x
    #   print("start val: {}".format(f(curr_x)))
    for i in range(10):
        curr_val = f(curr_x)
        for j in range(25):
            grad_x = df_dx(curr_x)
            new_x = curr_x - alpha * grad_x
            new_val = f(new_x)
            if new_val < curr_val:
                curr_x = new_x
                break
            else:
                alpha /= 2.
    curr_grad = df_dx(curr_x)
    #   print("end val: {} | grad_value: {}".format(f(curr_x), curr_grad))
    converged = curr_grad < EPS_TOL
    if not converged:
        print("DID NOT CONVERGE! CHECK")
    return curr_x


In [None]:
def O_(x, y):
    return (y-2)**2 + 2*x*y

@extend.primitive
def argmin_O2(x, y_init=None, O=O_): # O(x, y) it should have 2 arguments we optimize over y and take derivative w.r.t x
    assert y_init is not None
    Oopt = lambda y : O(x, y)
    return grad_descent(Oopt, grad(Oopt), y_init)

def argmin_O2_vjp(ans, x, y_init=None, O=O_):
    """
    This should return the jacobian-vector product 
    it should calculate d_ans/dx because the vector contains dloss/dans
    then we get with dloss/dans * dans/dx = dloss/dx which we're actually interested in
    """
    g = grad(O, 1)    
    dg_dy = grad(g, 1)(x, y_init)
    dg_dx = grad(g, 0)(x, y_init)
    
    if np.ndim(dg_dy) == 0: # we have just simple scalar function so we just have to divide instead of inverse
        return lambda v: v*(-1./dg_dy)*dg_dx
    
        
    return lambda v: v * np.negative(np.matmul(np.linalg.inv(dg_dy), dg_dx))

extend.defvjp(argmin_O2, argmin_O2_vjp)



# in this case O has 3 arguments: O(x, y, z_init) we assume that y = y(x) ? 
@extend.primitive
def argmin_O3(x, y, z_init=None, O=O_):
    assert y_init is not None
    assert z_init is not None
    Oopt = lambda z: O(x, y, z)
    return grad_descent(Oopt, grad(Oopt), z_init)

def argmin_O3_vjp(ans, x, y, z_init=None, O=O_):
    g = grad(O, 2)    
    dg_dz = grad(g, 2)(x, y, z_init)
    dg_dx = grad(g, 0)(x, y, z_init)
    dy_dx = grad(y)(x)
    dg_dy = grad(g, 1)(x, y, z_init)
    
    return lambda v: v*(-(dg_dy + dg_dz*dy_dx))



# we always optimize over the last positional arguments in this functions
def O2(x, y):
    return (3*x - 5*y + x*y)**2

def O3(x, y, z):
    return (x + y - z)**2


def O1(x, init_y=3.0, init_z=4.0):
    y = argmin_O2(x, init_y, init_z, O2)
    z = argmin_O3(x, y, init_z, O3)
    return (x - 2*y)**2 + (z - x)**2

