## 22 연산자 오버로드(3)

+ 실제로 연산자들은 매우 많다.
+ __neg__(self) : -self
+ __sub__(self) : self-other
+ __rsub__(self) : other-self
+ __truediv__(self,other) : self/other
+ __rtruediv__(self,other) : other/self
+ __pow__(self,other) : self ** other

### 이전 코드


In [67]:
import weakref
import numpy as np
import contextlib


class Config:
    enable_backprop = True


@contextlib.contextmanager
def using_config(name, value):
    old_value = getattr(Config, name)
    setattr(Config, name, value)
    try:
        yield
    finally:
        setattr(Config, name, old_value)


def no_grad():
    return using_config('enable_backprop', False)


In [68]:
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

In [69]:
def as_variable(obj):
  if isinstance(obj, Variable):
    return obj
  return Variable(obj)

In [70]:
class Variable:
  __array_priority__ = 200
  def __init__(self, data, name = None) :
    if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
    self.data = data
    # 이름 지정
    self.name = name
    self.grad = None
    self.creator = None
    #세대 수 표현
    self.generation = 0
  def set_creator(self, func):
    self.creator = func
    self.generation = func.generation + 1
  
  def backward(self, retain_grad = False):
    if self.grad is None:
      self.grad = np.ones_like(self.data)
    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key = lambda x: x.generation)

    add_func(self.creator)
    while funcs:
      f = funcs.pop()
      # output이 많으니까 차례차례 진행
      gys = [output().grad for output in f.outputs]
      gxs = f.backward(*gys)
      if not isinstance(gxs, tuple) :
        # 튜플 변환
        gxs = (gxs,)
      for x, gx in zip(f.inputs, gxs):
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx
        if x.creator is not None:
          add_func(x.creator)
      # 중간 미분값 삭제
      if not retain_grad:
        for y in f.outputs:
          y().grad = None
  def cleargrad(self):
    self.grad = None
  #shape
  @property
  def shape(self):
    return self.data.shape
  
  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size
  #property는 메소드를 인스턴스 변수처럼 사용할 수 있음

  @property
  def dtype(self):
    return self.data.dtype
  
  def __len__(self):
    return len(self.data)
    
  def __repr__(self):
    if self.data is None:
      return 'variable(None)'
    p = str(self.data).replace('\n','\n'+' '*9)
    return 'variable(' + p + ')'


In [71]:
import weakref
class Function : 
  def __call__(self, *inputs):
    inputs = [as_variable(x) for x in inputs]
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)
    if not isinstance(ys, tuple):
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys ]
    if Config.enable_backprop:
      self.generation = max([x.generation for x in inputs])
      for output in outputs :
        output.set_creator(self)
      self.inputs = inputs
      self.outputs = [weakref.ref(output) for output in outputs]
    return outputs if len(outputs) > 1 else outputs[0]
  
  def forward(self,xs):
    raise NotImplementedError()

  def backward(self,gys):
    raise NotImplementedError()

In [72]:
class Add(Function):
  def forward(self,x0,x1):
    y = x0 + x1 
    return y
  def backward(self, gy):
    return gy,gy

In [73]:
def add(x0, x1):
  return Add()(x0,x1)

In [74]:
class Square(Function):
  def forward(self,x):
    y = x**2
    return y
    
  def backward(self, gy):
    x = self.inputs[0].data
    gx = 2*x*gy
    return gx

In [75]:
def square(x):
  return Square()(x)

In [76]:
class Mul(Function):
  def forward(self, x0, x1):
    y = x0 * x1
    return y
  
  def backward(self, gy):
    x0, x1 = self.inputs[0].data, self.inputs[1].data
    return gy * x1, gy * x0

In [77]:
def mul(x0, x1):
  return Mul()(x0,x1)

In [78]:
Variable.__add__ = add
Variable.__radd__ = add
Variable.__mul__ = mul
Variable.__rmul__ = mul


### 22.1 음수(부호 변환)

In [79]:
class Neg(Function) :
  def forward(self, x):
    return -x
  def backward(self, gy):
    return -gy

def neg(x):
  return Neg()(x)

Variable.__neg__ = neg

In [80]:
x = Variable(np.array(2.0))
y = -x
print(y)

variable(-2.0)


### 22.2 뺄셈

+ 일단 왼쪽 부분부터 하니까 x0은 자연스럽게 Variable이다. 하지만 x1은 많은 형태를 가지고 있어야하기 때문에 as_array로 변환해준다.

In [81]:
class Sub(Function):
  def forward(self, x0, x1):
    y = x0 - x1
    return y

  def backward(self, gy):
    return gy, -gy



In [82]:
def sub(x0, x1):
  x1 = as_array(x1)
  return Sub()(x0,x1)

Variable.__sub__ = sub

+ rsub는 오른쪽이 Variable이다. 하지만 들어올때 rsub이기 때문에 x0이 Variable, x1 이 다양한 형태이다. 이때도 똑같이 as_array를 해야하지만 return 해줄때 실질적으로 x1-x0이기 때문에 반대로 넣어줘야한다.


In [85]:
def rsub(x0, x1):
  x1 = as_array(x1)
  return Sub()(x1,x0)

Variable.__rsub__ = rsub

In [86]:
x = Variable(np.array(2.0))
y1 = 2.0 - x
y2 = x - 1.0
print(y1)
print(y2)

variable(0.0)
variable(1.0)


### 22.3 나눗셈

In [87]:
class Div(Function):
  def forward(self, x0, x1):
    y = x0/x1
    return y

  def backward(self, gy):
    x0, x1 = self.inputs[0].data, self.inputs[1].data
    gx0 = gy / x1
    gx1 = gy * (-x0 / x1 ** 2)
    return gx0, gx1

def div(x0,x1):
  x1 = as_array(x1)
  return Div()(x0,x1)

def rdiv(x0, x1):
  x1 = as_array(x1)
  return Div()(x1,x0)

Variable.__truediv__ = div
Variable.__rtruediv__ = rdiv

In [90]:
x = Variable(np.array(2.0))
y = x/3
y2 = 3/x
print(y)
print(y2)

variable(0.6666666666666666)
variable(1.5)


### 22.4 거듭제곱

In [88]:
class Pow(Function):
  def __init__(self, c):
    self.c = c
  
  def forward(self,x):
    y = x ** self.c
    return y
  
  def backward(self, gy):
    x = self.inputs[0].data
    c = self.c
    gx = c * x ** (c-1) * gy
    return gx

def pow(x, c):
  return Pow(c)(x)

Variable.__pow__ = pow


In [89]:
x = Variable(np.array(2.0))
y = x**3
print(y)

variable(8.0)
