In [1]:
import torch
import torch.nn as nn

### Vanilla RNNs

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size, bias=True)
        self.i2o = nn.Linear(input_size + hidden_size, output_size, bias=True)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_tensor, hidden_tensor):
        combined = torch.cat((input_tensor, hidden_tensor), dim=1)

        hidden = torch.tanh(self.i2h(combined))
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)


### LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size

        self.i2f = nn.Linear(input_size + hidden_size, hidden_size, bias=True) # forget gate
        self.i2i = nn.Linear(input_size + hidden_size, hidden_size, bias=True) # input gate
        self.i2o = nn.Linear(input_size + hidden_size, hidden_size, bias=True) # output gate
        self.i2c = nn.Linear(input_size + hidden_size, hidden_size, bias=True) # cell gate

        self.i2y = nn.Linear(hidden_size, output_size, bias=True)  # Output layer
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_tensor, hidden_state, cell_state):
        combined = torch.cat((input_tensor, hidden_state), dim=1)

        forget_gate = torch.sigmoid(self.i2f(combined))
        input_gate = torch.sigmoid(self.i2i(combined))
        output_gate = torch.sigmoid(self.i2o(combined))
        cell_candidate = torch.tanh(self.i2c(combined))

        cell_state = forget_gate * cell_state + input_gate * cell_candidate
        hidden_state = output_gate * torch.tanh(cell_state)

        output = self.i2y(hidden_state)
        output = self.softmax(output)

        return output, hidden_state, cell_state

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size), torch.zeros(batch_size, self.hidden_size)


### Gated Recurrent Unit

In [None]:
class GRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRU, self).__init__()
        self.hidden_size = hidden_size

        # gate
        self.i2z = nn.Linear(input_size + hidden_size, hidden_size) # update gate
        self.i2r = nn.Linear(input_size + hidden_size, hidden_size) # reset gate
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size) # candidate hidden state

        #output layer
        self.i2y = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1) # remove if using CrossEntropyLoss

    def forward(self, input_tensor, hidden_state):
        combined = torch.cat((input_tensor, hidden_state), dim=1)

        # gates
        update_gate = torch.sigmoid(self.i2z(combined))
        reset_gate = torch.sigmoid(self.i2r(combined))

        # apply reset gate before computing candidate hidden state
        combined_reset = torch.cat((input_tensor, reset_gate * hidden_state), dim=1)
        candidate = torch.tanh(self.i2h(combined_reset))

        # final hidden state
        hidden_state = (1 - update_gate) * hidden_state + update_gate * candidate

        # output 
        output = self.i2y(hidden_state)
        output = self.softmax(output)

        return output, hidden_state
    
    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
    
        