In [1]:
import numpy as np


def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x


class Variable:
    def __init__(self, data):
        # 要求输入一个ndarray的数组
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
        
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        #不用对最后的dy进行手动设grad为1
        if self.grad is None:
            self.grad = np.ones_like(self.data)
            
        funcs = [self.creator]
        # 这里只支持单个输入输出
        # while funcs:
        #     f = funcs.pop()
        #     x, y = f.input, f.output
        #     x.grad = f.backward(y.grad)
        #     if self.creator is not None:
        #         funcs.append(x.creator)
        while funcs:
            f = funcs.pop()
            gys = [output.grad for output in f.outputs]#取出输出的梯度
            gxs = f.backward(*gys)#反向传播得到输入的梯度
            #鉴定是否为元组，或者说数据保存为元组是因为会出现return x1, x2这种类型
            if not isinstance(gxs,tuple):
                gxs = (gxs,)
            # 使用zip来设置每一对的导数    
            for x ,gx in zip(f.inputs,gxs):
                x.grad = gx
                if x.creator is not None:
                    funcs.append(x.creator)

class Function:
    def __call__(self, *inputs):
        xs = [x.data for x in inputs]
        ys = self.forward(*xs)
        if not isinstance(ys, tuple):
            ys = (ys,)
        outputs = [Variable(as_array(y))for y in ys]
        for output in outputs:
            output.set_creator(self)
        self.inputs = inputs
        self.outputs = outputs
        # 
        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, xs):
        raise NotImplementedError

    def backward(self, gys):
        raise NotImplementedError

In [2]:
class Add(Function):
    def forward(self, x0,x1):
        y = x1+x0
        return y
    def backward(self, gy):
        return gy, gy
    
def add(x0,x1):
    return Add()(x1,x0)

In [3]:
class Square(Function):
    def forward(self, x):
        y = x**2
        return y

    def backward(self, gy):
        x = self.inputs[0].data
        gx = 2*x*gy
        return gx
def square(x):
    return Square()(x)

In [4]:
x0 = Variable(np.array(2.0))
x1 = Variable(np.array(3.0))
z = add(square(x0),square(x1))
z.backward()
print(z.data)
print(x0.grad)
print(x1.grad)

13.0
4.0
1
6.0
