Imports

In [50]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import DataLoader

from CNN import CNN

import matplotlib.pyplot as plt
import numpy as np

Define Show Function

In [51]:
def show(image):
    plt.imshow(image, cmap='gray')

Import Datasets

In [52]:
train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

Check Device

In [53]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device.upper()}")

Using CUDA


Model Parameters

In [54]:
input_channels = 1
classes = 10
conv_kernel_size = 5
pool_kernel_size = 2
conv_stride = 1
pool_stride = 2
fc1_output = 1124

Define Model

In [55]:
model = CNN(input_channels, classes, conv_kernel_size, pool_kernel_size, conv_stride, pool_stride, fc1_output)

model.to(device)

CNN(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 24, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=384, out_features=1124, bias=True)
  (fc2): Linear(in_features=1124, out_features=2248, bias=True)
  (fc3): Linear(in_features=2248, out_features=4496, bias=True)
  (fc4): Linear(in_features=4496, out_features=10, bias=True)
)

Training Parameters

In [56]:
train_batch_size = 32
learning_rate = 0.01
momentum = 0.8
num_epochs = 200

In [57]:
train_loader = DataLoader(train_data, train_batch_size)

test_loader = DataLoader(test_data)

Loss Function And Optimizer

In [58]:
loss_func = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

Train Function

In [59]:
def train():

    model.train()

    for epochs in range(num_epochs):

        print("Epoch: ", epochs + 1)

        for i, (images, labels) in enumerate(train_loader):

            images = images.to(device)
            labels = labels.to(device)

            output = model(images)
            loss = loss_func(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

Test Function

In [60]:
def test():

    model.eval()

    correct_predictions = 0
    images_evaluated = 0

    with torch.no_grad():

        for image, label in test_loader:

            image = image.to(device)
            label = label.to(device)

            output = model(image)

            images_evaluated += 1

            prediction = torch.argmax(output)

            if prediction == label:
                correct_predictions += 1
    
    accuracy = correct_predictions / images_evaluated * 100

    return accuracy, correct_predictions, images_evaluated

Train & Test

In [61]:
print(test())

print("-------------------------------")

train()

print("-------------------------------")

print(test())

(14.149999999999999, 1415, 10000)
-------------------------------
Epoch:  1
Epoch:  2
Epoch:  3
Epoch:  4
Epoch:  5
Epoch:  6
Epoch:  7
Epoch:  8
Epoch:  9
Epoch:  10
Epoch:  11
Epoch:  12
Epoch:  13
Epoch:  14
Epoch:  15
Epoch:  16
Epoch:  17
Epoch:  18
Epoch:  19
Epoch:  20
Epoch:  21
Epoch:  22
Epoch:  23
Epoch:  24
Epoch:  25
Epoch:  26
Epoch:  27
Epoch:  28
Epoch:  29
Epoch:  30
Epoch:  31
Epoch:  32
Epoch:  33
Epoch:  34
Epoch:  35
Epoch:  36
Epoch:  37
Epoch:  38
Epoch:  39
Epoch:  40
Epoch:  41
Epoch:  42
Epoch:  43
Epoch:  44
Epoch:  45
Epoch:  46
Epoch:  47
Epoch:  48
Epoch:  49
Epoch:  50
Epoch:  51
Epoch:  52
Epoch:  53
Epoch:  54
Epoch:  55
Epoch:  56
Epoch:  57
Epoch:  58
Epoch:  59
Epoch:  60
Epoch:  61
Epoch:  62
Epoch:  63
Epoch:  64
Epoch:  65
Epoch:  66
Epoch:  67
Epoch:  68
Epoch:  69
Epoch:  70
Epoch:  71
Epoch:  72
Epoch:  73
Epoch:  74
Epoch:  75
Epoch:  76
Epoch:  77
Epoch:  78
Epoch:  79
Epoch:  80
Epoch:  81
Epoch:  82
Epoch:  83
Epoch:  84
Epoch:  85
Epoch:  

In [62]:
save_file = "model_data.pth"

torch.save(model, save_file)