In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import sys
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
def load_mnist_dataset(batch_size, resize=None, root='../../Datasets/FashionMNIST'):
    trans = []
    if resize:
        trans.append(transforms.Resize(size=resize))
    trans.append(transforms.ToTensor())

    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
    
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return train_iter, test_iter

In [13]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv = nn.Sequential(
            # input: n*1*224*224
            nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4), # n*96*54*54
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), # n*96*26*26
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2), # n*256*26*26
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), # n*256*12*12
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1), # n*384*12*12
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), # n*384*12*12
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1), # n*256*12*12
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2) # n*256*5*5
        )
        self.fc = nn.Sequential(
            nn.Linear(256*5*5, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 10),
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))
        return output

In [14]:
def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, nn.Module):
        device = list(net.parameters())[0].device
    acc_sum, n = 0, 0
    with torch.no_grad():
        for X, y in data_iter:
            net.eval() # 评估模式，关闭dropout
            acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
            net.train() # 训练模式
            n += y.shape[0]
    return acc_sum / n

In [15]:
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on", device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, batch_num, start = 0.0, 0.0, 0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_num += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_num, train_acc_sum / n, test_acc, time.time() - start))

In [16]:
batch_size = 128
train_iter, test_iter = load_mnist_dataset(batch_size=batch_size, resize=224)

In [17]:
lr, num_epochs = 0.001, 5
net = AlexNet()
optimizer = optim.Adam(net.parameters(), lr=lr)

In [19]:
# cpu上训练超慢
# train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)