In [23]:
import matplotlib.pyplot as plt
import autograd.numpy as np  # Thinly-wrapped version of Numpy
from autograd import grad, elementwise_grad
from autograd import extend

EPS_TOL = 1e-10

In [29]:
def grad_descent(f, df_dx, initial_x):
    alpha = 1.0
    curr_x = initial_x
    #   print("start val: {}".format(f(curr_x)))
    for i in range(50):
        curr_val = f(curr_x)
        for j in range(25):
            grad_x = df_dx(curr_x)
            new_x = curr_x - alpha * grad_x
            new_val = f(new_x)
            if new_val < curr_val:
                curr_x = new_x
                break
            else:
                alpha /= 2.
    curr_grad = df_dx(curr_x)
    #   print("end val: {} | grad_value: {}".format(f(curr_x), curr_grad))
    converged = curr_grad < EPS_TOL
    if not converged:
        print("DID NOT CONVERGE! CHECK")
    return curr_x


In [30]:
def O_(x, y):
    return (y-2)**2 + 2*x*y

@extend.primitive
def argmin_O(x, y_init=None, O=O_): # O(x, y) it should have 2 arguments we optimize over y and take derivative w.r.t x
    assert y_init is not None
    Oopt = lambda y : O(x, y)
    return grad_descent(Oopt, grad(Oopt), y_init)


def argmin_O_vjp(ans, x, y_init=None, O=O_):
    """
    This should return the jacobian-vector product 
    it should calculate d_ans/dx because the vector contains dloss/dans
    then we get with dloss/dans * dans/dx = dloss/dx which we're actually interested in
    """
    g = grad(O, 1)    
    dg_dy = grad(g, 1)(x, y_init)
    dg_dx = grad(g, 0)(x, y_init)
    
    if np.ndim(dg_dy) == 0: # we have just simple scalar function so we just have to divide instead of inverse
        return lambda v: v*(-1./dg_dy)*dg_dx
    
        
    return lambda v: v * np.negative(np.matmul(np.linalg.inv(dg_dy), dg_dx))

extend.defvjp(argmin_O, argmin_O_vjp)



argmin_O(1.0, 5.0)


1.0

In [31]:
"""
Safety check that gradient is correct
"""
from autograd.test_util import check_grads
import functools

def O1(x, y):
    return (2*x - y)**2

def example_func(x):
    y_init = 7.
    y = argmin_O(x, y_init, O1)
    return (x - 2*y)**2

grad_of_example = grad(example_func)
print("Gradient: \n", grad_of_example(5.))

# Check the gradients numerically, just to be safe.

def finite_diff(f, x):
    h = 1e-7
    df_dx = (f(x + h) - f(x -h))/(2.*h)
    return df_dx

print("Finite diff grad: {}".format(finite_diff(example_func, 5.)))


check_grads(example_func, modes=['rev'])(5.)


Gradient: 
 90.0
Finite diff grad: 90.00000005698894


In [39]:
def O2(x, y):
    return (3*x - 5*y)**2

def f2(x):
    y_init = 7.
    y = argmin_O(x, y_init, O2)
    return (x-y)**2

print("Finite diff grad: {} | Analytical grad: {}".format(finite_diff(f2, 5.), grad(f2)(5.)))
check_grads(f2, modes=['rev'])(5.)


Finite diff grad: 1.6000000080396148 | Analytical grad: 1.5999999999989738
