In [28]:
from numbers import Number
from typing import NoReturn
import numpy as np


class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)

        functions_list = [self.creator]
        while functions_list:
            print('functions_list:', functions_list)
            function = functions_list.pop()
            print('Use function:', function)
            
            gys = [output.grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    functions_list.append(x.creator)
                    print("Collect function:", x.creator)
            print('functions_list:', functions_list)
            print()

        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


# wrong flow

In [29]:
x = Variable(np.array(3.0))
a = exp(x)
b = square(a)
c = cube(a)
y = add(b, c)
y.backward()
print(x.grad)

functions_list: [<__main__.Add object at 0x7fefb41e9eb0>]
Use function: <__main__.Add object at 0x7fefb41e9eb0>
Collect function: <__main__.Square object at 0x7fefb41e9b80>
Collect function: <__main__.Cube object at 0x7fefb41e98b0>
functions_list: [<__main__.Square object at 0x7fefb41e9b80>, <__main__.Cube object at 0x7fefb41e98b0>]

functions_list: [<__main__.Square object at 0x7fefb41e9b80>, <__main__.Cube object at 0x7fefb41e98b0>]
Use function: <__main__.Cube object at 0x7fefb41e98b0>
Collect function: <__main__.Exp object at 0x7fefb45e35e0>
functions_list: [<__main__.Square object at 0x7fefb41e9b80>, <__main__.Exp object at 0x7fefb45e35e0>]

functions_list: [<__main__.Square object at 0x7fefb41e9b80>, <__main__.Exp object at 0x7fefb45e35e0>]
Use function: <__main__.Exp object at 0x7fefb45e35e0>
functions_list: [<__main__.Square object at 0x7fefb41e9b80>]

functions_list: [<__main__.Square object at 0x7fefb41e9b80>]
Use function: <__main__.Square object at 0x7fefb41e9b80>
Collect f

# Use generation to tell difference of function priority

In [15]:
generations = [2, 0, 1, 4, 2]
funcs = []

for g in generations:
    f = Function()
    f.generation = g
    funcs.append(f)
    
print([f.generation for f in funcs])

funcs.sort(key=lambda x: x.generation)
print([f.generation for f in funcs])

f = funcs.pop()
print(f.generation)

[2, 0, 1, 4, 2]
[0, 1, 2, 2, 4]
4


In [30]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)

        functions_list = []
        seen_set = set()
        def add_function(f):
            if f not in seen_set:
                functions_list.append(f)
                seen_set.add(f)
                functions_list.sort(key=lambda x: x.generation)
                
        add_function(self.creator)
        
        while functions_list:
            print('functions_list:', functions_list)
            function = functions_list.pop()
            print('Use function:', function)
            
            gys = [output.grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    add_function(x.creator)
                    print("Collect function:", x.creator)
            print('functions_list:', functions_list)
            print()

        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x

# correct flow

In [31]:
x = Variable(np.array(3.0))
a = exp(x)
b = square(a)
c = cube(a)
y = add(b, c)
y.backward()
print(x.grad)

functions_list: [<__main__.Add object at 0x7fefb41d9a00>]
Use function: <__main__.Add object at 0x7fefb41d9a00>
Collect function: <__main__.Square object at 0x7fefb40f7f70>
Collect function: <__main__.Cube object at 0x7fefd011e7f0>
functions_list: [<__main__.Square object at 0x7fefb40f7f70>, <__main__.Cube object at 0x7fefd011e7f0>]

functions_list: [<__main__.Square object at 0x7fefb40f7f70>, <__main__.Cube object at 0x7fefd011e7f0>]
Use function: <__main__.Cube object at 0x7fefd011e7f0>
Collect function: <__main__.Exp object at 0x7fefb40f7cd0>
functions_list: [<__main__.Exp object at 0x7fefb40f7cd0>, <__main__.Square object at 0x7fefb40f7f70>]

functions_list: [<__main__.Exp object at 0x7fefb40f7cd0>, <__main__.Square object at 0x7fefb40f7f70>]
Use function: <__main__.Square object at 0x7fefb40f7f70>
Collect function: <__main__.Exp object at 0x7fefb40f7cd0>
functions_list: [<__main__.Exp object at 0x7fefb40f7cd0>]

functions_list: [<__main__.Exp object at 0x7fefb40f7cd0>]
Use functio

In [32]:
x = Variable(np.array(2.0))
a = square(x)
y = add(square(a), square(a))
y.backward()

print(y.data)
print(x.grad)

functions_list: [<__main__.Add object at 0x7fefb41d7a00>]
Use function: <__main__.Add object at 0x7fefb41d7a00>
Collect function: <__main__.Square object at 0x7fefb41d7670>
Collect function: <__main__.Square object at 0x7fefb41d7cd0>
functions_list: [<__main__.Square object at 0x7fefb41d7670>, <__main__.Square object at 0x7fefb41d7cd0>]

functions_list: [<__main__.Square object at 0x7fefb41d7670>, <__main__.Square object at 0x7fefb41d7cd0>]
Use function: <__main__.Square object at 0x7fefb41d7cd0>
Collect function: <__main__.Square object at 0x7fefb41d7eb0>
functions_list: [<__main__.Square object at 0x7fefb41d7eb0>, <__main__.Square object at 0x7fefb41d7670>]

functions_list: [<__main__.Square object at 0x7fefb41d7eb0>, <__main__.Square object at 0x7fefb41d7670>]
Use function: <__main__.Square object at 0x7fefb41d7670>
Collect function: <__main__.Square object at 0x7fefb41d7eb0>
functions_list: [<__main__.Square object at 0x7fefb41d7eb0>]

functions_list: [<__main__.Square object at 0x

# use heapq

In [25]:
import heapq

generations = [2, 0, 1, 4, 3]
funcs = []

for g in generations:
    f = Function()
    f.generation = g
    funcs.append(f)
    
print([f.generation for f in funcs])


pq = []
for f in funcs:
    heapq.heappush(pq, f)
    
print(pq)
print()

heapq.heapify(funcs)
print(funcs)
print()

for i, s in enumerate(zip(pq, funcs)):
    f1, f2 = s
    print(f'Element{i}:', f1 == f2)

[2, 0, 1, 4, 3]
[<__main__.Function object at 0x7fefb4848580>, <__main__.Function object at 0x7fefb40ec730>, <__main__.Function object at 0x7fefb4848eb0>, <__main__.Function object at 0x7fefb4848bb0>, <__main__.Function object at 0x7fefb40df0a0>]

[<__main__.Function object at 0x7fefb4848580>, <__main__.Function object at 0x7fefb40ec730>, <__main__.Function object at 0x7fefb4848eb0>, <__main__.Function object at 0x7fefb4848bb0>, <__main__.Function object at 0x7fefb40df0a0>]

Element0: True
Element1: True
Element2: True
Element3: True
Element4: True


In [85]:
from collections.abc import Iterable
from collections import namedtuple
from typing import NamedTuple
import functools

@functools.total_ordering
class ReverseCompare(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, ReverseCompare) and self.obj > other.obj


class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(ReverseCompare, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: ReverseCompare) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
            print('update set:', self.heapset)
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __repr__(self):
        pq = {i: self.Item(priority=x.obj.generation, item=type(x.obj).__name__) for i, x in enumerate(self.maxheap)}
        return str(pq)

    
def priority_set(iterable_queue):
    return PrioritySet()(iterable_queue)
        

In [87]:
class Function:
    def __init__(self):
        self.generation = None
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation


generations = [2, 0, 1, 4, 3]
funcs = []


for g in generations:
    f = Function()
    f.generation = g
    funcs.append(f)

ps = priority_set(funcs)

for d in ps.maxheap:
    print(d.obj.generation, id(d.obj))

d = ps.pop()
print('max:', d.generation, id(d))

update set: {140667499480976}
update set: {140667499480976, 140667021649280}
update set: {140667499480976, 140667021648128, 140667021649280}
update set: {140667499480976, 140667021648128, 140667021649280, 140667021650192}
update set: {140667021648128, 140667021649280, 140667021650336, 140667499480976, 140667021650192}
4 140667021650192
3 140667021650336
1 140667021648128
0 140667021649280
2 140667499480976
max: 4 140667021650192


In [91]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError(f"{type(data)} is not supported")
        self.data = data
        self.grad = None
        self.creator = None
        self.generation = 0

    def set_creator(self, function) -> NoReturn:
        self.creator = function
        self.generation = function.generation + 1
        return None

    def backward(self) -> NoReturn:
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        def priority_set(iterable_queue):
            return PrioritySet()(iterable_queue)
        
        functions_list = priority_set([self.creator])

        while functions_list:
            print('functions_list:', functions_list)
            function = functions_list.pop()
            print('Use function:', function)
            
            gys = [output.grad for output in function.outputs]
            gxs = function.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(function.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx  # DO NOT USE += operator

                if x.creator is not None:
                    functions_list.add(ReversedPriority(x.creator))
                    print("Collect function:", x.creator)
            print('functions_list:', functions_list)
            print()

        return None

    def cleargrad(self) -> NoReturn:
        self.grad = None
        return None


class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y)) for y in ys]

        self.generation = max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self)

        self.inputs = inputs
        self.outputs = outputs
        return outputs if len(outputs) > 1 else outputs[0]
    
    def __lt__(self, other):
        return self.generation < other.generation
    
    def __eq__(self, other):
        return self.generation == other.generation

    def forward(self, xs):
        raise NotImplementedError()

    def backward(self, gys):
        raise NotImplementedError()


class Add(Function):
    def forward(self, x0: Number, x1: Number) -> Number:
        y = x0 + x1
        return y
    
    def backward(self, gy: Number) -> tuple:
        return gy, gy


class Square(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 2
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 2 * x * gy
        return gx
    

class Cube(Function):
    def forward(self, x: Number) -> Number:
        y = x ** 3
        return y

    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = 3 * (x ** 2) * gy
        return gx


class Exp(Function):
    def forward(self, x: Number) -> Number:
        y = np.exp(x)
        return y
    
    def backward(self, gy: Number) -> Number:
        x = self.inputs[0].data
        gx = np.exp(x) * gy
        return gx


def add(x0: Number ,x1: Number) -> Number:
    return Add()(x0, x1)


def square(x: Number) -> Number:
    return Square()(x)


def cube(x: Number) -> Number:
    return Cube()(x)


def exp(x: Number) -> Number:
    return Exp()(x)


def as_array(x) -> np.ndarray:
    if np.isscalar(x):
        return np.array(x)
    return x


@functools.total_ordering
class ReversedPriority(object):
    def __init__(self, obj):
        self.obj = obj
        
    def __lt__(self, other):
        return isinstance(other, ReversedPriority) and self.obj > other.obj
    

class PrioritySet(object):
    Item = namedtuple("Item", ["priority", "item"])
    
    def __call__(self, queue=None):
        self.maxheap = []
        self.heapset = set()
        if queue is None:
            return self
        
        if not isinstance(queue, Iterable):
            print(f"{type(queue)} is not iterable")
            return None
        
        queue = map(ReversedPriority, queue)
        for reverse_obj in queue:
            self.add(reverse_obj)
        return self

    def add(self, x: ReversedPriority) -> None:
        if id(x.obj) not in self.heapset:
            heapq.heappush(self.maxheap, x)
            self.heapset.add(id(x.obj))
        
    def pop(self):
        x = heapq.heappop(self.maxheap)
        self.heapset.remove(id(x.obj))
        return x.obj
    
    def __len__(self):
        return len(self.maxheap)
    
    def __repr__(self):
        pq = {i: self.Item(priority=x.obj.generation, item=type(x.obj).__name__) for i, x in enumerate(self.maxheap)}
        return str(pq)


In [92]:
x = Variable(np.array(2.0))
a = square(x)
y = add(square(a), square(a))
y.backward()

print(y.data)
print(x.grad)

functions_list: {0: Item(priority=2, item='Add')}
Use function: <__main__.Add object at 0x7fef97b257c0>
Collect function: <__main__.Square object at 0x7fefb433fc70>
Collect function: <__main__.Square object at 0x7fef97b25280>
functions_list: {0: Item(priority=1, item='Square'), 1: Item(priority=1, item='Square')}

functions_list: {0: Item(priority=1, item='Square'), 1: Item(priority=1, item='Square')}
Use function: <__main__.Square object at 0x7fefb433fc70>
Collect function: <__main__.Square object at 0x7fefb433f8b0>
functions_list: {0: Item(priority=1, item='Square'), 1: Item(priority=0, item='Square')}

functions_list: {0: Item(priority=1, item='Square'), 1: Item(priority=0, item='Square')}
Use function: <__main__.Square object at 0x7fef97b25280>
Collect function: <__main__.Square object at 0x7fefb433f8b0>
functions_list: {0: Item(priority=0, item='Square')}

functions_list: {0: Item(priority=0, item='Square')}
Use function: <__main__.Square object at 0x7fefb433f8b0>
functions_list: 