## 32 고차 미분(구현 편)

In [None]:
import weakref
import numpy as np
import contextlib

class Variable:
  __array_priority__ = 200
  def __init__(self, data, name = None) :
    if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
    self.data = data
    self.name = name
    self.grad = None
    self.creator = None
    self.generation = 0
  def set_creator(self, func):
    self.creator = func
    self.generation = func.generation + 1
  
  def backward(self, retain_grad = False):
    if self.grad is None:
      #self.grad = np.ones_like(self.data)
      self.grad = Variable(np.ones_like(self.data))

    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key = lambda x: x.generation)

    add_func(self.creator)
    while funcs:
      f = funcs.pop()
      gys = [output().grad for output in f.outputs]
      gxs = f.backward(*gys)
      if not isinstance(gxs, tuple) :
        gxs = (gxs,)
      for x, gx in zip(f.inputs, gxs):
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx
        if x.creator is not None:
          add_func(x.creator)
      if not retain_grad:
        for y in f.outputs:
          y().grad = None
  def cleargrad(self):
    self.grad = None
  @property
  def shape(self):
    return self.data.shape
  
  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size

  @property
  def dtype(self):
    return self.data.dtype
  
  def __len__(self):
    return len(self.data)
    
  def __repr__(self):
    if self.data is None:
      return 'variable(None)'
    p = str(self.data).replace('\n','\n'+' '*9)
    return 'variable(' + p + ')'



### 32.2 함수 클래스의 역전파

In [None]:
class Mul(Function):
  def forward(self, x0, x1):
    y = x0 * x1
    return y
  
  def backward(self, gy):
    #x0, x1 = self.inputs[0].data, self.inputs[1].data
    x0, x1 = self.inputs
    return gy * x1, gy * x0


In [None]:
class Square(Function):
  def forward(self,x):
    y = x**2
    return y
    
  def backward(self, gy):
    x, = self.inputs
    gx = 2*x*gy
    return gx

In [None]:
class Div(Function):
  def forward(self, x0, x1):
    y = x0/x1
    return y

  def backward(self, gy):
    x0, x1 = self.inputs
    gx0 = gy / x1
    gx1 = gy * (-x0 / x1 ** 2)
    return gx0, gx1

In [None]:
class Pow(Function):
  def __init__(self, c):
    self.c = c
  
  def forward(self,x):
    y = x ** self.c
    return y
  
  def backward(self, gy):
    x, = self.inputs
    c = self.c
    gx = c * x ** (c-1) * gy
    return gx

### 32.3 역전파를 더 효율적으로 (모드 추가)

In [None]:
import weakref
import numpy as np
import contextlib

class Variable:
  __array_priority__ = 200
  def __init__(self, data, name = None) :
    if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
    self.data = data
    self.name = name
    self.grad = None
    self.creator = None
    self.generation = 0
  def set_creator(self, func):
    self.creator = func
    self.generation = func.generation + 1
  
  def backward(self, retain_grad = False, create_graph=False):
    if self.grad is None:
      self.grad = Variable(np.ones_like(self.data))

    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key = lambda x: x.generation)

    add_func(self.creator)
    while funcs:
      f = funcs.pop()
      gys = [output().grad for output in f.outputs]
      
      with using_config('enable_backprop', create_graph):
        gxs = f.backward(*gys)
        if not isinstance(gxs, tuple) :
          gxs = (gxs,)
        for x, gx in zip(f.inputs, gxs):
          if x.grad is None:
            x.grad = gx
          else:
            x.grad = x.grad + gx
          if x.creator is not None:
            add_func(x.creator)
    if not retain_grad:
      for y in f.outputs:
        y().grad = None
  def cleargrad(self):
    self.grad = None
  @property
  def shape(self):
    return self.data.shape
  
  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size

  @property
  def dtype(self):
    return self.data.dtype
  
  def __len__(self):
    return len(self.data)
    
  def __repr__(self):
    if self.data is None:
      return 'variable(None)'
    p = str(self.data).replace('\n','\n'+' '*9)
    return 'variable(' + p + ')'



### 32.4 __init__.py 변경

In [None]:
is_simple_core = False # True

if is_simple_core : 
  from dezero.core_simple import Variable
  from dezero.core_simple import Function
  from dezero.core_simple import using_config
  from dezero.core_simple import no_grad
  from dezero.core_simple import as_array
  from dezero.core_simple import as_variable
  from dezero.core_simple import setup_variable

else :
  from dezero.core import as_variable
  from dezero.core import Function
  from dezero.core import using_config
  from dezero.core import no_grad
  from dezero.core import as_array
  from dezero.core import as_variable
  from dezero.core import setup_variable
setup_variable()