In [None]:
import torch
import numpy as np
import torchvision.datasets

In [None]:
mnist_train = torchvision.datasets.MNIST('./', download=True, train=True)
mnist_test = torchvision.datasets.MNIST('./', download=True, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [None]:
X_train = mnist_train.train_data
y_train = mnist_train.train_labels
X_test = mnist_test.test_data
y_test = mnist_test.test_labels



In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
X_train = X_train.unsqueeze(1).float().to(device)
X_test = X_test.unsqueeze(1).float().to(device)

In [None]:
X_train.shape

torch.Size([60000, 1, 28, 28])

In [None]:
class NextLeNet3(torch.nn.Module):
    def __init__(self):
        super(NextLeNet3, self).__init__()

        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=8, kernel_size=3, padding=1
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=8, out_channels=8, kernel_size=3, padding=1
        )
        self.act1  = torch.nn.ReLU()
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
       
        self.conv3 = torch.nn.Conv2d(
            in_channels=8, out_channels=16, kernel_size=3, padding=1
        )
        self.conv4 = torch.nn.Conv2d(
            in_channels=16, out_channels=16, kernel_size=3, padding=1
        )
        self.act2  = torch.nn.ReLU()
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1   = torch.nn.Linear(7 * 7 * 16, 120)
        self.act3  = torch.nn.ReLU()
        
        self.fc2   = torch.nn.Linear(120, 84)
        self.act4  = torch.nn.ReLU()
        
        self.fc3   = torch.nn.Linear(84, 10)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.act1(x)
        x = self.pool1(x)
        
        
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.act2(x)
        x = self.pool2(x)
        
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))

        x = self.fc1(x)
        x = self.act3(x)
        x = self.fc2(x)
        x = self.act4(x)
        x = self.fc3(x)

        return x

In [None]:
model = NextLeNet3()

In [None]:
model = model.to(device)

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0e-4)

In [None]:
y_train = y_train.to(device)

In [None]:
y_test = y_test.to(device)

In [None]:
batch_size = 100

for epoch in range(10000):
    order = np.random.permutation(X_train.shape[0])
    for start_id in range(0, X_train.shape[0], batch_size):
        optimizer.zero_grad()
        batch_id = order[start_id : start_id+batch_size]
        predict = model.forward(X_train[batch_id])
        loss_pred = loss(predict, y_train[batch_id])
        loss_pred.backward()
        optimizer.step()

    # if epoch % 50 == 0:
    predict = model.forward(X_test).argmax(axis=1)
    print((y_test == predict).cpu().numpy().mean())


0.9248
0.9538
0.9634
0.9657
0.9729
0.9752
0.9746
0.9761
0.9782
0.9777
0.9817
0.9825
0.9808
0.9836
0.9835
0.983
0.9847
0.9848
0.9858
0.9844
0.984
0.9858
0.986
0.9868
0.9858
0.9868
0.9864
0.9872
0.9861
0.9874
0.9871
0.9881
0.9867
0.9877
0.9878
0.9881
0.9876
0.9875
0.9885
0.9876
0.9876
0.9888
0.9878
0.9875
0.9878
0.9881
0.9888
0.9883
0.9883
0.9885
0.988
0.9882
0.9887
0.9886
0.9879
0.9882
0.988
0.9886
0.9886
0.9892
0.9879
0.9886
0.9887
0.9885
0.9887
0.9887
0.9881
0.9882
0.9882
0.9885
0.989
0.9891
0.9888
0.9889
0.9888
0.9889
0.9884
0.9884
0.989
0.9892
0.9891
0.989
0.9889
0.9891
0.989
0.9888
0.989
0.9892
0.9893
0.989
0.9892
0.9893
0.9893
0.9889
0.9888
0.9889
0.9893
0.9889
0.9891
0.989
0.9891
0.9887
0.9894
0.9892
0.9894
0.9889
0.9893
0.9888
0.9891
0.9891
0.9889
0.9891
0.989
0.989
0.9891
0.9895
0.9896
0.9894
0.9896
0.9891
0.9894
0.9895
0.9892
0.9896
0.9893
0.9893
0.9896
0.9894
0.9895
0.9897
0.9894
0.9898
0.9896
0.9895
0.9891
0.9896
0.9895
0.9894
0.9896
0.9896
0.9894
0.9894
0.9897
0.9896
0.9895