<a href="https://colab.research.google.com/github/PSuHyeon/Simple_TensorFlow/blob/main/Var_Fun.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
import numpy as np
import weakref
import contextlib


In [50]:
class Config:
  enable_backprop = True

@contextlib.contextmanager
def using_config(name, value):
  oldvalue = getattr(Config, name)
  setattr(Config, name, value)
  try:
    yield
  finally:
    setattr(Config, name, oldvalue)


def no_grad():
  return using_config('enable_backprop', False)

def as_array(x):
  if np.isscalar(x):
    return np.array(x)
  return x

def as_variable(obj):
  if isinstance(obj, Variable):
    return obj
  else:
    return Variable(obj)

Class Variable:

In [51]:
class Variable:
  __array_priority_ = 200  #이건 ndarray 가 왼쪽에 올 때 Variable의 operator 가 우선권을 주기 위함임. 
  def __init__(self, data, name = None):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.'.format(type(data)))

    self.data = data
    self.name = name
    self.grad = None
    self.creator = None
    self.generation = 0
  
  @property
  def shape(self):
    return self.data.shape

  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size

  @property
  def dtype(self):
    return self.data.dtype

  @property
  def T(self):
    return transpose(self, None)

  def set_creator(self, f):
    self.creator = f
    self.generation = f.generation + 1

  def cleargrad(self):
    self.grad = None

  def __len__(self):
    return len(self.data)
  
  def __repr__(self):
    if self.data is None: 
      return 'variable(None)'
    else:
      p = str(self.data).replace('\n', '\n' + ' ' * 9)
      return 'variable(' + p + ')'

  def backward(self, retain_grad = False, create_graph=False):
    
    if self.grad is None: 
      self.grad = Variable(np.ones_like(self.data))

    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key = lambda x: 0 if x is None else x.generation)

    add_func(self.creator)
    while funcs:
      f = funcs.pop()
      gys = [output().grad for output in f.outputs]
      with using_config('enable_backprop', create_graph):
        gxs = f.backward(*gys)
        if not isinstance(gxs, tuple):
          gxs = (gxs,) 
        for x, gx in zip(f.inputs, gxs): 
          if x.grad is not None:
            x.grad = x.grad + gx
          else:
            x.grad = gx
        
          if x.creator is not None:
            add_func(x.creator)
        if not retain_grad:
          for y in f.outputs:
            y().grad = None
  def reshape(self, *sh):
    if len(sh) == 1 and isinstance(sh[0], (tuple, list)):
      sh = sh[0]
    return reshape(self,sh)
  def sum(self, axis = None, keepdims = False):
    return sum(self, axis, keepdims)

Class Function (this will be overloaded into specific function):

  Define by run method --> as function proceed forward, result variable keeps tracks of creator (function/method) of which it was made. 

In [52]:
class Function:

  def __call__(self, *inputs):
    inputs = [as_variable(input) for input in inputs]
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)
    if not isinstance(ys, tuple):
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    if Config.enable_backprop:
      self.generation = max([input.generation for input in inputs])
      for o in outputs:
        o.set_creator(self)
      self.inputs = inputs
      self.outputs = [weakref.ref(output) for output in outputs]

    return outputs if len(outputs) > 1 else outputs[0]

  def forward(self, x):
    raise NotImplementedError

  def backward(self, x):
    raise NotImplementedError

Specific higher functions:

In [53]:
class Square(Function):

  def forward(self, input):
    y = input ** 2
    return y
  
  def backward(self, gy):
    gx = 2 * self.inputs[0].data * gy
    return gx

def square(x):
  return Square()(x)

In [54]:
class Exp(Function):
  
  def forward(self, input):
    y = np.exp(input)
    return 
  
  def backward(self, gy):
    gx = np.exp(self.inputs[0].data) * gy
    return gx

def exp(x):
  return Exp()(x)

In [55]:
class Add(Function):

  def forward(self, x0, x1):
    self.x0_shape, self.x1_shape = x0.shape, x1.shape
    y = x0 + x1
    return y 

  def backward(self, gy):
    gx0, gx1 = gy, gy
    if self.x0_shape != self.x1_shape:
      gx0 = sum_to(gx0, self.x0_shape)
      gx1 = sum_to(gx1, self.x1_shape)
    return gx0, gx1

def add(x1, x2):
  x1 = as_array(x1)
  return Add()(x1, x2)

Variable.__add__ = add
Variable.__radd__ = add

In [56]:
class Mul(Function):
  
  def forward(self, x0, x1):
    self.x0_shape, self.x1_shape = x0.shape, x1.shape
    y = x0 * x1
    return  y
  
  def backward(self, gy):
    x0, x1 = self.inputs[0], self.inputs[1]
    gx0, gx1 = x1 * gy, x0 * gy
    if x0.shape != x1.shape:
      gx0 = sum_to(gx0, x0.shape)
      gx1 = sum_to(gx1, x1.shape)
    return gx0, gx1 

def mul(x0, x1):
  x1 = as_array(x1)
  return Mul()(x0, x1)

Variable.__mul__ = mul
Variable.__rmul__ = mul

In [57]:
class Neg(Function):
  def forward(self, x):
    return -x
  def backward(self, gy):
    return -gy
  
def neg(x):
  return Neg()(x)

Variable.__neg__ = neg 

In [58]:
class Sub(Function):
  def forward(self, x0, x1):
    self.x0_shape, self.x1_shape = x0.shape, x1.shape
    y = x0 - x1 
    return y
  def backward(self, gy):
    gx0, gx1 = gy, -gy
    if self.x0_shape != self.x1_shape:
      gx0 = sum_to(gx0, self.x0_shape)
      gx1 = sum_to(gx1, self.x1_shape)
    return gx0, gx1

  
def sub(x0, x1):
  x1 = as_array(x1)
  return Sub()(x0, x1)

def rsub(x0, x1): 
  x1 = as_array(x1)
  return Sub()(x1, x0)
Variable.__sub__ = sub
Variable.__rsub__ = rsub


In [59]:
class Div(Function):
  def forward(self, x0, x1):
    self.x0_shape, self.x1_shape = x0.shape, x1.shape
    y = x0 / x1
    return y
  
  def backward(self, gy):
    x0, x1 = self.inputs[0], self.inputs[1]
    gx0 = gy / x1
    gx1 = -gy * x0 / (x1 ** 2)
    if self.x0_shape != self.x1_shape:
      gx0 = sum_to(gx0, self.x0_shape)
      gx1 = sum_to(gx1, self.x1_shape)
    return gx0, gx1

def div(x0, x1):
  x1 = as_array(x1)
  return Div()(x0, x1)

def rdiv(x0, x1):
  x1 = as_array(x1)
  return Div()(x1, x0)

Variable.__truediv__ = div
Variable.__rtruediv__ = rdiv

In [60]:
class Pow(Function):

  def __init__(self, c):
    self.c = c 

  def forward(self, x):
    c = self.c
    y = x ** c
    return y
  
  def backward(self, gy):
    x = self.inputs[0]
    c = self.c
    gx = c * (x ** (c - 1)) * gy
    return gx

def pow(x,c):
  # c = as_array(c)
  return Pow(c)(x) 

Variable.__pow__ = pow


In [61]:
class Reshape(Function):
  def __init__(self, shape):
    self.shape = shape

  def forward(self, x):
    self.x_shape = x.shape
    y = x.reshape(self.shape)
    return y

  def backward(self, gy):
    return reshape(gy, self.x_shape)

def reshape(x, shape):
  if x.shape == shape:
    return as_variable(x)
  return Reshape(shape)(x)

In [62]:
class Transpose(Function):
    def __init__(self, axes=None):
        self.axes = axes

    def forward(self, x):
        y = x.transpose(self.axes)
        return y

    def backward(self, gy):
        if self.axes is None:
            return transpose(gy)

        axes_len = len(self.axes)
        inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
        return transpose(gy, inv_axes)


def transpose(x, axes=None):
    if axes is None or len(axes) == 0:
        axes = None
    elif len(axes) == 1:
        if isinstance(axes[0], (tuple, list)) or axes[0] is None:
            xes = axes[0]
    return Transpose(axes)(x)


In [63]:
class Sum(Function):
    def __init__(self, axis, keepdims):
      self.axis = axis 
      self.keepdims = keepdims

    def forward(self, x):
      self.shape = x.shape
      y = x.sum(axis=self.axis, keepdims = self.keepdims) 
      return y

    def backward(self, gy):
      gy = reshape_sum_backward(gy, self.shape, self.axis, self.keepdims)
      y = broadcast_to(gy, self.shape)
      return y

def sum(x, axis = None, keepdims = False):
  return Sum(axis, keepdims)(x)

In [64]:
class Broadcast_to(Function):
  
  def __init__(self, shape):
    self.shape = shape

  def forward(self, x):
    self.x_shape = x.shape
    y = np.broadcast_to(x, self.shape)
    return y

  def backward(self, gy):
    gx = sum_to(gy, self.x_shape)
    return gx 

def broadcast_to(x, shape):
  if x.shape == shape:
    return as_variable(x)
  return Broadcast_to(shape)(x) 

In [65]:
class Sum_to(Function):
  
  def __init__(self, shape):
    self.shape = shape

  def forward(self, x):
    self.x_shape = x.shape
    y = u_sum_to(x, self.shape)
    return y
  
  def backward(self, gy):
    return broadcast_to(gy, self.shape) 

def sum_to(x, shape):
  if x.shape == shape:
    return as_variable(x)
  return Sum_to(shape)(x)

Numerical differentiation:

In [66]:
def numerical_diff(f, x, e=1e-4):
  x0 = Variable(x.data + e)
  x1 = Variable(x.data - e)
  y0 = f(x0)
  y1 = f(x1)
  return (y0.data - y1.data) / (2 * e) 

auxilary function for sum

In [67]:
def reshape_sum_backward(gy, x_shape, axis, keepdims):
    ndim = len(x_shape)
    tupled_axis = axis
    if axis is None:
        tupled_axis = None
    elif not isinstance(axis, tuple):
        tupled_axis = (axis,)

    if not (ndim == 0 or tupled_axis is None or keepdims):
        actual_axis = [a if a >= 0 else a + ndim for a in tupled_axis]
        shape = list(gy.shape)
        for a in sorted(actual_axis):
            shape.insert(a, 1)
    else:
        shape = gy.shape

    gy = gy.reshape(shape)  
    return gy

In [68]:
def u_sum_to(x, shape):

    ndim = len(shape)
    lead = x.ndim - ndim
    lead_axis = tuple(range(lead))

    axis = tuple([i + lead for i, sx in enumerate(shape) if sx == 1])
    y = x.sum(lead_axis + axis, keepdims=True)
    if lead > 0:
        y = y.squeeze(lead_axis)
    return y