# 智能码货模型训练

## 1.导入所需模块

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import os
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

## 2.定义模型训练函数

In [None]:
def train():
    # 输入图片大小
    input_size = 224
    # 测试数据集占比
    valid_percente = 0.2
    # 批量大小
    batch_size = 32
    # 训练次数
    NUM_EPOCHS = 10
    # 学习率
    lr = 0.0001
    # 训练设备选择 GPU or CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    data_dir  = 'goods_datasets'
    
    try:
        # 模型保存地址
        os.makedirs('students_models')
    except FileExistsError:
        print('模型保存文件夹已被创建')
        
    # 模型保存地址
    BEST_MODEL_PATH = 'students_models/goods.pth'
    
    # 创建图片数据预处理
    dataset = datasets.ImageFolder(
    data_dir,
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.637,0.619,0.5936], [0.336,0.339,0.358])])
    )
    class_idx = dataset.class_to_idx
    print('数据标签与对应的索引',class_idx)
    # 测试数据集图片数量
    num_valid = int(len(dataset) * valid_percente)
    # 划分训练数据集与测试数据集
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_valid, num_valid])
    # 加载数据集
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=6
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=6
    )
    
    # 训练与测试数据集图片数量
    num_train_dataset = len(train_dataset)
    num_test_dataset = len(test_dataset)
    # 选择想要的网络模型，“pretrained”表示是否加载预训练模型
#     model = models.mobilenet_v2(pretrained=True)
#     model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_of_classes)
    model = models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, len(class_idx))
    model = model.to(device)
    
    # 定义损失函数
    loss_fc = nn.CrossEntropyLoss()
    # 选择优化器 Adam 或者 SGD
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.01)
    for epoch in range(NUM_EPOCHS):
        batch = 0
        # 初始化训练参数
        best_accuracy = 0.0
        train_loss = 0.0
        train_corrects = 0

        # 初始化测试参数
        test_acc = 0.0
        test_loss = 0.0
        test_corrects = 0
        # 模型训练
        model.train()
        for images, labels in iter(train_loader):
            # 选择设备将“图片”和“标签”输入模型中
            images = images.to(device)
            labels = labels.to(device)
            # 初始化梯度
            optimizer.zero_grad()
            # 模型前向传播
            outputs = model(images)
            # 通过交叉熵求出模型预测的结果与真实“标签”之间的误差值loss
            tr_loss = loss_fc(outputs, labels)
            # 反向传播，通过loss对模型参数进行求导更新参数
            tr_loss.backward()
            # 使用优化器对模型参数进行更新
            optimizer.step()

            train_loss += tr_loss.item() * images.size(0)

            _, predict = torch.max(outputs, 1)
            train_corrects += torch.sum(labels.data == predict)

        train_loss = train_loss / num_train_dataset
        train_acc = train_corrects.item() / num_train_dataset
        # 对测试集进行评估
        model.eval()
        for images, labels in iter(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                # 前向传播得到预测结果
                outputs = model(images)
                _, predict = torch.max(outputs, 1)
                t_loss = loss_fc(outputs, labels)
                test_loss += t_loss.item() * images.size(0)

                # 记录预测失败的数量
                test_corrects += torch.sum(labels.data == predict)

        test_loss = test_loss / num_test_dataset
        test_acc = test_corrects.item() / num_test_dataset

        print('epoch={}'.format(epoch + 1))
        print('训练数据集准确率为：{:.2%}，误差为：{}'.format(train_acc, train_loss))
        print('测试数据集准确率为：{:.2%}, 误差为：{}'.format(test_acc, test_loss))
        if test_acc > best_accuracy:
            torch.save(model, BEST_MODEL_PATH)
            best_accuracy = test_acc


## 3.开始训练

In [None]:
train()