In [25]:
import math

import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as functional
from torch.autograd import Variable
from torch.utils.data import DataLoader

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cuda = torch.cuda.is_available()

print(device)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

torch.manual_seed(125)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(125)

cuda


In [27]:
import torchvision.transforms as transforms

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (1.0,))
])

In [28]:
from torchvision.datasets import MNIST

download_root = "../data/"

train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
vali_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

In [29]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
vali_loader = DataLoader(vali_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [30]:
batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

In [31]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden):
        hx, cx = hidden
        x = x.view(-1, x.size(1))
        gates = self.x2h(x) + self.h2h(hx)

        gates = gates.squeeze()

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate)
        hy = torch.mul(outgate, torch.tanh(cy))

        return (hy, cy)

In [32]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
        super(LSTMModel, self).__init__()

        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim

        self.lstm = LSTMCell(input_dim, hidden_dim, layer_dim)

        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)

        outs = []

        cn = c0[0, :, :]
        hn = h0[0, :, :]

        for seq in range(x.size(1)):
            hn, cn = self.lstm(x[:, seq, :], (hn, cn))
            outs.append(hn)

        out = outs[-1].squeeze()

        out = self.fc(out)

        return out

In [33]:
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dim = 10

model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
model.to(device)

criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [34]:
seq_dim = 28
loss_list = []
iter = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, seq_dim, input_dim).to(device))
        labels = Variable(labels.to(device))

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        iter += 1

        if iter % 500 == 0:
            correct = 0
            total = 0

            for images, labels in vali_loader:
                images = Variable(images.view(-1, seq_dim, input_dim).to(device))
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted.cpu() == labels.cpu()).sum()

            accuracy = 100 * correct / total

            print("Iteration: {}. Loss: {}. Accuracy: {}".format(iter, loss.item(), accuracy))

Iteration: 500. Loss: 2.237457275390625. Accuracy: 21.420000076293945
Iteration: 1000. Loss: 0.9093605875968933. Accuracy: 75.72000122070312
Iteration: 1500. Loss: 0.4703645706176758. Accuracy: 87.77999877929688
Iteration: 2000. Loss: 0.2919614911079407. Accuracy: 93.33000183105469
Iteration: 2500. Loss: 0.14118364453315735. Accuracy: 93.83999633789062
Iteration: 3000. Loss: 0.08196574449539185. Accuracy: 96.5
Iteration: 3500. Loss: 0.09124790877103806. Accuracy: 95.87999725341797
Iteration: 4000. Loss: 0.0626092329621315. Accuracy: 97.12000274658203
Iteration: 4500. Loss: 0.046092256903648376. Accuracy: 97.25
Iteration: 5000. Loss: 0.09187523275613785. Accuracy: 96.94999694824219
Iteration: 5500. Loss: 0.09047114104032516. Accuracy: 97.30000305175781
Iteration: 6000. Loss: 0.023957697674632072. Accuracy: 97.80999755859375
Iteration: 6500. Loss: 0.01368747465312481. Accuracy: 97.55000305175781
Iteration: 7000. Loss: 0.020472373813390732. Accuracy: 97.8499984741211
Iteration: 7500. Loss

In [43]:
def evaluate(model, val_iter):
    corrects, total, total_loss = 0, 0, 0

    model.eval()

    for images, labels in val_iter:
        images = Variable(images.view(-1, seq_dim, input_dim).to(device))
        labels = labels.to(device)  # move labels to the same device as the model

        logits = model(images)
        loss = functional.cross_entropy(logits, labels, reduction='sum')
        total += labels.size(0)
        total_loss += loss.item()
        corrects += (logits.argmax(1) == labels).sum()

    avg_loss = total_loss / len(val_iter.dataset)
    avg_acc = corrects / total

    return avg_loss, avg_acc

In [44]:
test_loss, test_acc = evaluate(model, test_loader)
print("Test Loss: {}, Test Accuracy: {}".format(test_loss, test_acc))

Test Loss: 0.0597270193759352, Test Accuracy: 0.9803999662399292
