### ゲート付きRNN

In [1]:
import torch
import torch.nn as nn

### UGRNN

In [9]:
class UGRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # 線形変換
        self.transform = nn.Linear(input_size+hidden_size, hidden_size)
        self.update = nn.Linear(input_size+hidden_size, hidden_size)

        # 活性化関数
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input, h_0=None):
        # input: [batch_size, seq_len, input_size
        batch_size, seq_len, _ = input.size()
        if h_0 is None:
            h_0 = torch.zeros(1, batch_size, self.hidden_size) 
        outputs = []
        h = h_0.squeeze(0) # [1, batch_size, hidden_size] -> [batch_size, hidden_size]
        for i in range(seq_len):
            # input[:, i] : [batch_size, input_size]
            combined = torch.cat((input[:, i, :], h), dim=1)
            hidden_candidate = self.tanh(self.transform(combined))
            update_gate = self.sigmoid(self.update(combined))
            h = update_gate * hidden_candidate + (1 - update_gate) * h            
            outputs.append(h.unsqueeze(1)) # h : [batch_size, hidden_size] -> [batch_size, 1, hidden_size] 
        self.output_seq = torch.cat(outputs, dim=1) # h : [batch_size, seq_len, hidden_size] # 各単語相当の値を全て返すための処理
        h_n = h.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size] # RNNの出力

        return self.output_seq, h_n

In [12]:
# UGRNNのテスト
input_size = 10
hidden_size = 3
batch_size = 8
seq_len = 5
input_tensor = torch.randn(batch_size, seq_len, input_size)
ugrnn = UGRNN(input_size, hidden_size)
output_seq, h_n = ugrnn(input_tensor)
output_seq.shape

torch.Size([8, 5, 3])

### LSTM

In [30]:
class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # 線形変換
        self.update_gate = nn.Linear(input_size+hidden_size, hidden_size)
        self.forget_gate = nn.Linear(input_size+hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size+hidden_size, hidden_size)

        # セル状態の更新に必要な全結合層
        self.cell_candidate = nn.Linear(input_size+hidden_size, hidden_size)

        # 活性化関数
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input, h_0=None, c_0=None):
        # input: [batch_size, seq_len, input_size
        batch_size, seq_len, _ = input.size()
        if h_0 is None:
            h_0 = torch.zeros(1, batch_size, self.hidden_size) 
        if c_0 is None:
            c_0 = torch.zeros(1, batch_size, self.hidden_size) 
       
        h = h_0.squeeze(0) # [1, batch_size, hidden_size] -> [batch_size, hidden_size]
        c = c_0.squeeze(0) # [1, batch_size, hidden_size] -> [batch_size, hidden_size]
        outputs = []
        for i in range(seq_len):
            # input[:, i] : [batch_size, input_size]
            combined = torch.cat((input[:, i, :], h), dim=1)
            cell_candidate = self.tanh(self.cell_candidate(combined))
            update_gate = self.sigmoid(self.update_gate(combined))
            forget_gate = self.sigmoid(self.forget_gate(combined))
            output_gate = self.sigmoid(self.output_gate(combined))
            
            c = update_gate * cell_candidate + forget_gate * c
            h = output_gate * self.tanh(c)
            outputs.append(h.unsqueeze(1)) # h : [batch_size, hidden_size] -> [batch_size, 1, hidden_size] 
        self.output_seq = torch.cat(outputs, dim=1) # h : [batch_size, seq_len, hidden_size] 
        h_n = h.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size] 
        c_n = c.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size] 

        return self.output_seq, (h_n, c_n) # (h_n, c_n)は、pytrochの実装にあわせている

In [39]:
# UGRNNのテスト
input_size = 10
hidden_size = 3
batch_size = 8
seq_len = 5
input_tensor = torch.randn(batch_size, seq_len, input_size)
lstm = MyLSTM(input_size, hidden_size)
output_seq, (h_n, c_n) = lstm(input_tensor)
print(output_seq.shape)
print(h_n.shape, c_n.shape)

torch.Size([8, 5, 3])
torch.Size([1, 8, 3]) torch.Size([1, 8, 3])


### PytorchのGRUとLSTMを使用する

In [40]:
# GRU
gru = nn.GRU(input_size, hidden_size, batch_first=True)
output_seq, h_n = gru(input_tensor)
print(output_seq.shape)
print(h_n.shape)

torch.Size([8, 5, 3])
torch.Size([1, 8, 3])


In [41]:
# LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
output_seq, (h_n, c_n) = lstm(input_tensor)
print(output_seq.shape)
print(h_n.shape, c_n.shape)

torch.Size([8, 5, 3])
torch.Size([1, 8, 3]) torch.Size([1, 8, 3])


In [48]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rnn_type='LSTM'):
        super().__init__()
        
        if rnn_type == 'RNN':
            self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) 
        
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True) 

        elif rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True) 

        elif rnn_type == 'UGRNN':
            self.rnn = UGRNN(input_size, hidden_size) # pytorchでは実装されていない

        else:
            raise ValueError('Unsapported RNN type. Choose from ["LSTM", "RNN", "GUR", "UGRNN"')
            
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output_seq, _ = self.rnn(x) 
        # output_seq : [b, st_len, hidden_size]
        out = self.fc(output_seq[:, -1, :])
        return out
        
       

In [51]:
output_size = 3
model = Model(input_size, hidden_size, output_size, rnn_type='RNN')
output = model(input_tensor)
print(output.shape)

torch.Size([8, 3])


### 補足

In [23]:
X = torch.randn(8, 5) # [batch_size, input_size]
h = torch.randn(8, 32) # [batch_size, hidden_size]

W = torch.randn(32, 5+32)

b = torch.ones(32)

In [24]:
combained = torch.cat((X, h), dim=1)
combained

tensor([[-0.0446, -1.9736,  0.0460,  0.0781, -0.6876, -3.6557,  0.9979,  0.7083,
          0.2976,  1.1999, -0.8881, -0.6283, -2.6703,  0.4093,  1.1520, -0.5572,
          2.9445,  2.0503,  1.1618,  0.5219,  0.2108,  0.1347, -0.0699, -1.0595,
         -0.3145, -0.9347,  0.7353, -0.0433,  1.1855,  1.1288, -0.7172, -1.2068,
          1.2265,  1.1323, -1.3892, -1.3278, -0.3085],
        [ 1.4799, -0.4404, -0.9352,  0.4447, -1.1642, -0.8137,  0.4417, -2.3990,
          1.7263,  0.7223, -0.9813,  0.1665, -1.0050, -1.0454,  1.1890, -1.0982,
         -2.0528,  2.0969, -1.0906, -0.6501,  0.8423,  0.4964, -2.4665,  0.3983,
         -0.5736, -0.1606,  0.3175, -0.0468,  1.7533, -1.9398, -0.4008, -0.5261,
         -0.5312, -1.2497, -1.6455,  0.1489,  0.2087],
        [-0.5195, -0.1675,  0.2761,  0.5483,  0.2494,  0.1480, -0.3050,  0.7821,
         -0.0767, -0.9582,  0.5438,  0.0176,  2.1386,  0.9012, -1.3719,  1.9839,
          1.0006,  0.9636,  1.0587, -1.0281, -0.5082, -0.6767, -0.1877,  0.3550,

In [25]:
combained.shape

torch.Size([8, 37])

In [27]:
h = combained @ W.T + b # 隠れ層
