In [None]:
import os
import sys

import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.models.alexnet import AlexNet
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from dataset_models import DetectionDataset


In [None]:
# 判断可用设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device} device.")


In [None]:
# 指定数据集目录
image_path = os.path.abspath('datasets/JPEGImages/')
if not os.path.exists(image_path):
    raise Exception(f"{image_path} path does not exist.")

prop_path = os.path.abspath('datasets/Proposals/')
if not os.path.exists(prop_path):
    raise Exception(f"{prop_path} path does not exist.")


In [None]:
# 数据预处理与增强
data_transform = Compose([Resize((224, 224)),
                          ToTensor(),
                          Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [None]:
dataset = DetectionDataset(image_path, prop_path, data_transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, validate_dataset = random_split(dataset, [train_size, test_size])
train_num = len(train_dataset)
val_num = len(validate_dataset)
train_num, val_num


In [None]:
batch_size = 64
cpu_count = os.cpu_count()
num_workers = cpu_count - 1 if cpu_count - 1 > 0 else 1


In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
validate_loader = DataLoader(validate_dataset, batch_size=4, shuffle=False,  num_workers=num_workers)
print(f"using {train_num} images for training, {val_num} images for validation.")


In [None]:
# 实例化模型，并送进设备
net = AlexNet(num_classes=2)
net.to(device)


In [None]:
# 指定损失函数用于计算损失；指定优化器用于更新模型参数；指定训练迭代的轮数，训练权重的存储地址
loss_function = nn.CrossEntropyLoss()
optimizer = Adam(net.parameters(), lr=0.0002)
epochs = 10
save_path = os.path.abspath('./weights')
if not os.path.exists(save_path):
    os.makedirs(save_path, exist_ok=True)

best_acc = 0.0  # 初始化验证集上最好的准确率，以便后面用该指标筛选模型最优参数。
train_steps = len(train_loader)  # rain_steps = len(dataset) / batch_size


In [None]:
for epoch in range(epochs):
    net.train()
    train_loss = torch.zeros(1).to(device) # 初始化，用于计算训练损失           torch.zeros(2, 3)————这就是torch.zeros的用法，括号内是size
                                           # tensor([[ 0.,  0.,  0.],
                                           #         [ 0.,  0.,  0.]])
    acc_num = torch.zeros(1).to(device)    # 初始化，用于计算训练过程中预测正确的数量
    sample_num = 0                         # 初始化，用于记录当前迭代中，已经计算了多少个样本
    # tqdm是一个进度条显示器，可以在终端打印出现在的训练进度
    train_bar = tqdm(train_loader, file=sys.stdout, ncols=100)
    for step, data in enumerate(train_bar):
        images, labels = data
        sample_num += images.shape[0]
        optimizer.zero_grad()
        outputs = net(images.to(device))          # output_shape: [batch_size, num_classes]   这里的images应该是前向传播forward中的x
        pred_class = torch.max(outputs, dim=1)[1] # torch.max 返回值是一个tuple，第一个元素是max值，第二个元素是max值的索引。  
        #这里dim表示要降维的维度，pred_class范围的是分类号
        acc_num += torch.eq(pred_class, labels.to(device)).sum() #torch.eq()判断后面两个数组对应元素是否相等，相等为true，不等为flase，这里对bool数组求和啥意思呢

        loss = loss_function(outputs, labels.to(device)) # 求损失，  ？？？为什么上下同时对labels做计算，但是计算的维度信息并不匹配啊？？？
        loss.backward() # 自动求导
        optimizer.step() # 梯度下降

        # print statistics
        train_loss += loss.detach()  / (step + 1)
        train_acc = acc_num.item() / sample_num 
        # .desc是进度条tqdm中的成员变量，作用是描述信息
        train_bar.desc = f"Epoch {epoch + 1}/{epochs}"

    # validate
    net.eval()
    acc_num = 0.0  # accumulate accurate number per epoch
    with torch.no_grad(): 
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc_num += torch.eq(predict_y, val_labels.to(device)).sum().item() 

    val_acc = acc_num / val_num
    print(f'Epoch {epoch + 1}/{epochs}: train_loss={float(train_loss / train_steps):.3f} train_acc={float(train_acc):.3f} val_accuracy={float(val_acc):.3f}')
    # 判断当前验证集的准确率是否是最大的，如果是，则更新之前保存的权重
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(net.state_dict(), os.path.join(save_path, "Detection.pth"))
