In [23]:
import functools
import heapq
from numbers import Number
from collections import namedtuple
from collections.abc import Iterable

import numpy as np
from typing import NoReturn

In [8]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            function = functions_list.pop()
            gys = [output.grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    functions_list.add(PriorityItem(x.creator))
        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class PriorityItem(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, PriorityItem) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(PriorityItem, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: PriorityItem) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)


# use weakref to avoid circular reference

In [9]:
import weakref

In [14]:
a = np.array([1,2,3])
b = weakref.ref(a)
b

<weakref at 0x7fa2f5b965e0; to 'numpy.ndarray' at 0x7fa2f54e45d0>

In [15]:
b()

array([1, 2, 3])

In [16]:
a = None
b

<weakref at 0x7fa2f5b965e0; to 'numpy.ndarray' at 0x7fa2f54e45d0>

In [24]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            function = functions_list.pop()
            # gys = [output.grad for output in function.outputs]
            gys = [output().grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    functions_list.add(PriorityItem(x.creator))
        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = [weakref.ref(output) for output in outputs]
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class PriorityItem(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, PriorityItem) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(PriorityItem, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: PriorityItem) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)


In [25]:
for _ in range(10):
    x = Variable(np.random.randn(10000))
    y = square(square(square(x)))
    y.backward()
    print(x.grad)

[ 5.25580748e+03 -8.49882187e-11  1.53437531e+00 ...  5.62658670e+00
 -1.64784901e+02 -3.07462315e-02]
[-0.95665465 -1.19067698  0.01088825 ... -0.12047008  0.00331084
  0.1141097 ]
[-1.14461565e+02  2.99117034e-13 -7.69669189e+01 ... -1.99732471e-02
 -1.95401580e-01  1.46353668e-03]
[-1.33008659e+00 -3.90626903e+00  8.30331546e+02 ... -9.44647155e-03
 -5.48300411e+02 -1.71559530e-02]
[-1.82419439e+01 -2.03202252e+01  5.73022251e+01 ...  6.88686436e-05
 -9.94707819e-06  5.82356899e+02]
[ 5.70415691e-02 -8.82656387e-04 -1.74273727e-04 ...  1.72172313e-01
  1.05001109e+02 -5.50355020e-03]
[ 3.56019785e+00  2.67524462e-03 -9.49329746e-01 ... -1.04322227e+03
 -1.94151613e-03  1.86977500e+03]
[ 2.37084960e+02 -2.98141718e+00  2.45300602e-04 ... -1.41122341e+01
  2.14331033e+02  5.16900045e-19]
[-1.84764502e+02 -8.74095277e+00 -3.49067391e-01 ... -8.99520742e+01
 -9.06640793e-03  2.07331190e+04]
[ 9.05822273e-03 -1.27759166e-02 -2.43870041e-15 ...  1.36220417e-03
  3.77655195e-06 -3.62383055