In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from cnn import CNN
from train import train

# Configuration of model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.backends.mps.is_available():
    device = "mps"

PATH_TO_READY_MODEL = None
PATH_TO_SAVE_MODEL = "model.pth"
EPOCHS = 20

# Load dataset

In [12]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Setup model

In [13]:
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss().to(device)

# Train or load model

In [14]:
if PATH_TO_READY_MODEL is not None:
    model.load_state_dict(torch.load(PATH_TO_READY_MODEL, map_location=device))
else:
    train(model, optimizer, criterion, train_loader, test_loader, device=device, num_epochs=EPOCHS)
    torch.save(model.state_dict(), PATH_TO_SAVE_MODEL if PATH_TO_SAVE_MODEL is not None else "model.pth")

Epoch [1/20], Loss: 0.2861
Epoch [2/20], Loss: 0.0575
Epoch [3/20], Loss: 0.0374
Epoch [4/20], Loss: 0.0278
Epoch [5/20], Loss: 0.0218
Epoch [6/20], Loss: 0.0167
Epoch [7/20], Loss: 0.0121
Epoch [8/20], Loss: 0.0107
Epoch [9/20], Loss: 0.0087
Epoch [10/20], Loss: 0.0067
Epoch [11/20], Loss: 0.0057
Epoch [12/20], Loss: 0.0049
Epoch [13/20], Loss: 0.0042
Early stopping triggered.


# Show result on validation data

In [25]:
model.eval() 

correct = 0
total = 0

with torch.no_grad():
    for images, real_labels in test_loader:
        images = images.to(device)
        real_labels = real_labels.to(device)
        
        outputs = model(images)
        _, predicted_labels = torch.max(outputs, 1)
        
        total += real_labels.size(0)
        correct += (predicted_labels == real_labels).sum().item()

print(f'Accuracy on test dataset: {100 * correct / total:.2f}%')

Accuracy on test dataset: 99.25%
