In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from my_netModel.Lenet5 import Lenet5
from my_netModel.ResNet import ResNet18
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler
from my_netModel.fun import print_epoching,Timer,music_play,print_epoching_per

In [2]:
device = torch.device('cuda:0')
# device = torch.device('cpu')

In [3]:
transform = {
        'train': transforms.Compose(
            [transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean = [0.4914,0.4822,0.4465],
                                  std  = [0.2023,0.1994,0.2010])]  # 图像标准化
        ),
        'val': transforms.Compose(
            [
             transforms.Resize((32,32)),   
             transforms.ToTensor(),
             transforms.Normalize(mean = [0.4914,0.4822,0.4465],
                                  std  = [0.2023,0.1994,0.2010])]  # 图像标准化
        )
    }


In [4]:
batch = 32
train_dataset = datasets.CIFAR10(root=r"D:\data\cifar10",train = True,transform=transform['train'])
train_loader = DataLoader(train_dataset,batch_size = batch,shuffle=True, num_workers=1)
    

val_dataset = datasets.CIFAR10(root=r"D:\data\cifar10",train = False,transform=transform['val'])
val_loader = DataLoader(val_dataset,batch_size = batch,shuffle=True, num_workers=1)

In [None]:
net = ResNet18().to(device)

# print(net)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.003, betas=(0.9, 0.999))
# optimizer = torch.optim.Adam(net.parameters(),lr=0.003)
scheduler=lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.3)

In [None]:
train_loss_list = [] 
train_acc_list = [] 
val_epoch_list = []
val_loss_list = [] 
val_acc_list = [] 

In [None]:
all_timer  = Timer()
train_timer  = Timer()
val_timer  = Timer()


In [8]:
epochs = 50 
for epoch in range(epochs):
        all_timer.start()

        net.train()        
        batch_acc = 0
        batch_loss = 0        
        total_num = 0
        train_timer.start()

        for id,(x,lable) in enumerate(train_loader):
                      
            x,lable = x.to(device),lable.to(device)
            logits = net(x)
            loss = criterion(logits,lable)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_acc += sum(torch.argmax(logits,dim = 1)==lable)
            batch_loss += loss
            total_num  += x.size(0)

            print_epoching_per(id,len(train_loader),'train')

        train_acc_list.append(batch_acc.item()/total_num)
        train_loss_list.append(batch_loss.item()/len(train_loader)) 
        train_timing = train_timer.stop()

        
        with torch.no_grad():

            net.eval()
            batch_acc = 0
            total_num = 0
            batch_loss = 0
            val_timer.start()

            for id,(x,lable) in enumerate(val_loader):
                
                x,lable = x.to(device),lable.to(device)
                logits = net(x)
                loss = criterion(logits,lable)

                batch_loss += loss
                batch_acc += sum(torch.argmax(logits,dim = 1)==lable)
                total_num += x.size(0)

                print_epoching_per(id,len(val_loader),'test ')

        
        
        val_loss_list.append(batch_loss.item()/len(val_loader)) 
        val_acc_list.append(batch_acc.item()/total_num) 
        val_epoch_list.append(epoch)
        val_timing = val_timer.stop()
        all_timing = all_timer.stop()

        print('Epoch[{}/{}];time:{:0.2f},{:0.2f},{:0.2f};train_loss:{:0.4f};val_loss:{:0.4f};train_acc:{:0.4f};val_acc:{:0.4f}'.format(
            epoch+1, epochs,            
            all_timing,train_timing,val_timing,      
            train_loss_list[-1],
            val_loss_list[-1],                     
            train_acc_list[-1],
            val_acc_list[-1]
        ))  # 打印每个epoch 
        # 更新学习率
        scheduler.step()
    

train:23.61%

In [None]:
# 保存loss图像
plt.plot(range(len(train_loss_list[1:-1])), train_loss_list[1:-1], label="train_loss")
plt.plot(range(len(val_loss_list[1:-1])), val_loss_list[1:-1]  , label="val_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.show()
plt.clf()

In [None]:

plt.plot(range(len(train_acc_list)), train_acc_list, label="train_acc")
plt.plot(range(len(val_acc_list)), val_acc_list, label="val_acc")

plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.show()
plt.clf()