## Input Feeding 기법
 - 이전 타임 스텝에서의 attention value = context vector

In [1]:
import torch
import torch.nn as nn

In [2]:
class InputFeedingDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(InputFeedingDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, input_size)
        self.rnn = nn.LSTM(input_size + hidden_size, hidden_size )
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, context ):
        embedded = self.embedding(input).unsqueeze(0)
        rnn_input = torch.cat((embedded, context.unsqueeze(0)), dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        output = self.fc(output.squeeze(0))
        
        return output, hidden

In [5]:
decoder = InputFeedingDecoder(input_size=10, hidden_size=20, output_size=30)
hidden = (torch.zeros(1, 1, 20), torch.zeros(1, 1, 20))
context = torch.zeros(1, 20)
input_token = torch.tensor([5])

output, hidden = decoder(input_token, hidden, context)
output.shape , hidden[0].shape

(torch.Size([1, 30]), torch.Size([1, 1, 20]))