In [1]:
import torch 
from torch.nn import functional as F 
import torch.nn as nn 

In [2]:
class SimpleRNN(nn.Module):
    
    def __init__(self, input_size:int, hidden_size:int , output_size:int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.inp2hidden = nn.Linear(input_size , hidden_size , bias=False)
        self.hid2out = nn.Linear(hidden_size , output_size)
        
    
    def forward(self , x, hidden_state):
        """
        Returns computed output and tanh(i2h + h2h)
        Inputs
        ------
        x: Input vector
        hidden_state: Previous hidden state
        Outputs
        -------
        out: Linear output (without activation because of how pytorch works)
        hidden_state: New hidden state matrix
        """
        x = self.inp2hidden(x)
        hidden_state = self.torch.tanh(x+hidden_state)
        out  = self.hid2out(hidden_state)
        
    
    def init_zero_hidden(self , batch_size = 1):
        """
        Helper function.
        Returns a hidden state with specified batch size. Defaults to 1
        
        """
        return torch.zeros(batch_size, self.hidden_size, requires_grad=False)
        

### Shape And Dimension

![shape](https://miro.medium.com/v2/resize:fit:786/format:webp/1*ky5-nx-6uMBualEs7HUwGA.png)

**Dimensions resulting from each matrix dot product (yellow indicators)**

1. batch_size x hidden_units
2. batch_size x hidden_units
3. batch_size x output size