# Bidirectional RNNs
All the recurrent networks we have considered up to now have a “causal” struc-ture, meaning that the state at time **t**  captures only information from the past,$x^{1} , . . . ,  x^{t-1} $, and the present input $x^t$. Some of the models we have discussedalso allow information from past **y** values to aﬀect the current state when the **y** values are available.In many applications, however, we want to output a prediction of $y^t$ that may depend on the whole input sequence.

For example, in speech recognition, the correct interpretation of the current sound as a phoneme may depend on the nextfew phonemes because of co-articulation and may even depend on the next few words because of the linguistic dependencies between nearby words: if there are two interpretations of the current word that are both acoustically plausible, we may have to look far into the future (and the past) to disambiguate them. This is also true of handwriting recognition and many other sequence-to-sequence learning tasks, described in the next section.

Bidirectional recurrent neural networks (or bidirectional RNNs) were inventedto address that need
![alt text](https://stanford.edu/~shervine/images/bidirectional-rnn.png)

As the name suggests, bidirectional RNNs combine an RNN that moves forward through time, beginning from the start of the sequence, with another RNN that moves backward through time, beginning from the end of the sequence.



In [0]:
! pip3 install torch torchvision

## Modules

In [0]:
import torch 
import torch.nn as nn
import torchvision
from torchvision import transforms

In [0]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###  Hyper-parameters

In [0]:
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003


### Dataset

In [0]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

##  Bidirectional recurrent neural network (many-to-one)


In [0]:
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

In [0]:
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

### Loss and optimizer

In [0]:
crossentropy = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### Train Model

In [0]:
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = crossentropy(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [1/2], Step [100/600], Loss: 0.5666
Epoch [1/2], Step [200/600], Loss: 0.4634
Epoch [1/2], Step [300/600], Loss: 0.1023
Epoch [1/2], Step [400/600], Loss: 0.1456
Epoch [1/2], Step [500/600], Loss: 0.2272
Epoch [1/2], Step [600/600], Loss: 0.0968
Epoch [2/2], Step [100/600], Loss: 0.0792
Epoch [2/2], Step [200/600], Loss: 0.0621
Epoch [2/2], Step [300/600], Loss: 0.0273
Epoch [2/2], Step [400/600], Loss: 0.0372
Epoch [2/2], Step [500/600], Loss: 0.0415
Epoch [2/2], Step [600/600], Loss: 0.0362


### Test Model

In [0]:
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

Test Accuracy of the model on the 10000 test images: 97.18 %


### Save the model checkpoint

In [0]:

torch.save(model.state_dict(), 'model.ckpt')