In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sys
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [36]:
def load_mnist_dataset(batch_size, resize=None, root='../../Datasets/FashionMNIST'):
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())

    transform = torchvision.transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=trans)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=trans)
    
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return train_iter, test_iter

In [16]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv = nn.Sequential(
            # input: n*1*28*28
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5), # n*6*24*24
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2), # n*6*12*12
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # n*16*8*8
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2) # n*16*4*4
        )
        self.fc = nn.Sequential(
            # n*256
            nn.Linear(16*4*4, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )
    
    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))
        return output

In [17]:
def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, nn.Module):
        device = list(net.parameters())[0].device
    acc_sum, n = 0, 0
    with torch.no_grad():
        for X, y in data_iter:
            net.eval() # 评估模式，关闭dropout
            acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
            net.train() # 训练模式
            n += y.shape[0]
    return acc_sum / n

In [32]:
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_num, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_num += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_num, train_acc_sum / n, test_acc, time.time() - start))

In [33]:
batch_size = 256
train_iter, test_iter = load_mnist_dataset(batch_size)

In [34]:
lr, num_epochs = 0.001, 5
net = LeNet()
optimizer = optim.Adam(net.parameters(), lr=lr)

In [35]:
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on cpu
epoch 1, loss 1.8079, train acc 0.329, test acc 0.581, time 23.0 sec
epoch 2, loss 0.9450, train acc 0.637, test acc 0.683, time 24.0 sec
epoch 3, loss 0.7506, train acc 0.720, test acc 0.733, time 24.9 sec
epoch 4, loss 0.6666, train acc 0.743, test acc 0.748, time 24.7 sec
epoch 5, loss 0.6166, train acc 0.759, test acc 0.762, time 24.6 sec
