In [1]:
import torch
import time
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn.functional as F

In [2]:
device = torch.device('cuda')

In [3]:
dtrain = datasets.MNIST(root='MNIST_DATA', train=True, download=True)
dtest = datasets.MNIST(root='MNIST_DATA', train=False, download=True)
train_img = F.pad(dtrain.data.float()[:, None, ...] / 255., (2, 2, 2, 2))
test_img = F.pad(dtest.data.float()[:, None, ...] / 255., (2, 2, 2, 2))
train_iter = DataLoader(TensorDataset(train_img.to(device), dtrain.targets.to(device)), batch_size=128, shuffle=True)
test_iter = DataLoader(TensorDataset(test_img.to(device), dtest.targets.to(device)), batch_size=1000, shuffle=False)
# train_iter = DataLoader(TensorDataset(train_img, dtrain.targets), batch_size=128, shuffle=True)
# test_iter = DataLoader(TensorDataset(test_img, dtest.targets), batch_size=1000, shuffle=False)

In [4]:
# 目前来看inplace=True基本不影响速度
inplace = True
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding='valid'),
    nn.ReLU(inplace=inplace),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding='valid'),
    nn.ReLU(inplace=inplace),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(in_features=1600, out_features=1024),
    nn.ReLU(inplace=inplace),
    nn.Linear(in_features=1024, out_features=128),
    nn.ReLU(inplace=inplace),
    nn.Linear(in_features=128, out_features=10)
)
model = model.to(device)


In [5]:
# item()函数确实影响速度, 所以计算LOSS, ACC时不执行.item()
epochs = 10
train_size = len(dtrain)
test_size = len(dtest)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss().to(device)

def train():
    for epoch in range(1, epochs + 1):
        total_train_loss = 0.
        train_nacc = 0
        ts = time.time()
        for xi, yi in train_iter:
            # xi, yi = xi.cuda(), yi.cuda()
            # set_to_none=True
            optimizer.zero_grad(set_to_none=True)
            # optimizer.zero_grad()
            logits = model(xi)
            loss = loss_fn(logits, yi)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                total_train_loss += loss * len(xi)
                train_nacc += torch.sum(logits.argmax(axis=-1) == yi)
        train_cost, train_acc = total_train_loss / train_size, train_nacc / train_size

        total_test_loss = 0.
        test_nacc = 0
        with torch.no_grad():
            for xi, yi in test_iter:
                # xi, yi = xi.cuda(), yi.cuda()
                logits = model(xi)
                total_test_loss += loss_fn(logits, yi) * len(xi)
                test_nacc += torch.sum(logits.argmax(axis=-1) == yi)
        test_cost, test_acc = total_test_loss / test_size, test_nacc / test_size
        time_used = time.time() - ts
        print(f"[{epoch}]TIME:{time_used}s COST:{test_cost} ACC:{train_acc} TEST_COST:{test_cost} TEST_ACC:{test_acc}")

%time train()

[1]TIME:8.557946681976318s COST:0.04621458798646927 ACC:0.945733368396759 TEST_COST:0.04621458798646927 TEST_ACC:0.9835000038146973
[2]TIME:8.146683931350708s COST:0.045185547322034836 ACC:0.9862666726112366 TEST_COST:0.045185547322034836 TEST_ACC:0.986299991607666
[3]TIME:8.104036808013916s COST:0.02236361615359783 ACC:0.9905666708946228 TEST_COST:0.02236361615359783 TEST_ACC:0.9918999671936035
[4]TIME:8.092193365097046s COST:0.024845756590366364 ACC:0.9922500252723694 TEST_COST:0.024845756590366364 TEST_ACC:0.9918999671936035
[5]TIME:8.094757080078125s COST:0.028636015951633453 ACC:0.9944166541099548 TEST_COST:0.028636015951633453 TEST_ACC:0.9905999898910522
[6]TIME:8.093015193939209s COST:0.025910209864377975 ACC:0.9948999881744385 TEST_COST:0.025910209864377975 TEST_ACC:0.991599977016449
[7]TIME:8.161656379699707s COST:0.03196645900607109 ACC:0.996483325958252 TEST_COST:0.03196645900607109 TEST_ACC:0.9914999604225159
[8]TIME:8.097290992736816s COST:0.03175250440835953 ACC:0.9966833