In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

#数据加载
train_data=datasets.CIFAR10(root='/root/cifar10/', train=True, transform=transforms.ToTensor(), download=True)
test_data=datasets.CIFAR10(root='/root/cifar10/', train=False, transform=transforms.ToTensor(), download=True)

#输出图像
temp=train_data[1][0].numpy()
print(temp.shape)
temp=temp.transpose(1, 2, 0)
print(temp.shape)
plt.imshow(temp)

#超参数定义
EPOCH=3
BATCH_SIZE=128
LR=0.001

from torch.utils.data import DataLoader
#使用DataLoader进行分批
train_loader=DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader=DataLoader(dataset=test_data, batch_size=BATCH_SIZE)

#使用ResNet
model=torchvision.models.resnet18(pretrained=True)

import time
#定义损失函数
criterion=nn.CrossEntropyLoss()
#定义优化器
optimizer=optim.Adam(model.parameters(), lr=LR)
#定义device
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

#训练
for epoch in range(EPOCH):
    start_time=time.time()
    for i, data in enumerate(train_loader):
        inputs, labels=data
        inputs, labels=inputs.to(device), labels.to(device)
        #前向传播
        outputs=model(inputs)
        #计算损失函数
        loss=criterion(outputs, labels)
        #清空上一轮梯度
        optimizer.zero_grad()
        #反向传播
        loss.backward()
        #参数更新
        optimizer.step()
    print('epoch{} loss:{:.4f} time:{:.4f}'.format(epoch+1, loss.item(), time.time()-start_time))

#保存训练模型
file_name='cifar10_resnet.pt'
torch.save(model, file_name)
print(file_name+' saved')

#测试
model=torch.load(file_name)
model.eval()

correct, total=0, 0
for data in test_loader:
    images, labels=data
    images, labels=images.to(device), labels.to(device)
    #前向传播
    out=model(images)
    #预测结果
    _, predicted=torch.max(out.data, 1)
    #判断预测结果与实际结果是否一致
    total += labels.size(0)
    correct += (predicted==labels).sum().item()

#输出识别准确率
print('10000张测试图像 准确率:{:.4f}%'.format(100.0*correct/total))