In [127]:
import torch
import torchvision
from torch import nn
from torch.utils import data
import matplotlib.pyplot as plt
import time
from model import AlexNet

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [128]:
def load_data_MNIST(batch_size, resize, trans):
    # training dataset
    train_dataset = torchvision.datasets.MNIST(root = '../data/', train = True, transform = trans, download = True)
    train_loader = data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
    #testing dataset
    test_dataset = torchvision.datasets.MNIST(root = '../data/', train = False, transform = trans, download = True)
    test_loader = data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = True)
    return train_loader, test_loader

In [129]:
def train(net, optimizer, loss, num_epochs, train_loader, test_loader):
    net.to(device)
    best_acc = 0

    for epoch in range(num_epochs):
        # train
        net.train()
        train_l = 0.0
        time_s = time.perf_counter()
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            l = loss(outputs, labels)
            l.backward()
            optimizer.step()

            train_l += l.item()

        # test
        net.eval()
        test_acc = 0.0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = torch.max(net(inputs), dim = 1)[1]            # outputs: [batch, classes], torch.max: [max_item, max_pos]
                test_acc += torch.eq(outputs, labels).sum().item()
            test_acc /= len(test_loader)
            if best_acc < test_acc:
                best_acc = test_acc

        print('[epoch %d] train_loss: %.3f, test_accuracy: %.2f%%, time: %ds' % (epoch + 1, train_l / i, test_acc, (int(time.perf_counter() - time_s))))

    print('Finished training! Best_accuracy: %.2f%%' % (best_acc))

In [130]:
def main():
    # set net
    net = AlexNet(num_classes = 10, init_weights = True)

    # set datasets
    batch_size = 64
    resize = 224
    trans = torchvision.transforms.Compose([torchvision.transforms.Resize(resize),
                                        torchvision.transforms.ToTensor()])
    train_loader, test_loader = load_data_MNIST(batch_size, resize, trans)

    # set optimizer and loss
    learning_rate = 0.001
    momentum = 0.5
    optimizer = torch.optim.SGD(net.parameters(), lr = learning_rate, momentum = momentum)
    loss = nn.CrossEntropyLoss()

    # start training
    num_epochs = 15
    train(net, optimizer, loss, num_epochs, train_loader, test_loader)

In [131]:
if __name__ == '__main__':
    main()

[epoch 1] train_loss: 2.059, test_accuracy: 54.03%, time: 66s
[epoch 2] train_loss: 0.359, test_accuracy: 61.30%, time: 66s
[epoch 3] train_loss: 0.165, test_accuracy: 61.99%, time: 65s
[epoch 4] train_loss: 0.125, test_accuracy: 62.29%, time: 66s
[epoch 5] train_loss: 0.101, test_accuracy: 62.55%, time: 65s
[epoch 6] train_loss: 0.089, test_accuracy: 62.67%, time: 65s
[epoch 7] train_loss: 0.077, test_accuracy: 62.76%, time: 65s
[epoch 8] train_loss: 0.072, test_accuracy: 62.90%, time: 65s
[epoch 9] train_loss: 0.064, test_accuracy: 62.96%, time: 65s
[epoch 10] train_loss: 0.059, test_accuracy: 62.97%, time: 65s
[epoch 11] train_loss: 0.057, test_accuracy: 63.08%, time: 65s
[epoch 12] train_loss: 0.052, test_accuracy: 63.05%, time: 65s
[epoch 13] train_loss: 0.050, test_accuracy: 63.11%, time: 65s
[epoch 14] train_loss: 0.047, test_accuracy: 63.06%, time: 64s
[epoch 15] train_loss: 0.045, test_accuracy: 63.16%, time: 65s
Finished training! Best_accuracy: 63.16%
