In [25]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn

In [20]:
image_path = './'
transform = transforms.Compose(
                    [transforms.ToTensor()])

mnist_train = torchvision.datasets.MNIST(root=image_path, transform=transform,
                                        train=True,download=False)

mnist_test = torchvision.datasets.MNIST(root=image_path,transform=transform,
                                        train=False, download=False)

batch_size = 64
torch.manual_seed(43)
train_dl = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)


## Construct the model

In [27]:
hidden_units_list =[32, 16]
all_layers = [nn.Flatten()]
image_shape = mnist_train[0][0].shape
input_size = image_shape[0] * image_shape[1] * image_shape[2]
for hidden_size in hidden_units_list:
    all_layers.append(nn.Linear(input_size, hidden_size))
    all_layers.append(nn.ReLU())
    input_size = hidden_size

all_layers.append(nn.Linear(hidden_units_list[-1], 10))
model = nn.Sequential(*all_layers)
model 

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [34]:
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
torch.manual_seed(0)
num_epochs = 20

for epoch in range(num_epochs):
    accuracy_hist_train = 0
    for x_batch, y_batch in train_dl:
        preds = model(x_batch)
        loss = loss_fn(preds, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        is_correct = (preds.argmax(dim=1) == y_batch).float()
        accuracy_hist_train += is_correct.sum()
    accuracy_hist_train /= len(train_dl.dataset)
    print(f'Epoch {epoch} Accuracy '
        f'{accuracy_hist_train:.4f}')

Epoch 0 Accuracy 0.9256
Epoch 1 Accuracy 0.9414
Epoch 2 Accuracy 0.9493
Epoch 3 Accuracy 0.9561
Epoch 4 Accuracy 0.9610
Epoch 5 Accuracy 0.9645
Epoch 6 Accuracy 0.9667
Epoch 7 Accuracy 0.9693
Epoch 8 Accuracy 0.9716
Epoch 9 Accuracy 0.9737
Epoch 10 Accuracy 0.9750
Epoch 11 Accuracy 0.9776
Epoch 12 Accuracy 0.9788
Epoch 13 Accuracy 0.9801
Epoch 14 Accuracy 0.9814
Epoch 15 Accuracy 0.9823
Epoch 16 Accuracy 0.9836
Epoch 17 Accuracy 0.9838
Epoch 18 Accuracy 0.9848
Epoch 19 Accuracy 0.9855


In [39]:
preds = model(mnist_test.data/255.)

is_correct = (preds.argmax(dim=1) == mnist_test.targets).float()
print(f'Test accuracy: {is_correct.mean():.4f}')

Test accuracy: 0.9666
