## Sequence 2 Sequence Modeling

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter as P

### About
Taking in a vector of variable length, and spitting out a vector of variable length. The most common real world example is Machine Translation.

One of the key components is the LSTM cell which we built in the Text - Gated Neural Nets (LSTM, GRU) notebook

In [3]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # input to hidden
        self.Wxi = P(torch.randn(input_size, hidden_size)*.01)
        self.Wxf = P(torch.randn(input_size, hidden_size)*.01)
        self.Wxo = P(torch.randn(input_size, hidden_size)*.01)
        self.Wxc = P(torch.randn(input_size, hidden_size)*.01)
        
        # hidden to hidden
        self.Whi = P(torch.randn(hidden_size, hidden_size)*.01)
        self.Whf = P(torch.randn(hidden_size, hidden_size)*.01)
        self.Who = P(torch.randn(hidden_size, hidden_size)*.01)
        self.Whc = P(torch.randn(hidden_size, hidden_size)*.01)
        
        # bias
        self.bi = P(torch.zeros(1, hidden_size))
        self.bf = P(torch.zeros(1, hidden_size))
        self.bo = P(torch.zeros(1, hidden_size))
        self.bc = P(torch.zeros(1, hidden_size))
        
        
    def forward(self, input, hidden):
        h, c = hidden # previous h, c 
        
        # sigmoid + linear map input, hidden -> hidden
        i_t = torch.sigmoid(input @ self.Wxi + h @ self.Whi + self.bi)
        f_t = torch.sigmoid(input @ self.Wxf + h @ self.Whf + self.bf)
        o_t = torch.sigmoid(input @ self.Wxo + h @ self.Who + self.bo)
        
        # tanh + linear map input, hidden -> hidden
        g_t = torch.tanh(input @ self.Wxc + h @ self.Whc + self.bc)
        
        # note that this is elementwise multiplication
        # not matrix multiplication
        c_t = c * f_t + i_t * g_t
        h_t = o_t * torch.tanh(c_t)
        
        return h_t, (h_t, c_t)
    
    def initHidden(self):
        return torch.zeros(1, self.hidden_size), torch.zeros(1, self.hidden_size)

### Encoder Architecture
The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.

In [6]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        
        # nn Embedding is just an array of tensors
        # nn.Embedding(input_size, hidden_size) is an array
        # holding (input size) tensors of size (hidden size)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = LSTMCell(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

### Decoder Architecture
The decoder is another RNN that takes the encoder output vector(s) and outputs a sequence of words to create the translation.

In [7]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = LSTMCell(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.lstm(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

### Dataset Prep