In [1]:
import numpy as np

def softmax(x):
   e_x = np.exp(x - np.max(x, axis=0, keepdims=True))
   return e_x / np.sum(e_x, axis=0, keepdims=True)

In [2]:
def rnn_cell_forward(xt, a_prev, parameters):
   Wax = parameters["Wax"]
   Waa = parameters["Waa"]
   Wya = parameters["Wya"]
   ba = parameters["ba"]
   by = parameters["by"]

   a_next = np.tanh(Wax @ xt + Waa @ a_prev + ba)
   yt_pred = softmax(Wya @ a_next + by)

   cache = (a_next, a_prev, xt, parameters)
   return a_next, yt_pred, cache

In [3]:
def rnn_forward(x, a0, parameters):

    caches = []
    n_x, m, T_x = x.shape
    n_y, n_a = parameters["Wya"].shape

    a = np.zeros((n_a, m, T_x))
    y_pred = np.zeros((n_y, m, T_x))

    a_next = a0
    for t in range(T_x):
        xt = x[:, :, t]
        a_next, yt_pred, cache = rnn_cell_forward(xt, a_next, parameters)
        a[:, :, t] = a_next
        y_pred[:, :, t] = yt_pred
        caches.append(cache)

    caches = (caches, x)
    return a, y_pred, caches

In [4]:
# 테스트 코드
np.random.seed(1)
n_x, m, T_x = 3, 10, 4
n_a, n_y = 5, 2
x_test = np.random.randn(n_x, m, T_x)
a0_test = np.random.randn(n_a, m)
parameters_test = {
    "Waa": np.random.randn(n_a, n_a),
    "Wax": np.random.randn(n_a, n_x),
    "Wya": np.random.randn(n_y, n_a),
    "ba": np.random.randn(n_a, 1),
    "by": np.random.randn(n_y, 1)
}
a_out, y_pred_out, caches_out = rnn_forward(x_test, a0_test, parameters_test)
print("a_out[4] = \n", a_out[4])
print("y_pred_out[1] = \n", y_pred_out[1])

a_out[4] = 
 [[-0.99935897 -0.57882882  0.99953622  0.99692362]
 [-0.99999375  0.77911235 -0.99861469 -0.99833267]
 [ 0.98895163  0.9905525   0.87805502  0.99623046]
 [ 0.9999802   0.99693738  0.99745184  0.97406138]
 [-0.9912801   0.98087418  0.76076959  0.54482277]
 [ 0.74865774 -0.59005528 -0.97721203  0.92063859]
 [-0.96279238 -0.99825059  0.95668547 -0.76146336]
 [-0.99251598 -0.95934467 -0.97402324  0.99861032]
 [ 0.93272501  0.81262652  0.65510908  0.69252916]
 [-0.1343305  -0.99995298 -0.9994704  -0.98612292]]
y_pred_out[1] = 
 [[8.70631878e-04 1.09227408e-01 2.95793685e-01 8.02699998e-02]
 [6.05834882e-04 6.73052187e-01 1.21038427e-03 1.17806974e-02]
 [5.72253732e-03 4.00062909e-03 2.06047094e-03 7.23375910e-01]
 [7.95603732e-01 8.62248606e-01 1.11182569e-01 8.15159466e-01]
 [2.57915964e-01 5.83969344e-01 9.40379273e-01 4.35479788e-02]
 [6.56311720e-01 1.13248528e-03 1.02336394e-02 1.00862411e-01]
 [6.07472236e-02 2.92414498e-01 1.14848948e-01 6.65231284e-01]
 [6.44953539e-02 

In [5]:
def rnn_cell_backward(da_next, cache):
    """
    단일 RNN 셀의 역전파 단계를 구현합니다.
    """
    (a_next, a_prev, xt, parameters) = cache
    Wax = parameters["Wax"]
    Waa = parameters["Waa"]
    Wya = parameters["Wya"]

    dtanh = (1 - a_next ** 2) * da_next
    dWax = dtanh @ xt.T
    dWaa = dtanh @ a_prev.T
    dba = np.sum(dtanh, axis=1, keepdims=True)

    da_prev = Waa.T @ dtanh

    return da_prev, dWax, dWaa, dba

In [11]:
def rnn_backward(da, caches):
    """
    기본 RNN의 역전파 단계를 구현합니다.

    매개변수:
    da -- 은닉 상태에 대한 손실의 기울기, 크기 (n_a, m, T_x)인 넘파이 배열
    caches -- 순전파에서 저장된 값들의 튜플, (캐시 리스트, x)를 포함

    반환값:
    gradients -- 가중치의 기울기를 포함하는 딕셔너리
    """
    (caches, x) = caches
    n_a, m, T_x = da.shape
    n_x, _, _ = x.shape
    # 기울기 초기화
    gradients = {
        "dWax": np.zeros((n_a, n_x)),
        "dWaa": np.zeros((n_a, n_a)),
        "dba": np.zeros((n_a, 1)),
        "da_prev": np.zeros((n_a, m))
    }

    da_next = np.zeros((n_a, m))

    # 시간 축을 따라 역전파 수행
    for t in reversed(range(T_x)):
        cache = caches[t]
        da_next, dWax_t, dWaa_t, dba_t = rnn_cell_backward(da[:, :, t] + da_next, cache)

        # 기울기 누적
        gradients["dWax"] += dWax_t
        gradients["dWaa"] += dWaa_t
        gradients["dba"] += dba_t

    return gradients

In [12]:
# 역전파 테스트 코드
np.random.seed(2)
da_test = np.random.randn(n_a, m, T_x)
gradients_out = rnn_backward(da_test, caches_out)
print("dWax = \n", gradients_out["dWax"])
print("dWaa = \n", gradients_out["dWaa"])
print("dba = \n", gradients_out["dba"])

dWax = 
 [[-3.30479652 -0.5889298   2.67389609]
 [-1.90854233  0.65640142 -1.14827789]
 [-2.0974342   4.76115666 -0.44652908]
 [-3.01352381  4.25404723 -1.20572801]
 [ 0.50119786 -3.11451526 -0.39116022]]
dWaa = 
 [[-2.13574198e+00  6.80711347e-01 -1.44129083e+00  2.45147206e+00
   1.08426242e-03]
 [ 3.59828711e+00  1.90362485e+00  5.50531297e+00 -2.46100000e+00
   2.41395762e+00]
 [ 4.84818068e+00  3.71006436e+00  2.16039211e+00  1.97751724e+00
   4.78034097e+00]
 [ 1.00859633e+01  7.14294000e+00  2.59060550e+00  4.46558687e+00
   8.66822485e-01]
 [ 1.73852065e+00  1.31795266e-02  1.15182331e+00 -1.55384609e+00
  -9.78216278e-01]]
dba = 
 [[-2.46456133]
 [10.68553905]
 [ 6.82076449]
 [ 8.92964609]
 [ 2.76198809]]
