In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import os
import tools
from models.xception import Xception

In [2]:
# set GPU device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_root_dir = "C:\\Users\\Rui\\Documents\\dataset\\classify\\jia"

In [8]:
# 训练函数
def net_train(net, train_loader, criterion, optimizer, output, epoch):
    net.train()
    # initial loss
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # 将输入传入GPU
        inputs, labels = data[0].to(device=device), data[1].to(device=device)

        # 将梯度置零
        optimizer.zero_grad()

        # 前向传播-计算误差-反向传播-优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # 计算误差并显示
        running_loss += loss.item()
        if (i + 1) % 100 == 0:  # tools.print_log loss every 100 times
            tools.print_log("epoch %s, iteration %s, loss: %.3f" % (epoch + 1, i + 1, running_loss / 100), file=output)
            running_loss = 0.0

    tools.print_log("Training Epoch Finished", file=output)

In [4]:
# 测试函数
def net_test(net, test_loader, output):
    correct = 0.0
    total = 0.0
    # 关闭梯度
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device=device), data[1].to(device=device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    precision = 100 * correct / total
    tools.print_log("Current accuracy of the network test: %s %%" % precision, file=output)
    return precision

In [5]:
# 数据集函数
def get_data_loader(data_dir, batch_size=16):
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor()
    ])
    dataset = torchvision.datasets.ImageFolder(data_dir, transform=transform)
    return DataLoader(dataset, shuffle=True, batch_size=batch_size)

In [7]:
train_loader = get_data_loader(os.path.join(data_root_dir, "train"), batch_size=32)
val_loader = get_data_loader(os.path.join(data_root_dir, "val"), batch_size=32)
test_loader = get_data_loader(os.path.join(data_root_dir, "test"), batch_size=32)
num_classes = len(os.listdir(os.path.join(data_root_dir, "train")))
model_name = "Xception"
net = Xception(num_classes=num_classes).to(device=device)
# 选择误差
criterion = nn.CrossEntropyLoss()
# 选择优化器
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 结果位置
output = open("log/base_classifier.txt", "a")
# 数据处理, 创建数据loader
# 改变epoch
n_epoch = 100
best_model, best_accuracy, best_epoch = net, 0.0, 0
tools.print_log("%s for jia: %s" % (model_name, net), file=output)
for epoch in range(n_epoch):
    # 如果十次迭代之后模型仍然没有提升则提前结束
    # if epoch - best_epoch >= 10:
    #     break
    tools.print_log("epoch %s/%s" % (epoch, n_epoch), file=output)
    net_train(net, train_loader, criterion, optimizer, output, epoch)  # 每个epoch训练一次，测试一次
    accuracy = net_test(net, val_loader, output)
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_model = net
        best_epoch = epoch
    tools.print_log("Current accuracy of the network validation: %s" % accuracy, file=output)
    tools.print_log("Best accuracy of the network validation: %s" % best_accuracy, file=output)
    tools.print_log("Test accuracy of the network: %s" % net_test(net, test_loader, output), file=output)
# 保存最佳的模型
torch.save(best_model, "checkpoint/base_classifier.pkl")