# 5.6 Affine/Softmax 계층 구현하기
## 5.6.1 Affine 계층(p. 171)

In [None]:
import numpy as np

class Affine:
    def __init__(self, W, b):
        # class를 만들 때, W와 b를 지정하고 시작하게 됨
        self.W = W
        self.b = b
        self.x = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        self.x = x
        out = np.dot(x, self.W) + self.b
        return out

    def backward(self, dout):
        dx = np.dot(dout, self.W.T)
        self.dW = np.dot(self.x.T, dout)
        self.db = np.sum(dout, axis = 0)

        return dx

In [None]:
def softmax(x):
    c = np.max(x)
    x_ = np.exp(x - c)
    y = x_ / np.sum(x_)
    return y

def cross_entropy_error(y, t):
    if y.ndim == 1:
       t = t.reshape(1, t.size)
       y = y.reshape(1, y.size)

    batch_size = y.shape[0]
    return -np.sum(t * np.log(y + 1e-7)) / batch_size

In [None]:
class SoftmaxWithLoss:
    def __init__(self):
        self.loss = None
        self.y = None
        self.t = None

    def forward(self, x, t):
        self.t = t
        self.y = softmax(x)
        print("self.y :", self.y)
        self.loss = cross_entropy_error(self.y, self.t)
        return self.loss
    
    def backward(self, dout=1):
        batch_size = self.t.shape[0]
        dx = (self.y - self.t) / batch_size

        return dx

In [None]:
# 예시로 돌려보기
import numpy as np


SlLayer = SoftmaxWithLoss()


# input
x_ = np.array([[0.1, 0.6, 0.3], [0.2, 0.2, 0.6],[0.05, 0.05, 0.9]])
# 정답
t_ = np.array([[0, 0, 1], [0, 1, 0], [0, 0, 1]])


# cross entropy error
loss_1 = SlLayer.forward(x_, t_)
print("Cross Entropy Error: ", loss_1)
dx = SlLayer.backward()
print("dx: ", dx)

<class 'numpy.ndarray'>
self.y : [[0.08433704 0.13904827 0.10300949]
 [0.09320684 0.09320684 0.13904827]
 [0.08022387 0.08022387 0.18769553]]
Cross Entropy Error:  2.1062666488247532
dx:  [[ 0.02811235  0.04634942 -0.29899684]
 [ 0.03106895 -0.30226439  0.04634942]
 [ 0.02674129  0.02674129 -0.27076816]]
