In [55]:
import torch
import torchvision
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt

torch.manual_seed(7)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [56]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./logs')

In [57]:
random_affine = torchvision.transforms.RandomAffine(degrees=10,
                                                        scale=(0.9, 1.1),
                                                        translate=(0.1, 0.1),
                                                        interpolation=torchvision.transforms.InterpolationMode.BICUBIC)

In [58]:
#定义超参数
BATCH_SIZE = 128

transform1 = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
                                 torchvision.transforms.RandomHorizontalFlip()])

transform2 = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data',train = True,
                                        download=True, transform=transform1)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = BATCH_SIZE,
                                          shuffle = True, num_workers=0,pin_memory = False)

testset = torchvision.datasets.CIFAR10(root='./data',train = False,
                                        download=True, transform=transform2)
testloader = torch.utils.data.DataLoader(testset,batch_size = BATCH_SIZE,
                                          shuffle = False, num_workers=0,pin_memory = False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [59]:
mynet = torchvision.models.quantization.resnet18(weights=None, progress=True, quantize=False)
mynet = mynet.train().to(DEVICE)
optimizer = torch.optim.SGD(mynet.parameters(), lr=0.01, momentum=0.9)
loss_func = torch.nn.CrossEntropyLoss()

In [60]:
lost = 0
mynet = mynet.train().to(DEVICE)
for epoch in range(10):
    running_loss = 0.0
    for step, (b_x,b_y)in enumerate(trainloader):
        outputs = mynet(b_x.to(DEVICE)) # 喂给 net 训练数据 x, 输出预测值
        loss = loss_func(outputs, b_y.to(DEVICE)) # 计算两者的误差
        optimizer.zero_grad() # 清空上一步的残余更新参数值
        loss.backward() # 误差反向传播, 计算参数更新值
        optimizer.step() # 将参数更新值施加到 net 的 parameters 上
        writer.add_scalar("loss18",loss.item(),step)
        # 打印状态信息
        running_loss += loss.item()
        #if(lost>loss.item()):
        #    lost = loss.item()
        #    torch.save(mynet.cpu(),str(loss.item()*100)+"ciarf.pth")
        #    mynet = mynet.train().to(DEVICE)

        #if step % 1000 == 999:    # 每2000个批次打印一次
        #    print('[%d, %5d] loss: %.3f' %
        #          (epoch + 1, step + 1, running_loss / 2000))
        #    running_loss = 0.0
    print('[%d, %5d] loss: %.3f' %
            (epoch + 1, step + 1, running_loss / 2000))
    running_loss = 0.0

print('Finished Training')
writer.close()

[1,   391] loss: 0.307
[2,   391] loss: 0.221
[3,   391] loss: 0.185
[4,   391] loss: 0.161
[5,   391] loss: 0.141
[6,   391] loss: 0.125
[7,   391] loss: 0.113
[8,   391] loss: 0.101
[9,   391] loss: 0.093
[10,   391] loss: 0.083
Finished Training
[1,   391] loss: 0.313
[2,   391] loss: 0.223
[3,   391] loss: 0.186
[4,   391] loss: 0.161
[5,   391] loss: 0.143
[6,   391] loss: 0.128
[7,   391] loss: 0.114
[8,   391] loss: 0.104
[9,   391] loss: 0.094
[10,   391] loss: 0.084
Finished Training


In [61]:
mynet = mynet.eval().to(DEVICE)

In [62]:
all_counter=0
correct_counter=0
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    out = mynet(inputs)
    out = out.detach().cpu().argmax(1)
    t = labels.cpu()
    for m in range(len(t)):
        all_counter += 1
        if t[m] == out[m]:
            correct_counter += 1

print(correct_counter, all_counter, correct_counter / all_counter)

7418 10000 0.7418
7370 10000 0.737


In [63]:
#mynet=torch.load("0.7026215083897114ciarf.pth").to(DEVICE)
all_counter=0
correct_counter=0
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs = inputs.to(DEVICE)
    labels = labels.to(DEVICE)
    out = mynet(inputs)
    out = out.detach().cpu().argmax(1)
    t = labels.cpu()
    for m in range(len(t)):
        all_counter += 1
        if t[m] == out[m]:
            correct_counter += 1

print(correct_counter, all_counter, correct_counter / all_counter)

7418 10000 0.7418
7370 10000 0.737
