In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


In [2]:
''' load MNIST database '''
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset_path = '../mnist_dataset'

train_dataset = MNIST(dataset_path, transform=transform, train=True, download=True)
valid_dataset = MNIST(dataset_path, transform=transform, train=False, download=True)
test_dataset = MNIST(dataset_path, transform=transform, train=False, download=True)


In [3]:
''' load MNIST dataset by using dataloader'''
batch_size = 64

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                         batch_size=1,
                         shuffle=False)


In [4]:
''' test data loader'''
print(len(train_loader))
for batch_idx, (image, label) in enumerate(train_loader):
    if (batch_idx + 1) % 100 == 0:
        print(image.shape, label.shape)


938
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1, 28, 28]) torch.Size([64])


In [21]:
''' test implementation '''
# prepare network input 
x_batch, y_batch = iter(train_loader).next()
batch_size, n_chn, hor_dim, ver_dim = x_batch.size()

x = x_batch
x = x.view(-1, 28, 28)

# initial hidden state
h0 = torch.zeros(2 * 2, 64, 128)
c0 = torch.zeros(2 * 2, 64, 128)

print(x.size(), h0.size())

# define neural network 
lstm = nn.LSTM(28, 128, 2, batch_first=True, bidirectional=True)

fc = nn.Linear(128 * 2, 10)

# Conv layer
x, _ = lstm(x, (h0, c0))
x = x[:, -1, :]

x = fc(x)

print(x_batch.size())
print(x.size(), y_batch.size())


torch.Size([64, 28, 28]) torch.Size([4, 64, 128])
torch.Size([64, 1, 28, 28])
torch.Size([64, 10]) torch.Size([64])


In [38]:
''' Model class definition '''
class BiLSTM(nn.Module):
    def __init__(self, hidden_size, num_layers, bidirectional=True):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        self.lstm = nn.LSTM(28, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_size * 2, 10)
        
    def forward(self, x):
        x = x.view(-1, 28, 28)
        batch_size = x.size(0)
        # initial hidden state
        if self.bidirectional:
            h0 = torch.zeros(2 * self.num_layers, batch_size, self.hidden_size)
            c0 = torch.zeros(2 * self.num_layers, batch_size, self.hidden_size)
        else:
            h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
            c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        
        # LSTM layer
        x, _ = self.lstm(x, (h0, c0))
        
        # FC layer
        x = x[:, -1, :]
        x = self.fc(x)

        return x

model = BiLSTM(128, 2, bidirectional=True)
print(model)


BiLSTM(
  (lstm): LSTM(28, 32, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=64, out_features=10, bias=True)
)


In [39]:
''' Training criteria and optimizer definition '''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print(criterion, optimizer)

CrossEntropyLoss() Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)


In [40]:
''' Train network '''
num_epochs = 6
model.train()

for epoch in range(num_epochs):
    loss_avg = 0.
    for image, label in train_loader:
        model_out = model(image)
        loss = criterion(model_out, label)
        loss.backward()
        optimizer.step()
        loss_avg += loss / len(train_loader)
        
    print('Epoch: {:} \tTrain loss: {:.6f}'.format(
        epoch+1, loss_avg))
    

Epoch: 1 	Train loss: 1.584579


KeyboardInterrupt: 

In [41]:
''' Test model '''
test_loss = 0.
accuracy_total = 0
model.eval()
for image, label in test_loader:
    # Evaluate loss
    model_out = model(image)
    loss = criterion(model_out, label)
    test_loss += loss / len(test_loader)
    
    # Evaluate classification accuracy
    _, pred = torch.max(model_out, dim=1)
    accuracy = torch.sum((pred == label).float())
    accuracy_total += accuracy / len(test_loader)
    
print('Test set: Accuracy: {:.2f}, Loss: {:.6f}'.format(
    accuracy_total, test_loss))


Test set: Accuracy: 0.82, Loss: 0.548989
