<a href="https://colab.research.google.com/github/Programlog/MNIST_CNN/blob/main/MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

#device configuration
device  = 'cuda' if torch.cuda.is_available() else 'cpu'

#hyperparameters
epochs = 4
batch = 16
learning_rate = 0.001


# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root = './data', train = False, download = True, transform = transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle = True)
test_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle = False)

classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #first layer is a convolution layer, with input size 1 (B&W photo), arbitrary output size 6, kernel size 5x5
        self.conv1 = nn.Conv2d(1, 8, 5)

        #next is a max pool later, with kernel size 2x2, and stride of 2(shifting two pixels to right)
        self.pool = nn.MaxPool2d(2,2)

        #second layer is another convolution later, with input size 6, output size 16, kernel size 5x5
        self.conv2 = nn.Conv2d(8, 16, 5 )

        #flatten to allow for linear layers
        self.flatten = nn.Flatten()

        #linear layers
        self.fc1 = nn.Linear(16*4*4, 120) # kernel shrinks so new input is now previous output * kernel area
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) # output nodes should be 10 for the 10 classes (0, 1, 2...)

    def forward(self, x):
        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))
        # x = self.flatten(x)

        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)
        # return x
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = CNN().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 18273118.23it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 495844.05it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4495146.90it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2754160.59it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [14]:
##Load
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

In [18]:
##Train
for epoch in range(epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        #forward pass
        outputs = model(images)

        # calculate cost
        loss = loss_func(outputs, labels)

        #backwards pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1)%2000 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item(): .5f}')

print('Finished Training')

Epoch [1/4], Loss:  0.00209
Epoch [2/4], Loss:  0.00000
Epoch [3/4], Loss:  0.00291
Epoch [4/4], Loss:  0.00550
Finished Training


In [19]:
## Save
from pathlib import Path

model_path = Path('model.pth')
torch.save(model.state_dict(), model_path)
print("Saved model")

Saved model


In [20]:
##Evaluate
model.eval()
with torch.inference_mode():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)

        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += torch.sum(predicted == labels)

        for i in range(batch):
            label = labels[i]
            pred = predicted[i]
            if label == pred:
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network: {acc: .4f} %')

    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of number {classes[i]}: {acc: .4f} %')

Accuracy of the network:  99.4783 %
Accuracy of number 0:  99.2402 %
Accuracy of number 1:  98.9321 %
Accuracy of number 2:  99.7315 %
Accuracy of number 3:  99.6412 %
Accuracy of number 4:  99.2811 %
Accuracy of number 5:  99.5204 %
Accuracy of number 6:  99.6958 %
Accuracy of number 7:  99.3136 %
Accuracy of number 8:  99.9829 %
Accuracy of number 9:  99.5293 %
