# ゲート付きRNN

## UGRNN

In [1]:
import torch
import torch.nn as nn

In [2]:
class UGRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

         # 線形変換
        self.transform = nn.Linear(input_size+hidden_size, hidden_size)
        self.update_gate = nn.Linear(input_size+hidden_size, hidden_size)

        # 活性化関数
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, h_0=None):
        # input（batch firstを想定）: [batch_size, seq_len, input_size]
        batch_size, seq_len, _ = input.size()

        if h_0 is None:
            h_0 = torch.zeros(1, batch_size, self.hidden_size)#.to(device)

        h = h_0.squeeze(0) # [1, batch_size, hidden_size] → [batch_size, hidden_size]
        outputs = []
        for i in range(seq_len):
            # [batch_size, hidden_size]
            combined = torch.cat((input[:, i, :], h), dim = 1)
            h_candidate = self.tanh(self.transform(combined))
            update_gate = self.sigmoid(self.update_gate(combined))
            h = update_gate * h_candidate + (1-update_gate) * h
            outputs.append(h.unsqueeze(1))# [batch_size, hidden_size] -> # [batch_size, 1, hidden_size]
        output_seq = torch.cat(outputs, dim=1)
        h_n = h.unsqueeze(0) # [batch_size, hidden_size] -> [1, batch_size, hidden_size]

        return output_seq, h_n

In [3]:
input_size = 10
hidden_size = 3
seq_len = 4
batch_size = 5

input_tensor = torch.randn(batch_size, seq_len, input_size)
ugrnn = UGRNN(input_size, hidden_size)
output_seq, h_n = ugrnn(input_tensor)

In [5]:
output_seq.shape

torch.Size([5, 4, 3])

In [7]:
h_n.shape

torch.Size([1, 5, 3])