# Vanilla RNN

# Import Library

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchsummary import summary

import numpy as np
import matplotlib.pyplot as plt

from utils import train_loop,test_loop

# Check MPS Device for GPU

In [2]:
import torch
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
else:
    print('can use GPU')
    
mps_device = torch.device("mps")

can use GPU


# Initial Setting

In [66]:
# t
timesteps = 10
# n
input_size = 4
# h
hidden_size = 8

# input : t, n
random_input = np.random.random((timesteps,input_size))
inputs = torch.tensor(random_input)
inputs = inputs.to(dtype=torch.float32,device=mps_device)

# hidden state : h,
hidden_state_t = torch.zeros((hidden_size,),device=mps_device)
print('hidden state\n',hidden_state_t)

hidden state
 tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='mps:0')


In [67]:
# weight for input X
Wx = torch.rand((hidden_size,input_size),dtype=torch.float32,device=mps_device)

# wiehgt for hidden state h
Wh = torch.rand((hidden_size,hidden_size),dtype=torch.float32,device=mps_device)

# bias for hidden state b
b = torch.rand((hidden_size,),dtype=torch.float32,device=mps_device)

print('Wx shape : ',Wx.shape,\
    '\nWh shape : ',Wh.shape,\
    '\nb  shape : ',b.shape)

Wx shape :  torch.Size([8, 4]) 
Wh shape :  torch.Size([8, 8]) 
b  shape :  torch.Size([8])


# Define Vanilla RNNs

### $h_t=\tanh(W_xX_t+W_hh_{t-1}+b)$

In [68]:
def vanilla_RNN(X,h,Wx,Wh,b):
    return torch.tanh(\
        torch.matmul(Wx,X)+torch.matmul(Wh,h)+b)

# Simple Training

In [69]:
total_hidden_states = []

# iterate input's sequence for each tims step
for input_t in inputs:
    # compute new hidden state value
    new_hidden = vanilla_RNN(\
        input_t,hidden_state_t,Wx,Wh,b)
    
    # save new hidden state value
    total_hidden_states.append(list(new_hidden))
    
    # update hidden state
    hidden_state_t = new_hidden

In [70]:
print(len(total_hidden_states),' x ',len(total_hidden_states[0]))

10  x  8


# Pytorch nn.RNN()

In [75]:
input_size = 5
hidden_size = 8
inputs = torch.Tensor(1,10,5)
cell = nn.RNN(input_size,hidden_size,batch_first=True)

In [76]:
outputs, _status = cell(inputs)

In [82]:
print(outputs.shape)
print(_status.shape)
print(outputs[0,9]==_status[0,0])

torch.Size([1, 10, 8])
torch.Size([1, 1, 8])
tensor([True, True, True, True, True, True, True, True])
