In [1]:
import heapq
import weakref
import functools
from numbers import Number
from collections import namedtuple
from collections.abc import Iterable

import numpy as np
from typing import NoReturn

In [2]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            function = functions_list.pop()
            # gys = [output.grad for output in function.outputs]
            gys = [output().grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    functions_list.add(PriorityItem(x.creator))
        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = [weakref.ref(output) for output in outputs]
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class PriorityItem(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, PriorityItem) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(PriorityItem, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: PriorityItem) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)


In [3]:
x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

1.0 1.0
2.0 1.0


# delete temporary gradients

In [6]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")

        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self, retain_grad=False) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            function = functions_list.pop()
            
            gys = [output().grad for output in function.outputs]
            gxs = function.backward(*gys)
            
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx

                if x.creator is not None:
                    functions_list.add(PriorityItem(x.creator))

            if not retain_grad:
                for y in function.outputs:
                    y().grad = None

        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = [weakref.ref(output) for output in outputs]
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class PriorityItem(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, PriorityItem) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(PriorityItem, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: PriorityItem) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)


In [7]:
x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

None None
2.0 1.0


# Config class

In [8]:
class Config:
    enable_backprop = True

In [9]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")

        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self, retain_grad=False) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            function = functions_list.pop()
            
            gys = [output().grad for output in function.outputs]
            gxs = function.backward(*gys)
            
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx

                if x.creator is not None:
                    functions_list.add(PriorityItem(x.creator))

            if not retain_grad:
                for y in function.outputs:
                    y().grad = None

        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        
        if Config.enable_backprop:
            self.generation = max([x.generation for x in inputs])
            for output in outputs:
                output.set_creator(self)
            
            self.inputs = inputs
            self.outputs = [weakref.ref(output) for output in outputs]
        
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class PriorityItem(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, PriorityItem) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(PriorityItem, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: PriorityItem) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)


# switch backpropogation mode

In [10]:
Config.enable_backprop = True
x = Variable(np.ones((100, 100, 100)))
y = square(square(square(x)))
y.backward()
print(x.grad)

[[[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]

 [[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]

 [[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]

 ...

 [[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]

 [[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]

 [[8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  ...
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]
  [8. 8. 8. ... 8. 8. 8.]]]


In [13]:
Config.enable_backprop = False
x = Variable(np.ones((100, 100, 100)))
y = square(square(square(x)))
try:
    y.backward()
    print(x.grad)
except AttributeError:
    print("unable to execute y.backward()")
    print("backpropogation mode is off")

unable to execute y.backward()
backpropogation mode is off


# use 'with' statement

In [14]:
import contextlib

@contextlib.contextmanager
def using_config(name, value):
    old_value = getattr(Config, name)
    setattr(Config, name, value)
    try:
        yield
    finally:
        setattr(Config, name, old_value)

In [15]:
with using_config('enable_backprop', False):
    x = Variable(np.array(2.0))
    y = square(x)

In [16]:
def no_grad():
    return using_config('enable_backprop', False)

with no_grad():
    x = Variable(np.array(2.0))
    y = square(x)