## 16 복잡한 계산 그래프(구현 편)


### 16.1 세대 추가

In [1]:
class Variable:
  def __init__(self, data) :
    if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    #세대 수 표현
    self.generation = 0
  def set_creator(self, func):
    self.creator = func
    self.generation = func.generation + 1
  
  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)
    funcs = [self.creator]
    while funcs:
      f = funcs.pop()
      # output이 많으니까 차례차례 진행
      gys = [output.grad for output in f.outputs]
      gxs = f.backward(*gys)
      if not isinstance(gxs, tuple) :
        # 튜플 변환
        gxs = (gxs,)
      for x, gx in zip(f.inputs, gxs):
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx
        if x.creator is not None:
          funcs.append(x.creator)

  def cleargrad(self):
    self.grad = None

In [2]:
class Function : 
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    # 별표를 붙여 언팩
    ys = self.forward(*xs)
    # 튜플이 아닌 경우 추가 지원
    if not isinstance(ys, tuple):
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys ]
    # x 중에 가장 큰 세대 다음이니까
    self.generation = max([x.generation for x in inputs])
    for output in outputs :
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs
    return outputs if len(outputs) > 1 else outputs[0]
  
  def forward(self,xs):
    raise NotImplementedError()

  def backward(self,gys):
    raise NotImplementedError()

### 16.2 세대 순으로 꺼내기

In [5]:
if __name__ == '__main__':
  generations = [2,0,1,4,2]
  funcs = []
  for g in generations :
    f = Function()
    f.generation = g
    funcs.append(f)

  print([f. generation for f in funcs])

[2, 0, 1, 4, 2]


In [7]:
if __name__ == '__main__':
  funcs.sort(key=lambda x: x.generation)
  print([f.generation for f in funcs])

  f = funcs.pop()
  f.generation


[0, 1, 2, 2, 4]


### 16.3 Variable 클래스의 backward

In [25]:
class Variable:
  def __init__(self, data) :
    if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    #세대 수 표현
    self.generation = 0
  def set_creator(self, func):
    self.creator = func
    self.generation = func.generation + 1
  
  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)
    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key = lambda x: x.generation)

    add_func(self.creator)
    while funcs:
      f = funcs.pop()
      # output이 많으니까 차례차례 진행
      gys = [output.grad for output in f.outputs]
      gxs = f.backward(*gys)
      if not isinstance(gxs, tuple) :
        # 튜플 변환
        gxs = (gxs,)
      for x, gx in zip(f.inputs, gxs):
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx
        if x.creator is not None:
          add_func(x.creator)

  def cleargrad(self):
    self.grad = None

In [26]:
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x


In [27]:
class Function : 
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    # 별표를 붙여 언팩
    ys = self.forward(*xs)
    # 튜플이 아닌 경우 추가 지원
    if not isinstance(ys, tuple):
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys ]
    # x 중에 가장 큰 세대 다음이니까
    self.generation = max([x.generation for x in inputs])
    for output in outputs :
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs
    return outputs if len(outputs) > 1 else outputs[0]
  
  def forward(self,xs):
    raise NotImplementedError()

  def backward(self,gys):
    raise NotImplementedError()

In [28]:
class Add(Function):
  def forward(self,x0,x1):
    y = x0 + x1 
    return y
  def backward(self, gy):
    return gy,gy

In [29]:
def add(x0, x1):
  return Add()(x0,x1)

In [30]:
class Square(Function):
  def forward(self,x):
    y = x**2
    return y
    
  def backward(self, gy):
    x = self.inputs[0].data
    gx = 2*x*gy
    return gx

In [31]:
def square(x):
  return Square()(x)

In [32]:
import numpy as np

### 16.4 동작 확인

In [35]:
if __name__ == '__main__':
  x = Variable(np.array(2.0))
  a = square(x)
  y = add(square(a), square(a))
  y.backward()
  print(y.data)
  print(x.grad)

32.0
64.0
