In [15]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

torch.manual_seed(125)
if cuda:
    torch.cuda.manual_seed(125)

In [17]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (1.0,))
])

data_root = "../data"

train_dataset = dataset.MNIST(root=data_root, train=True, transform=mnist_transform, download=True)
valid_dataset = dataset.MNIST(root=data_root, train=False, transform=mnist_transform, download=True)
test_dataset = dataset.MNIST(root=data_root, train=False, transform=mnist_transform, download=True)

In [18]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)  # ┌ valid shuffle is not general,
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         shuffle=True)  # │ but sometimes use it for learning more variations

In [19]:
batch_size = 100
n_iters = 6000
num_epochs = int(n_iters / (len(train_dataset) / batch_size))

In [20]:
class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias

        self.x2h = nn.Linear(input_size, 3 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 3 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden):
        x = x.view(-1, x.size(1))
        gate_x = self.x2h(x).squeeze()
        gate_h = self.h2h(hidden).squeeze()

        i_r, i_i, i_n = gate_x.chunk(3, 1)
        h_r, h_i, h_n = gate_h.chunk(3, 1)

        reset_gate = functional.sigmoid(i_r + h_r)
        input_gate = functional.sigmoid(i_i + h_i)
        new_gate = functional.tanh(i_n + (reset_gate * h_n))

        return new_gate + input_gate * (hidden - new_gate)

In [21]:
class GRUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dom, bias=True):
        super(GRUModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim

        self.gru_cell = GRUCell(input_dim, hidden_dim, layer_dim)
        self.fc = nn.Linear(hidden_dim, output_dom)

    def forward(self, x):
        h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)).to(device)

        outs = []
        hn = h0[0, :, :]

        for seq in range(x.size(1)):
            hn = self.gru_cell(x[:, seq, :], hn)
            outs.append(hn)

        out = outs[-1].squeeze()
        out = self.fc(out)

        return out

In [22]:
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dom = 10

model = GRUModel(input_dim, hidden_dim, layer_dim, output_dom).to(device)

criteria = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [23]:
seq_dim = 28
loss_list = []
_iter = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, seq_dim, input_dim).to(device))
        labels = Variable(labels.to(device))

        optimizer.zero_grad()
        outputs = model(images)

        loss = criteria(outputs, labels)
        loss.backward()

        optimizer.step()
        loss_list.append(loss.item())
        _iter += 1

        if _iter % 500 == 0:
            correct = 0
            total = 0

            for _images, _labels in valid_loader:
                _images = Variable(_images.view(-1, seq_dim, input_dim).to(device))

                outputs = model(_images)

                _, predicted = torch.max(outputs.data, 1)

                total += _labels.size(0)

                correct += (predicted.cpu() == _labels.cpu()).sum()

            accuracy = 100 * correct / total

            loss_list.append(loss.data)

            print("Iteration: {}, Loss: {}, Accuracy: {}".format(_iter, loss.data, accuracy))

Iteration: 500, Loss: 1.6616928577423096, Accuracy: 43.59000015258789
Iteration: 1000, Loss: 0.8945668935775757, Accuracy: 76.19999694824219
Iteration: 1500, Loss: 0.29147762060165405, Accuracy: 89.7300033569336
Iteration: 2000, Loss: 0.23627933859825134, Accuracy: 93.51000213623047
Iteration: 2500, Loss: 0.03288726136088371, Accuracy: 95.05000305175781
Iteration: 3000, Loss: 0.03037494421005249, Accuracy: 95.81999969482422
Iteration: 3500, Loss: 0.16210567951202393, Accuracy: 96.33999633789062
Iteration: 4000, Loss: 0.193087637424469, Accuracy: 96.19000244140625
Iteration: 4500, Loss: 0.051720187067985535, Accuracy: 97.0
Iteration: 5000, Loss: 0.13900159299373627, Accuracy: 97.26000213623047
Iteration: 5500, Loss: 0.08090292662382126, Accuracy: 97.62000274658203
Iteration: 6000, Loss: 0.10488346219062805, Accuracy: 97.69000244140625
Iteration: 6500, Loss: 0.07984013855457306, Accuracy: 97.80000305175781
Iteration: 7000, Loss: 0.10250388830900192, Accuracy: 97.55999755859375
Iteration:

In [24]:
from typing import Iterator


def evaluate(mdl, val_iter: Iterator):
    corrects, p_total, p_total_loss = 0, 0, 0
    mdl.eval()
    for img, i_labels in val_iter:
        img = Variable(img.view(-1, seq_dim, input_dim).to(device))

        logit = mdl(img)
        f_loss = functional.cross_entropy(logit, i_labels.to(device), reduction='sum')
        _, res_predict = torch.max(logit.data, 1)
        p_total += i_labels.size(0)
        p_total_loss += f_loss.item()
        corrects += (res_predict.cpu() == i_labels).sum()

    avg_loss = p_total_loss / p_total
    acc = 100 * corrects / p_total
    return avg_loss, acc

In [25]:
test_loss, test_acc = evaluate(model, test_loader)
print("Test Loss: %5.2f | Test Acc: %5.2f" % (test_loss, test_acc))

Test Loss:  0.07 | Test Acc: 97.98
