Skip to content

Latest commit

 

History

History
244 lines (163 loc) · 10.5 KB

220208.md

File metadata and controls

244 lines (163 loc) · 10.5 KB

Day 116

Deep Learning from Scratch

5. 오차역전파법(backpropagation)

수치 미분은 단순하고 구현하기도 쉽지만 계산 시간이 오래 걸린다는 단점이 있다. 이보다 효율적인 것이 가중치 매개변수의 기울기를 효율적으로 계산하는 오차역전파법이다.

계산 그래프(computational graph)

계산 그래프는 계산 과정을 그래프로 나타낸 것이다. 이 그래프는 그래프 자료구조로, 복수의 노드(node)와 에지(edge)로 표현된다. 에지는 노드 사이의 직선을 말한다.

계산 그래프로 문제 풀이

문제 1 : A는 슈퍼에서 1개에 100원인 사과를 2개 샀다. 이때 지불 금액을 구하여라. 단, 소비세 10%가 부과된다.

문제 1을 계산 그래프로 풀면 아래의 그림과 같이 된다.

fig 5-1

위에서는 x2와 x1.1을 각각 하나의 연산으로 취급해 원 안에 표기한 형태이다.

하지만 곱셈인 x만을 연산으로 생각하면, 사과의 개수와 소비세를 변수로 취급하여 원 밖에 표기한다.

fig 5-2

문제 2 : A는 슈퍼에서 사과를 2개, 귤을 3개 샀다. 사과는 1개에 100원, 귤을 1개에 150원이다. 소비세가 10%일 때 지불 금액은?

fig 5-3

사과와 귤의 금액을 합산하고 소비세까지 계산해주면 된다.

계산 그래프는 왼쪽에서 오른쪽으로 계산을 진행한다.

계산 그래프를 이용한 문제 풀이

  1. 계산 그래프 구성
  2. 그래프에서 계산을 왼쪽에서 오른쪽으로 진행(순전파)

역전파는 오른쪽에서 왼쪽으로 전파되는 것을 말한다.

국소적 계산

계산 그래프는 국소적 계산을 전파하는 것으로 최종 결과를 얻는다는 것이 특징이다.

국소적 계산은 전체에서 어떤 일이 벌어지든 상관없이 자신과 관계된 정보만으로 결과를 출력할 수 있다.

fig 5-4

위의 그림에서 핵심은 각 노드에서의 계산은 국소적 계산이라는 것이다. 각 노드는 자신과 관련한 계산 외에는 아무것도 신경 쓸게 없다.

계산 그래프는 국소적 계산에 집중한다. 전체 계산이 복잡하더라도 각 단계에서 하는 일은 해당 노드의 국소적 계산이다.

계산 그래프로 푸는 이유

계산 그래프의 이점

  1. 국소적 계산으로 전체가 복잡해도 각 노드에서는 단순한 계산에 집중하여 문제를 단순화할 수 있다.
  2. 중간 계산 결과를 모두 보관할 수 있다.
  3. 역전파를 통해 미분을 효율적으로 계산할 수 있다.

문제 1에서 사과 가격이 오르면 최종 금액에 어떤 영향이 끼치는가는 '사과 가격에 대한 지불 금액의 미분'을 구하는 문제가 된다. 이는 계산 그래프에서 역전파를 하면 구할 수 있다.

fig 5-5

역전파(굵은 선)는 국소적 미분을 전달한다. 여기서 역전파는 오른쪽에서 왼쪽으로 1, 1.1, 2.2 순으로 미분 값을 전달한다. 이 결과로 사과 가격에 대한 지불 금액의 미분 값은 2.2이며 사과가 1원 오르면 최종 금액은 2.2원 오른다는 것을 알 수 있다.


연쇄 법칙

국소적 미분을 전달하는 원리는 연쇄법칙(chain rule)에 따른 것이다.

계산 그래프의 역전파

fig 5-6

y = f(x)라는 계산의 역전파를 표현한 것이다.

역전파의 계산 절차는 신호 E에 노드의 국소적 미분을 곱한 후 다음 노드로 전달하는 것이다. 여기서 국소적 미분은 순전파 때의 y = f(x) 계산의 미분을 x에 대한 y의 미분을 의미한다.

연쇄법칙이란

함성 함수는 여러 함수로 구성된 함수다. 예를 들면 z = (x + y)**2라는 식은 아래처럼 두 개의 식으로 구성된다.

e 5 1

연쇄 법칙은 합성 함수의 미분에 대한 성질로, 합성 함수의 미분은 합성 함수를 구성하는 각 함수의 미분의 곱으로 나타낼 수 있다는 법칙이다.

위의 식을 예로 들어 미분을 하면 아래와 같다.

e 5 3

이는 해석적으로 구한 결과로 최종적으로 두 값을 곱해 계산하면 답을 확인할 수 있다.

e 5 4

연쇄법칙과 계산 그래프

위의 연쇄법칙 계산을 계산 그래프로 나타내면 아래와 같이 된다.

fig 5-7

역전파의 계산 절차에서는 노드로 들어온 입력 신호에 그 노드의 국소적 미분(편미분)을 곱한 후 다음 노드로 전달한다.

예를 들어 **2 노드에서의 역전파는 입력이 dz/dz이며, 이에 국소적 미분인 dz/dt(순전파 시에는 입력이 t이고 출력이 z이므로 이 노드에서 국소적 미분은 dz/dt이다)를 곱하고 다음 노드로 넘긴다.


역전파

덧셈 노드의 역전파

식 z = x + y를 예로 들어 역전파를 확인하겠다.

z = x + y를 해석적으로 계산하면 다음과 같다.

e 5 5

이의 계산 그래프는 다음과 같다.

fig 5-9

덧셈 노드의 역전파는 1을 곱하기만 하기 때문에 입력된 값을 그대로 다음 노드로 보낸다.

구체적인 예

fig 5-11

곱셈 노드의 역전파

식은 z = xy이고 이 식의 미분은 다음과 같다.

e 5 6

계산 그래프는 다음과 같다.

fig 5-12

곱셈 노드의 역전파는 상류의 값에 순전파 때의 입력 신호들을 '서로 바꾼 값'을 곱해서 하류로 보낸다.

따라서 곱셈의 역전파는 순방향 입력 신호의 값이 필요해서 곱셈 노드를 구현할 때는 순전파의 입력 신호를 변수에 저장해둔다.

구체적인 예

fig 5-13


단순한 계층 구현

곱셈 계층

모든 계층은 순전파를 처리하는 forward(), 역전파를 처리하는 backward()라는 공통 메소드를 갖도록 구현한다.

곱셈 계층을 구현하면 아래와 같다.

class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None
    
    def forward(self, x, y):
        self.x = x
        self.y = y
        out = x * y
        
        return out
    
    def backward(self, dout):
        dx = dout * self.y # x와 y의 자리를 바꾼다.
        dy = dout * self.x
        
        return dx, dy

__init__()에서는 인스턴스 변수인 x와 y를 초기화 한다. 이 두 변수는 순전파 시의 입력 값을 유지하기 위해 사용한다.

forward()에서는 x와 y를 인수로 받아 두 값을 곱해서 반환한다.

backward()에서는 상류에서 넘어온 미분(dout)에 순전파 때의 값을 서로 바꿔 곱한 후에 하류로 흘린다.

fig 5-16

이 그림을 구현하면 다음과 같다.

apple = 100
apple_num = 2
tax = 1.1

# 계층
mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()

# 순전파
apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)

print(price)

220.00000000000003

# 역전파
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

print(dapple, dapple_num, dtax)

2.2 110.00000000000001 200

덧셈 계층

class AddLayer():
    def __init__(self):
        pass
    
    def forward(self, x, y):
        out = x + y
        return out
    
    def backward(self, dout):
        dx = dout * 1
        dy = dout * 1
        return dx, dy

fig 5-17

위의 그림을 구현하면 다음과 같다.

apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1

# 계층
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()

# 순전파
apple_price = mul_apple_layer.forward(apple, apple_num)
orange_price = mul_orange_layer.forward(orange, orange_num)
all_price = add_apple_orange_layer.forward(apple_price, orange_price)
price = mul_tax_layer.forward(all_price, tax)

# 역전파
dprice = 1
dall_price, dtax = mul_tax_layer.backward(dprice)
dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)
dorange, dorange_num = mul_orange_layer.backward(dorange_price)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

print(price)
print(dapple_num, dapple, dorange, dorange_num, dtax)

715.0000000000001
110.00000000000001 2.2 3.3000000000000003 165.0 650