In [1]:
# -*- coding: utf-8 -*-

import torch
import os
import random
import os
import numpy as np
import logging
from config import Config
from model import SiameseNetwork, choose_optimizer
from evaluate import Evaluator
from loader import load_data

logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""
模型训练主程序
"""

def main(config):
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    #加载训练数据
    train_data = load_data(config["train_data_path"], config)
    #加载模型
    model = SiameseNetwork(config)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用，迁移模型至gpu")
        model = model.cuda()
    #加载优化器
    optimizer = choose_optimizer(config, model)
    #加载效果测试类
    evaluator = Evaluator(config, model, logger)
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id1, input_id2, input_id3 = batch_data   #输入变化时这里需要修改，比如多输入，多输出的情况
            loss = model(input_id1, input_id2, input_id3)
            train_loss.append(loss.item())
            # if index % int(len(train_data) / 2) == 0:
            #     logger.info("batch loss %f" % loss)
            loss.backward()
            optimizer.step()
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    torch.save(model.state_dict(), model_path)
    return

In [2]:
main(Config)

2024-08-03 22:54:15,835 - __main__ - INFO - gpu可以使用，迁移模型至gpu
2024-08-03 22:54:18,896 - __main__ - INFO - epoch 1 begin
2024-08-03 22:54:19,955 - __main__ - INFO - epoch average loss: 0.949083
2024-08-03 22:54:19,956 - __main__ - INFO - 开始测试第1轮模型效果：
2024-08-03 22:54:19,987 - __main__ - INFO - 预测集合条目总量：464
2024-08-03 22:54:19,988 - __main__ - INFO - 预测正确条目：396，预测错误条目：68
2024-08-03 22:54:19,988 - __main__ - INFO - 预测准确率：0.853448
2024-08-03 22:54:19,989 - __main__ - INFO - --------------------
2024-08-03 22:54:19,989 - __main__ - INFO - epoch 2 begin
2024-08-03 22:54:20,014 - __main__ - INFO - epoch average loss: 0.917139
2024-08-03 22:54:20,015 - __main__ - INFO - 开始测试第2轮模型效果：
2024-08-03 22:54:20,045 - __main__ - INFO - 预测集合条目总量：464
2024-08-03 22:54:20,046 - __main__ - INFO - 预测正确条目：405，预测错误条目：59
2024-08-03 22:54:20,046 - __main__ - INFO - 预测准确率：0.872845
2024-08-03 22:54:20,047 - __main__ - INFO - --------------------
2024-08-03 22:54:20,048 - __main__ - INFO - epoch 3 begin
2024-08-03 22