# Chain derivatives
The purpose of this notebook is to develop a function that can receive an arbitrary
number of functions along with an input array and return the derivative of that set of 
functions.

In [9]:
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from numpy import ndarray
import pandas as pd
from typing import Callable, List

from src.math import derivative
from src.math import tanh_function, softmax_function, sigmoid, relu, leaky_relu, square

# A Function takes in an ndarray as an argument and produces an ndarray
Array_Function = Callable[[ndarray], ndarray]

# A Chain is a list of functions
Chain = List[Array_Function]

In [46]:
def chain_deriv_3(chain: Chain,
                  input_range: ndarray) -> ndarray:
    '''
    Uses the chain rule to compute the derivative of three nested functions:
    (f3(f2(f1)))' = f3'(f2(f1(x))) * f2'(f1(x)) * f1'(x)
    '''
    
    assert len(chain) == 3, \
    "This function requires 'Chain' objects to have length 3"
    
    f1 = chain[0]
    f2 = chain[1]
    f3 = chain[2]
    
    # f1(x)
    f1_of_x = f1(input_range)
    
    # f2(f1_of_x)
    f2_of_x = f2(f1_of_x)
    
    # df3du
    df3du = derivative(f3, f2_of_x)
    
    # df2du
    df2du = derivative(f2, f1_of_x)
    
    # df1du
    df1dx = derivative(f1, input_range)
    
    ## Multiply these together at each point
    return df1dx * df2du * df3du


def test_chain_deriv_3():
    chain_1 = [leaky_relu, square, sigmoid]
    INPUT = np.arange(-3, 3, 1)
    np.testing.assert_allclose( 
        list(chain_deriv_3(chain_1, INPUT)), 
        [-0.05809723, -0.03974509, -0.019992,  0.,  0.39322385, 0.07065084],
        rtol=1e-6, atol=0
    )
    
    
INPUT_RANGE = np.arange(-3, 3, 1)

chain_1 = [leaky_relu, square, sigmoid]
chain_2 = [leaky_relu, sigmoid, square]

chain_deriv_3(chain_1, INPUT_RANGE)

array([-0.05809723, -0.03974509, -0.019992  ,  0.        ,  0.39322385,
        0.07065084])

In [25]:
test_chain_deriv_3()

In [94]:
INPUT_RANGE = np.arange(-3, 3, 1)

chain_1 = [leaky_relu, square, sigmoid]
chain_2 = [leaky_relu, sigmoid, square]

def derivative_chain_rule(chain: List,
                          input_range: ndarray) -> ndarray:
    
    outputs = []
    for i in range(0, len(chain) - 1):
        if i == 0:
            outputs.append(chain[i](input_range))
        else: outputs.append(chain[i](outputs[i - 1]))
        
    derivatives = []
    for i in range(len(chain) - 1, -1 , -1):
        if i == 0:
            derivatives.append(derivative(chain[i], input_range))
        else: derivatives.append(derivative(chain[i], outputs[i - 1]))
        
    for i in range(1, len(chain)):
        if i == 1:
            product = np.multiply(derivatives[i - 1], derivatives[i])
        else:
            product = np.multiply(product, derivatives[i])
    
    return product

x = derivative_chain_rule(chain_1, INPUT_RANGE)
x


array([-0.05809723, -0.03974509, -0.019992  ,  0.        ,  0.39322385,
        0.07065084])

In [95]:
np.testing.assert_allclose( 
    derivative_chain_rule(chain_1, INPUT_RANGE), 
    chain_deriv_3(chain_1, INPUT_RANGE),
    rtol=1e-6, atol=0
)

In [None]:
def countdown(n, chain):
    print(n)
    if n > 0:
        countdown(n - 1, chain)

def create_func(

def derivative_chain_rule(chain: List,
                          input_range: ndarray) -> ndarray:
    inputs = []
    outputs = []

    
    return product

x = derivative_chain_rule(chain_1, INPUT_RANGE)

In [93]:
np.convolve(np.array([1,2,3,4,5]), np.array([1,1,1,1,1]))

array([ 1,  3,  6, 10, 15, 14, 12,  9,  5])