In [1]:
import torch
import torch.nn as nn
import torchvision

device = torch.device("cuda")

mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True)
train_x = mnist_train.data.reshape(-1, 1, 28, 28).float() 
train_y = torch.zeros((mnist_train.targets.shape[0], 10))  
train_y[torch.arange(mnist_train.targets.shape[0]), mnist_train.targets] = 1 

mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True)
test_x = mnist_test.data.reshape(-1, 1, 28, 28).float() 
test_y = torch.zeros((mnist_test.targets.shape[0], 10))  
test_y[torch.arange(mnist_test.targets.shape[0]), mnist_test.targets] = 1  

mean = train_x.mean()
std = train_x.std()
train_x = (train_x - mean) / std 
test_x = (test_x - mean) / std

batches = 600
x_train_batches = torch.split(train_x, batches)
y_train_batches = torch.split(train_y, batches)

In [2]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.logits = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, padding=2), 
            nn.MaxPool2d(kernel_size=2), 
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2), 
            nn.Flatten(), 
            nn.Linear(64 * 7 * 7, 1024),
            nn.Flatten(), 
            nn.Linear(1024, 10)).to(device)

    def f(self, x):
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1)).to(device)

    def accuracy(self, x, y):
        return torch.mean(torch.eq(self.f(x).argmax(1), y.argmax(1)).float()).to(device)

In [3]:
model = CNNModel()

optimizer = torch.optim.Adam(model.parameters(), 0.0001)
for epoch in range(10):
    for batch in range(len(x_train_batches)):
        model.loss(x_train_batches[batch].to(device), y_train_batches[batch].to(device)).backward() 
        optimizer.step()  
        optimizer.zero_grad()  

    print("accuracy = %s" % model.accuracy(test_x.to(device), test_y.to(device))) 

accuracy = tensor(0.9405, device='cuda:0')
accuracy = tensor(0.9644, device='cuda:0')
accuracy = tensor(0.9736, device='cuda:0')
accuracy = tensor(0.9785, device='cuda:0')
accuracy = tensor(0.9814, device='cuda:0')
accuracy = tensor(0.9836, device='cuda:0')
accuracy = tensor(0.9842, device='cuda:0')
accuracy = tensor(0.9844, device='cuda:0')
accuracy = tensor(0.9853, device='cuda:0')
accuracy = tensor(0.9853, device='cuda:0')
