In [4]:
%%writefile utils.py
import numpy as np

class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))
        self.data = data
        self.grad = None
        self.creator = None
        
    def set_creator(self, func):
        '''
        연산 함수 저장(연결 기록)
        '''
        self.creator = func
        
    def backward_recur(self):
        f = self.creator # 어떤 함수사용했는지 가져오기
        if f is not None:
            x = f.input # 입력 가져오기
            x.grad = f.backward(self.grad) # 역전파 계산
            x.backward_recur() # 재귀로 호출
    
    def backward(self):
        if self.grad is None: # gradient init
            self.grad = np.ones_like(self.data)
        funcs = [self.creator]
        while(funcs):
            f = funcs.pop()
            x, y = f.input, f.output
            x.grad = f.backward(y.grad)
            if x.creator is not None:
                funcs.append(x.creator)

                
def as_array(x): # np의 경우 return이 np.float인 경우가 있으니 array로 변환
    if np.isscalar(x):
        return np.array(x)
    return x

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        
        output = Variable(as_array(y))
        output.set_creator(self)
        
        self.input = input
        self.output = output
        return output
    
    def forward(self, x):
        '''
        연산 작성
        '''
        raise NotImplementedError() # 구현 안되어 있음을 의미
        
    def backward(self, gy):
        '''
        gy는 chain rule로 곱해주기 위해 이전의 기울기
        '''
        raise NotImplementedError()
        
class Square(Function):
    def forward(self, x):
        return x ** 2
    
    def backward(self, gy):
        x = self.input.data
        gx = 2 *  x * gy
        return gx 
def square(x):
    return Square()(x)


class Exp(Function):
    def forward(self, x):
        return np.exp(x)
    
    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx
def exp(x):
    return Exp()(x)

    
def numerical_diff(f, x, eps=1e-4):
    '''
    numerical diff
    수치 미분으로 f(x + h) - f(x - h) / 2*h : 
                  h : lim -> 0
    '''
    x0 = Variable(x.data - eps)
    x1 = Variable(x.data + eps)
    y0 = f(x0)
    y1 = f(x1)
    return (y1.data - y0.data) / (2 * eps)

# def f(x):
#     '''
#     composite function diff
#     '''
#     A = Square()
#     B = Exp()
#     C = Square()
#     return C(B(A(x)))

Overwriting utils.py


### Unit Test

In [7]:
%%writefile unittest_square.py
from utils import *
import unittest

class SquareTest(unittest.TestCase):
    def test_forward(self):
        x = Variable(np.array(2.0))
        y = square(x)
        expected = np.array(4.0)
        self.assertEqual(y.data, expected)
        
    def test_backward(self):
        x = Variable(np.array(3.0))
        y = square(x)
        y.backward()
        expected = np.array(6.0)
        self.assertEqual(x.grad, expected)
        
    def test_gradient_check(self): # 기울기 검사 자동화
        x = Variable(np.random.rand(1))
        y = square(x)
        y.backward()
        num_grad = numerical_diff(square, x)
        flg = np.allclose(x.grad, num_grad) # 어느정도 가까우면 true |a-b| <= (atol(1e-8) + rtol(1e-5) * b)
        self.assertTrue(flg)

Overwriting unittest_square.py


In [2]:
# !python -m unittest unittest_square.py

Traceback (most recent call last):
  File "C:\python_anaconda3\lib\runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\python_anaconda3\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\hyunsoo\study\g\Study_\Mystudy\DL_from_Scratch\dl_3\unittest.py", line 2, in <module>
    import unittest
  File "C:\Users\hyunsoo\study\g\Study_\Mystudy\DL_from_Scratch\dl_3\unittest.py", line 4, in <module>
    class SquareTest(unittest.TestCase):
AttributeError: partially initialized module 'unittest' has no attribute 'TestCase' (most likely due to a circular import)
