In [54]:
"""
The basis for a framework of automatic differentiation such as PyTorch.

What we want is to:
1. Build a forward pass manually, passing tensors / variable through blocks
2. Make sure each variable remember which block uses it (to get its gradients)
"""

import abc
from collections import deque
import functools
import itertools
import math
import numpy as np
import operator
from typing import List

In [63]:
class Variable:
    def __init__(self, value: float, from_op=None, requires_grad=False):
        self.value = float(value)
        self.gradient = None
        self.from_op = from_op
        self.requires_grad = requires_grad
        self.gradient_fcts = []
    
    def compute_gradient(self):
        if self.requires_grad:
            self.gradient = sum(fct.derivative_by(self) for fct in self.gradient_fcts) if self.gradient_fcts else 1
    
    def backward(self):
        to_visit = deque([self])
        while to_visit:
            node = to_visit.popleft()
            if node.requires_grad:
                node.compute_gradient()
                if node.from_op:
                    to_visit.extend(arg for arg in node.from_op.arguments)

    
class Function(abc.ABC):
    # TODO - make it a metaclass
    # TODO - you could make it a Monad in Haskell
    
    def __init__(self, arguments: List[Variable]):
        self.arguments = arguments
        self.output = None
    
    def __call__(self) -> Variable:
        result = self.apply(arg.value for arg in self.arguments)
        self.output = Variable(result, self, requires_grad=False)
        for arg in self.arguments:
            if arg.requires_grad:
                arg.gradient_fcts.append(self)
                self.output.requires_grad=True
        return self.output
    
    @abc.abstractmethod
    def apply(self, argument_values) -> Variable:
        pass
    
    @abc.abstractmethod
    def derivative_by(self, by: Variable) -> float:
        pass
    
    
class AddFct(Function):
    def __init__(self, arguments):
        super().__init__(arguments)
    
    def apply(self, argument_values):
        return functools.reduce(operator.add, argument_values, 0)
    
    def derivative_by(self, x: Variable):
        if x in self.arguments:
            return self.output.gradient
        return 0


class MultiplyFct(Function):
    def __init__(self, arguments):
        super().__init__(arguments)
    
    def apply(self, argument_values):
        return functools.reduce(operator.mul, argument_values, 1)
    
    def derivative_by(self, x: Variable):
        total = 1.
        for arg in self.arguments:
            if arg != x:
                total *= arg.value
        return total
    
    
def add(v1: Variable, v2: Variable):
    op = AddFct([v1, v2])
    return op()


def multiply(v1: Variable, v2: Variable):
    op = MultiplyFct([v1, v2])
    return op()


x1 = Variable(1, requires_grad=True)
x2 = Variable(2, requires_grad=True)
x3 = Variable(3, requires_grad=True)
y1 = add(x1, x2)
y2 = add(x2, x3)
z = multiply(y1, y2)
z.backward()

print(x1.gradient)
print(x2.gradient)
print(x3.gradient)

5.0
8.0
3.0
