In [1]:
from LeNet5 import LeNet5
import numpy as np
import os
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 256
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = LeNet5().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = CrossEntropyLoss()
all_epoch = 100
if not os.path.isdir("models"):
    os.mkdir("models")

In [3]:
for current_epoch in range(all_epoch):
    model.train()
    for idx, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        predict_y = model(data.float())
        loss = loss_fn(predict_y, label.long())
        loss.backward()
        optimizer.step()

    all_correct_num = 0
    all_sample_num = 0
    model.eval()

    for idx, (test_x, test_label) in enumerate(test_loader):
        test_x = test_x.to(device)
        test_label = test_label.to(device)
        predict_y = model(test_x.float()).detach()
        predict_y = torch.argmax(predict_y, dim=-1)
        current_correct_num = predict_y == test_label
        all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
        all_sample_num += current_correct_num.shape[0]
    acc = all_correct_num / all_sample_num
    print('Epoch %d/%d: Loss=%.6f, Accuracy=%.6f' % (current_epoch, all_epoch, loss.item(), acc))
    torch.save(model.state_dict(), 'models/lenet_{:.3f}.pt'.format(acc))
print("Model finished training")

Epoch 0/100: Loss=1.163798, Accuracy=0.689100
Epoch 1/100: Loss=0.612977, Accuracy=0.838400
Epoch 2/100: Loss=0.443324, Accuracy=0.886400
Epoch 3/100: Loss=0.351810, Accuracy=0.908300
Epoch 4/100: Loss=0.301460, Accuracy=0.922400
Epoch 5/100: Loss=0.267620, Accuracy=0.932200
Epoch 6/100: Loss=0.242694, Accuracy=0.939700
Epoch 7/100: Loss=0.223391, Accuracy=0.946700
Epoch 8/100: Loss=0.208120, Accuracy=0.951700
Epoch 9/100: Loss=0.196093, Accuracy=0.957300
Epoch 10/100: Loss=0.186627, Accuracy=0.962600
Epoch 11/100: Loss=0.179057, Accuracy=0.966600
Epoch 12/100: Loss=0.173066, Accuracy=0.968600
Epoch 13/100: Loss=0.168687, Accuracy=0.971000
Epoch 14/100: Loss=0.165859, Accuracy=0.973500
Epoch 15/100: Loss=0.164254, Accuracy=0.975200
Epoch 16/100: Loss=0.163469, Accuracy=0.976700
Epoch 17/100: Loss=0.163188, Accuracy=0.977000
Epoch 18/100: Loss=0.163209, Accuracy=0.978100
Epoch 19/100: Loss=0.163417, Accuracy=0.978100
Epoch 20/100: Loss=0.163725, Accuracy=0.978500
Epoch 21/100: Loss=0.16