# TODO

1. Proste API które można wykorzystać do liczenia automatycznego gradientu
2. Zaimplementowanie za pomocą tego API aproksymacji jakiejś prostej funkcji / regresji liniowej
3. Zaimplementowanie za pomocą tego API sieci neuronowej 
4. Zaimplementowanie tego samego za pomocą API autogradu
5. Dodanie ulepszeń naszego API w taki sposób, żeby było bardziej podobne do tego od autogradu

Based on this [great talk](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/).

Okay. 
So basically, the goal for this notebook would be to create a a code for which a gradient can be calculated by applying a simple rule. 

The logic would be the following: 
1. Create a Node containing information about parents and about current function
2. Wrap every numpy function in wrapper function that would recieve an input, unbox it, calculate its value, and box it again




In [1]:
# GOAL
from autograd import grad
import autograd.numpy as anp
import numpy as np

def fun(X, Y):
    return anp.sum(X) * anp.sum(Y)

X = np.array([1.0,2.0,3.0])
Y = np.array([3.0,4.0,5.0])

grad_fun = grad(fun)

print(f'Fun: {fun(X, Y)}, deriv: {grad_fun(X, Y)}')


print(g(X, Y))

Fun: 72.0, deriv: [12. 12. 12.]


NameError: name 'g' is not defined

In [None]:
from __future__ import annotations
from numpy.typing import NDArray
from typing import Callable, Iterable


class Node:
    def __init__(self, value: NDArray, parents: Iterable[Node], primitive) -> None:
        """
        :param value: value associated with the node
        :param parents: parents of the Node
        :param primitive: function by which this node was produced
        """
        self.value = value
        self.parents = parents
        self.primitive = primitive

    

In the below cell we define the primitive class which will work as a wrapper for all operations that we can possibly perform and will help us build a computational graph.

In [None]:
from typing import Any
import numpy as np
from abc import ABC, abstractmethod
from functools import partial
from typing import Callable
from numpy.typing import NDArray
import operator as op


class Primitive(ABC):
    def __init__(self, fun: Callable[[NDArray], NDArray], argcount: int) -> None:
        self.fun = fun
        self._argc = argcount

    def __call__(self, *args: Node) -> Node:
        values = [node.value for node in args] 
        print(values)
        primitive_results = self.fun(*values)
        return Node(primitive_results, args, self)

    def grad(self, argnum) -> Callable:
        return partial(self._gradfun, argnum=argnum)

    @abstractmethod
    def _gradfun(self, *args: Node, argnum):
        ...

class sum(Primitive):
    def __init__(self) -> None:
        super().__init__(np.sum, 1)

    def _gradfun(self, *args: Node, argnum):
        arg = args[argnum].value
        return np.ones(arg.shape)

class product(Primitive):
    def __init__(self) -> None:
        super().__init__(np.prod, 1)

    
    def _gradfun(self, *args: Node, argnum):
        factors = [a for i, a in enumerate(args) if i != argnum] # Derivative of a product is simply a factor before the product
        return np.prod(factors, axis=1)
        


class value(Primitive):
    def __init__(self) -> None:
        super().__init__(lambda x: x, 1)
    
    def _gradfun(self, *args: Node, argnum):
        arg = args[argnum]
        return np.zeros_like(arg)

In [None]:

import operator

def toposort(end_node, parents=operator.attrgetter('parents')):
    child_counts = {}
    stack = [end_node]
    while stack:
        node = stack.pop()
        if node in child_counts:
            child_counts[node] += 1
        else:
            child_counts[node] = 1
            stack.extend(parents(node))

    childless_nodes = [end_node]
    while childless_nodes:
        node = childless_nodes.pop()
        yield node
        for parent in parents(node):
            if child_counts[parent] == 1:
                childless_nodes.append(parent)
            else:
                child_counts[parent] -= 1

In [None]:
def forward_pass(fun: Primitive, args):
    start = Node(args, [], value)
    end = fun(*args)
    return start, end

In [None]:
from typing import cast 

def backward_pass(gradient, end_node: Node):
    """
    Retursn final grad and all intermediate grads
    """
    outgrads = { end_node: gradient }
    for node in toposort(end_node): # for all nodes coming into this node
        outgrad = outgrads[node]
        parents = node.parents
        for p in parents:
            primitive = cast(Primitive, p.primitive)
            outgrads[p] = [
                outgrad @ primitive.grad(argnum=i) 
                for i in range(primitive.argnum)]
    return outgrad



def grad(fun: Primitive, argnum=0):
    """
    :param fun: primitive to be differentiated
    :param argnum: number of arguments with respect to which 
    derivatives should be calculated
    """
    def gradient_function(*args):
        """
        returns change vector (jacobian) for every argument 
        in function for provided values
        """
        _, end = forward_pass(fun, args)
        return lambda grad_value: backward_pass(grad_value, end)
        
    return gradient_function



In [None]:
# Nie komplikujemy. Skoro działamy na kombinacji funkcji, to działajmy na kombinacji funkcji :) 

# sum(X) * sum(X) -> product(sum(), sum())
# Działajmy na węzłach


# Primitive - is just a function, that can have arguments.


X = value()
Y = value()

A = sum(X)
B = sum(Y)

C = product(A, B)




# Funkcja zadaniem której jest wywołanie przekazanego 
# primitive z odpowiednimi argumentami
def forward_pass(primitive: Primitive, *args: NDArray):
    start_nodes = [Node(arg, [], value) for arg in args]
    # TODO:
    # - build a graph
    # - 
    # Primitive - function 
    # Node - operation 
    # Nodes contain primitives

    


# In order to call a primitive, we should: 
# 1. Find the very first primitives in the chain (i.e. one that does not have parents)
# 2. Pass values to them, collect their responses and pass forward
# 3. It actually should be done not in forward pass method, but in call of each primitive
# 4. The most simple primitive would look like 


In [None]:
A = sum()

X = Node(np.array([1,2,3]), [], value)


output = A(X)
print(output.value)


grads = [A.grad(i) for i in range(2)]
print([
    g(X) for g in grads
])

[array([1, 2, 3])]
6


IndexError: tuple index out of range

In [None]:
from __future__ import absolute_import
from __future__ import print_function
from builtins import range
import autograd.numpy as np
from autograd import grad
from autograd.test_util import check_grads


def training_loss(weights):
    return anp.sum(weights)


inputs = np.array([[0.52, 1.12,  0.77]])



training_gradient_fun = grad(training_loss)
training_gradient_fun(weights)


array([1., 1., 1.])