import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from Net import LeNet_5

# paremeters
momentum = 0.9
record_step = 10
batch_size = 128
test_batch_size = 1000
lr = 0.001
Epoch = 1
random_seed = 1
torch.cuda.manual_seed_all(random_seed)

# test
def test(test_loader, net, loss_function):
    net.eval()
    total_loss = 0
    acc = 0
    for data, target in test_loader:
        data = torch.tensor(data).type(torch.FloatTensor).cuda()
        target = torch.tensor(target).type(torch.LongTensor).cuda()
        out = net(data)
        loss = loss_function(out, target)
        classification = torch.max(out,1)[1]
        total_loss += loss.item()
        correct = (target == classification).sum()
        acc+=correct.item()
    
    return total_loss/len(test_loader.dataset), acc/len(test_loader.dataset)   
    print(total_loss/len(test_loader.dataset))
    print(acc/len(test_loader.dataset))


# train
def main():
    # load dataset
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)) 
                       ])),
        batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=test_batch_size, shuffle=True)
    print('Finished reading data')

    # load net
    net = LeNet_5().cuda()
    loss_function = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

    # train
    max_accuracy = 0
    accuracy_list = []
    for epoch in range(Epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            net.train()
            data = torch.tensor(data).type(torch.FloatTensor).cuda()
            target = torch.tensor(target).type(torch.LongTensor).cuda()
            out = net(data)
            classification = torch.max(out,1)[1]
            optimizer.zero_grad()
            loss = loss_function(out, target)
            loss.backward()
            optimizer.step()

            if (batch_idx+1)%(record_step) == 0:
                test_loss, accuracy = test(test_loader, net, loss_function)
                accuracy_list.append(accuracy)
                # save the best model
                if accuracy >= max_accuracy:
                    torch.save(net, './net.pkl')
                print(f'Epoch{epoch+1} iteration{batch_idx}/{len(train_loader)}: Accuracy = {accuracy} '
                      f'Loss_train = {loss.item()/len(data)} Loss_test = {test_loss}')

    print('Finished all!')
    # figure
    plt.plot(accuracy_list)
    plt.title('Accuracy Line')
    plt.xlabel('Time')
    plt.ylabel('Accuracy')
    plt.savefig("pic.png")
    plt.show()

if __name__ == '__main__':
    main()


运行结果:
Epoch1 iteration9/469: Accuracy = 0.131 Loss_train = 0.018004607409238815 Loss_test = 0.0022998397827148436
Epoch1 iteration19/469: Accuracy = 0.1375 Loss_train = 0.017895152792334557 Loss_test = 0.0022988242149353027
Epoch1 iteration29/469: Accuracy = 0.1452 Loss_train = 0.017878560349345207 Loss_test = 0.0022975866317749023
Epoch1 iteration39/469: Accuracy = 0.151 Loss_train = 0.01792256534099579 Loss_test = 0.002296201753616333
Epoch1 iteration49/469: Accuracy = 0.1598 Loss_train = 0.017969252541661263 Loss_test = 0.002294724774360657
Epoch1 iteration59/469: Accuracy = 0.1647 Loss_train = 0.01789519563317299 Loss_test = 0.0022931348800659178
Epoch1 iteration69/469: Accuracy = 0.1702 Loss_train = 0.017858892679214478 Loss_test = 0.0022914038419723512
Epoch1 iteration79/469: Accuracy = 0.1688 Loss_train = 0.017848389223217964 Loss_test = 0.002289528489112854
Epoch1 iteration89/469: Accuracy = 0.177 Loss_train = 0.017940325662493706 Loss_test = 0.0022875229597091673
Epoch1 iteration99/469: Accuracy = 0.1886 Loss_train = 0.017897559329867363 Loss_test = 0.0022854191541671755
Epoch1 iteration109/469: Accuracy = 0.2051 Loss_train = 0.01791350543498993 Loss_test = 0.00228315486907959
Epoch1 iteration119/469: Accuracy = 0.2223 Loss_train = 0.017804570496082306 Loss_test = 0.0022807384729385378
Epoch1 iteration129/469: Accuracy = 0.2375 Loss_train = 0.017852842807769775 Loss_test = 0.0022781068325042723
Epoch1 iteration139/469: Accuracy = 0.2568 Loss_train = 0.017739254981279373 Loss_test = 0.002275214099884033
Epoch1 iteration149/469: Accuracy = 0.2736 Loss_train = 0.017705664038658142 Loss_test = 0.0022720455408096313
Epoch1 iteration159/469: Accuracy = 0.2914 Loss_train = 0.017776422202587128 Loss_test = 0.0022687005043029785
Epoch1 iteration169/469: Accuracy = 0.3206 Loss_train = 0.017663512378931046 Loss_test = 0.0022649160861968995
Epoch1 iteration179/469: Accuracy = 0.3469 Loss_train = 0.017628077417612076 Loss_test = 0.0022607649087905885
Epoch1 iteration189/469: Accuracy = 0.378 Loss_train = 0.01761803962290287 Loss_test = 0.0022561724662780763
Epoch1 iteration199/469: Accuracy = 0.399 Loss_train = 0.01758885383605957 Loss_test = 0.002250828289985657
Epoch1 iteration209/469: Accuracy = 0.4075 Loss_train = 0.017620466649532318 Loss_test = 0.002244613695144653
Epoch1 iteration219/469: Accuracy = 0.4048 Loss_train = 0.017562108114361763 Loss_test = 0.0022375184774398806
Epoch1 iteration229/469: Accuracy = 0.39 Loss_train = 0.017417872324585915 Loss_test = 0.0022295536279678344
Epoch1 iteration239/469: Accuracy = 0.3803 Loss_train = 0.01727321557700634 Loss_test = 0.0022200347661972047
Epoch1 iteration249/469: Accuracy = 0.3757 Loss_train = 0.01733490079641342 Loss_test = 0.0022087298631668093
Epoch1 iteration259/469: Accuracy = 0.3684 Loss_train = 0.016976773738861084 Loss_test = 0.002195406460762024
Epoch1 iteration269/469: Accuracy = 0.3598 Loss_train = 0.017086086794734 Loss_test = 0.002178644371032715
Epoch1 iteration279/469: Accuracy = 0.3541 Loss_train = 0.016891829669475555 Loss_test = 0.00215859637260437
Epoch1 iteration289/469: Accuracy = 0.3594 Loss_train = 0.01660800538957119 Loss_test = 0.0021343688011169434
Epoch1 iteration299/469: Accuracy = 0.3682 Loss_train = 0.01646573655307293 Loss_test = 0.0021042741298675535
Epoch1 iteration309/469: Accuracy = 0.3832 Loss_train = 0.016296641901135445 Loss_test = 0.0020669954061508177
Epoch1 iteration319/469: Accuracy = 0.4143 Loss_train = 0.01612268015742302 Loss_test = 0.0020222079277038576
Epoch1 iteration329/469: Accuracy = 0.4362 Loss_train = 0.015136996284127235 Loss_test = 0.0019681931376457212
Epoch1 iteration339/469: Accuracy = 0.4675 Loss_train = 0.015051516704261303 Loss_test = 0.0019022202849388122
Epoch1 iteration349/469: Accuracy = 0.48 Loss_train = 0.014442021027207375 Loss_test = 0.0018259101867675782
Epoch1 iteration359/469: Accuracy = 0.5183 Loss_train = 0.013943396508693695 Loss_test = 0.001740064811706543
Epoch1 iteration369/469: Accuracy = 0.5704 Loss_train = 0.012971184216439724 Loss_test = 0.0016495554566383361
Epoch1 iteration379/469: Accuracy = 0.5875 Loss_train = 0.01316823624074459 Loss_test = 0.0015488527297973633
Epoch1 iteration389/469: Accuracy = 0.5981 Loss_train = 0.01214633509516716 Loss_test = 0.001437147879600525
Epoch1 iteration399/469: Accuracy = 0.6425 Loss_train = 0.010402144864201546 Loss_test = 0.001330328905582428
Epoch1 iteration409/469: Accuracy = 0.6877 Loss_train = 0.009638778865337372 Loss_test = 0.001211769998073578
Epoch1 iteration419/469: Accuracy = 0.7118 Loss_train = 0.009278555400669575 Loss_test = 0.0011043765902519227
Epoch1 iteration429/469: Accuracy = 0.7312 Loss_train = 0.00820239633321762 Loss_test = 0.001010335385799408
Epoch1 iteration439/469: Accuracy = 0.7412 Loss_train = 0.007313249632716179 Loss_test = 0.0009250949561595917
Epoch1 iteration449/469: Accuracy = 0.7442 Loss_train = 0.0065794577822089195 Loss_test = 0.0008587775468826294
Epoch1 iteration459/469: Accuracy = 0.7764 Loss_train = 0.006603806279599667 Loss_test = 0.000794600123167038