# RNN

In [1]:
from torch import nn
from torch.nn import functional as F
import math
import torch

### RNNのスクラッチ実装

##### • 入力層の次元数と，隠れ層の次元数を引数にとる
##### • 入力層と隠れ層の重みとバイアスをパラメータとして保持
##### • 初期値は に従う確率分布からランダムサンプル
##### • forwardメソッドに順伝搬を行う処理を記述
##### • input: [batch_size, seq_len, input_size]およびh_0:[1, batch_size, hidden_size]を引数にする
##### • 全stepの隠れ状態[batch_size, seq_len, hidden_size]および最後の
##### stepの隠れ状態[1, batch_size, hidden_size]を戻り値として返す
##### • 出力層は実装不要
##### • forwardメソッドの出力をnn.Linearに入力し最終的な予測値を計算する想定

In [2]:
class MyRNN:
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        init_range = 1.0 / math.sqrt(hidden_size)
        self.W_in = torch.empty(hidden_size, input_size).uniform_(-init_range, init_range)
        self.W_h = torch.empty(hidden_size, hidden_size).uniform_(-init_range, init_range)
        
        self.b_in = torch.empty(hidden_size).uniform_(-init_range, init_range)
        self.b_h = torch.empty(hidden_size).uniform_(-init_range, init_range)

    def forward(self, input, h_0=None):
        # input: [batch_size, seq_len, input_size]
        batch_size, seq_len, _ = input.size()
        self.h_0 = h_0
        
        if h_0 is None:
            self.h_0 = torch.zeros(1, batch_size, self.hidden_size)#.to(device)

        outputs = []
        h = h_0 # [1, batch_size, hidden_size]
        for i in range(seq_len):
            # input[:, i] : [batch_size, input_size]
            h = torch.tanh(input[:, i] @ self.W_in.T + self.b_in + h.squeeze(0) @ self.W_h.T + self.b_h) # [batch_size, hidden_size] :
            
            outputs.append(h.unsqueeze(1)) # h : [batch_size, hidden_size] -> [batch_size, 1, hidden_size] 
        output_seq = torch.cat(outputs, dim=1) # h : [batch_size, seq_len, hidden_size] # 各単語相当の値を全て返すための処理
        h_n = h.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size] # RNNの出力

        return output_seq, h_n

#### 補足

In [4]:
# 一様分布
hidden_size = 3
input_size = 5
W_in = torch.empty(hidden_size, input_size).uniform_()

init_range = 1.0/math.sqrt(hidden_size)
W_in.uniform_(-init_range, init_range)

tensor([[-0.1766,  0.5647,  0.0658,  0.1871,  0.4833],
        [-0.5748,  0.2428, -0.2839,  0.2757,  0.0639],
        [-0.3644,  0.2232, -0.0383,  0.5226,  0.5703]])

In [5]:
# h_0の初期化
batch_size = 8
hidden_size = 3
h_0 = torch.zeros(1, batch_size, hidden_size)
h_0

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])

In [6]:
# 配列の確認
a = torch.randn(2, 3, 5)
print(a)
print()
print(a[:, 0, 0])
print()
print(a[:, 0])

tensor([[[ 0.2567,  0.0449,  1.4392, -1.2831, -0.9640],
         [-0.4567,  1.1090,  1.7074, -0.1938,  0.3962],
         [-0.1587, -0.3230, -0.5574, -0.4526,  0.8397]],

        [[ 1.3157, -0.3365, -1.8597,  0.0821, -0.2248],
         [-0.8528, -0.0931,  0.0334, -0.9589,  0.6655],
         [ 1.1303,  0.2768,  0.3341,  2.2097, -0.3889]]])

tensor([0.2567, 1.3157])

tensor([[ 0.2567,  0.0449,  1.4392, -1.2831, -0.9640],
        [ 1.3157, -0.3365, -1.8597,  0.0821, -0.2248]])


### テスト

In [7]:
input_size = 10
hidden_size = 3
batch_size = 8
seq_len = 5

input_tensor = torch.randn(batch_size, seq_len, input_size)
rnn = MyRNN(input_size, hidden_size)
output_seq, h_n  = rnn.forward(input_tensor) 
print(output_seq.shape, h_n.shape)

torch.Size([8, 5, 3]) torch.Size([1, 8, 3])


### MyRNNモデル

In [10]:
class MyRNNModel():
    def __init__(self, input_size, hidden_size, output_size):
        self.rnn = MyRNN(input_size, hidden_size) 
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output_seq, h_n = self.rnn.forward(x) # [1, b, h_size]
        out = self.fc(h_n.squeeze(0)) # [b, out]
        return out

In [9]:
output_size = 2
model = MyRNNModel(input_size, hidden_size, output_size)
out = model.forward(input_tensor)
out.shape

torch.Size([8, 2])

### nn.RNN 

In [40]:
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) 
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # output_seq : [batch_size, seq_len, hidden_size]
        # h_n :  [1, b, h_size]
        # out : [b, out]
        output_seq, h_n = self.rnn(x) 
        # out = self.fc(h_n.squeeze(0)) 
        out = self.fc(output_seq[:, -1, :]) # [batch_size, 1, hidden_size]
        # NER (many to many)
        # out = self.fc(output_seq)
        return out

In [41]:
output_size = 2
model = RNNModel(input_size, hidden_size, output_size)
out = model(input_tensor)
out.shape

torch.Size([8, 2])

In [42]:
input_size = 10
hidden_size = 3
batch_size = 8
seq_len = 5
model = RNNModel(input_size, hidden_size, output_size)
out = model(input_tensor)
out.size()

torch.Size([8, 2])

In [43]:
for name, param in model.named_parameters():
    print(f"{name}: {param.size()}")

rnn.weight_ih_l0: torch.Size([3, 10])
rnn.weight_hh_l0: torch.Size([3, 3])
rnn.bias_ih_l0: torch.Size([3])
rnn.bias_hh_l0: torch.Size([3])
fc.weight: torch.Size([2, 3])
fc.bias: torch.Size([2])


#### ↑ nn.Linear は最後の次元に対してのみ作用。つまり、入力テンソルの形状のうち、第1軸（0から数えると第2軸）には何も作用しない

### RNN baack propagation

In [None]:
class MyRNN:
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        init_range = 1.0 / math.sqrt(hidden_size)
        self.W_in = torch.empty(hidden_size, input_size).uniform_(-init_range, init_range).clone().requires_grad_(True)
        self.W_h = torch.empty(hidden_size, hidden_size).uniform_(-init_range, init_range).clone().requires_grad_(True)
        
        self.b_in = torch.empty(hidden_size).uniform_(-init_range, init_range).clone().requires_grad_(True)
        self.b_h = torch.empty(hidden_size).uniform_(-init_range, init_range).clone().requires_grad_(True)

    def forward(self, input, h_0=None):
        # input: [batch_size, seq_len, input_size
        self.input = input
        batch_size, seq_len, _ = input.size()
        
        if h_0 is None:
            h_0 = torch.zeros(1, batch_size, self.hidden_size)#.to(device)

        outputs = []
        h = h_0 # [1, batch_size, hidden_size]
        for i in range(seq_len):
            # input[:, i] : [batch_size, input_size]
            h = torch.tanh(input[:, i] @ self.W_in.T + self.b_in + h.squeeze(0) @ self.W_h.T + self.b_h) # [batch_size, hidden_size] :
            
            outputs.append(h.unsqueeze(1)) # h : [batch_size, hidden_size] -> [batch_size, 1, hidden_size] 
        self.output_seq = torch.cat(outputs, dim=1) # h : [batch_size, seq_len, hidden_size] # 各単語相当の値を全て返すための処理
        h_n = h.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size] # RNNの出力

        return self.output_seq, h_n

    def backward(self, out_grad):
        self.grad_W_in_list = []
        self.grad_W_h_list = []
        self.grad_b_in_list = []
        self.grad_b_h_list = []

        self.grad_h_list = []
        self.grad_h_tanh_list = []
   

        grad_W_in = torch.zeros_like(self.W_in)
        grad_W_h = torch.zeros_like(self.W_h)
        grad_b_in = torch.zeros_like(self.b_in)
        grad_b_h = torch.zeros_like(self.b_h)
        grad_h = torch.zeros_like(self.h)

        grad_output_seq = torch.zeros_like(self.output_seq) #[b, seq_len, hidden_size]
        grad_output_seq[:, -1, :] = out_grad

        for i in reversed(range(self.seq_len)):

            # tanhの微分 (dh*(1-dh^2))
            grad_h_tanh += 
            grad_W_in += torch.sum()  # バッチの合計を取る
            grad_b_in += torch.sum()
            grad_h = 

            if i !=0 :
                # self.output_seqを使って計算
                grad_W_h += 
                grad_W_h +=
            else:
                # h_0を使って計算
                grad_W_h += 
                grad_W_h +=
    
            # 勾配を保存
            # テンソルのバックアップ: ある処理を行う前のテンソルの状態を保存しておきたい場合に、clone()を使ってバックアップを作成する
            self.grad_W_in_list.append(grad_W_in.clone())
            self.grad_W_h_list.append(grad_W_h.clone())
            self.grad_b_in_list.append(grad_b_in.clone())
            self.grad_b_h_list.append(grad_b_h.clone())
            self.grad_h_list.append(grad_h.clone())
            self.grad_h_tanh_list.append(grad_h_tanh.clone())
                
            
            

        
        
        

In [2]:
for i in reversed(range(5)):
    print(i)

4
3
2
1
0
