In [3]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader 


BATCH_SIZE = 64
EPOCH = 50
LR = 0.001

my_tf = transforms.Compose([
    transforms.Resize((224,224)),  #模型的输入图像尺寸
    transforms.ToTensor(),         
    transforms.Normalize([0.4914, 0.4822, 0.4465],[0.247, 0.243, 0.261]) 
    ])

train_dataset = torchvision.datasets.CIFAR10(root='./', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./', train=False, transform=transforms.ToTensor(), download=True)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset)

model = torchvision.models.resnet50(pretrained=True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(EPOCH):
    for idx, data in enumerate(train_dataloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = model(inputs)
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch{} loss:{:.4f}'.format(epoch+1, loss.item()))

#保存模型参数
torch.save(model, 'cifar10_resnet.pt')
#模型加载
model = torch.load('cifar10_resnet.pt')
#测试
model.eval()
correct, total = 0, 0
for data in test_dataloader:
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)
    output = model(inputs)
    _,idx = torch.max(output.data,1) # 输出最大值的位置
    total += labels.size(0) # 全部图片
    correct +=(idx==labels).sum() # 正确的图片
print(f"accuracy:{100.*correct/total}")

Files already downloaded and verified
Files already downloaded and verified
epoch1 loss:2.6683
epoch2 loss:0.9296
epoch3 loss:0.5147
epoch4 loss:0.4831
epoch5 loss:1.4039
epoch6 loss:0.2270
epoch7 loss:0.9371
epoch8 loss:0.4071
epoch9 loss:0.5657
epoch10 loss:0.1740
epoch11 loss:0.3852
epoch12 loss:0.7975
epoch13 loss:0.0313
epoch14 loss:0.6374
epoch15 loss:0.1955
epoch16 loss:2.4862
epoch17 loss:0.0254
epoch18 loss:0.4474
epoch19 loss:0.1573
epoch20 loss:0.0212
epoch21 loss:0.0096
epoch22 loss:0.0924
epoch23 loss:0.0190
epoch24 loss:0.0654
epoch25 loss:0.7753
epoch26 loss:0.0264
epoch27 loss:0.0616
epoch28 loss:0.0020
epoch29 loss:0.2466
epoch30 loss:0.0037
epoch31 loss:0.0077
epoch32 loss:0.9050
epoch33 loss:0.2553
epoch34 loss:0.0201
epoch35 loss:0.0493
epoch36 loss:0.0130
epoch37 loss:0.3112
epoch38 loss:2.1373
epoch39 loss:0.2105
epoch40 loss:0.0197
epoch41 loss:0.1113
epoch42 loss:0.0287
epoch43 loss:0.0014
epoch44 loss:0.0923
epoch45 loss:0.0026
epoch46 loss:0.0005
epoch47 loss: