使用一对多的多分类策略，为每个状态建立单独的分类器，然后分别训练

- 配置全局参数可以任意指定要训练的模型和对应的数据集
    - 保存生成的数据集并且要支持从使用之前生成的数据集
未完成部分
- 配置全局参数可以针对指定的模型进行保存和读取
    - 模型的保存和读取，要支持保存不同版本，同时留下模型的相关数据，方便后续选择出最优的模型
- 拆分训练集和验证集
- 配置整合函数main

In [None]:
from StatusChecker.TraditionalStatusChecker import TraditionalStatusChecker

current_train_tag = TraditionalStatusChecker.ASC_STATUS_FIGHTING
preset_positive_data_count_min = 100
learn_rate = 0.1

In [None]:
status_list = list(filter(lambda name: (name.startswith("ASC_STATUS_") and name != 'ASC_STATUS_UNKNOWN'),
                                  dir(TraditionalStatusChecker)))
status_list = list(map(lambda name: getattr(TraditionalStatusChecker, name), status_list))

print(status_list)
print(f"total {len(status_list)} status")
# tag_dic = status_list
# tag_dic = ["level_selection",
#            "restore_sanity_medicine",
#            "restore_sanity_stone",
#            "team_up",
#            "fighting",
#            "battle_settlement",
#            "annihilation_settlement",
#            "level_up"]


根据训练目标选取数据集
- 正向数据集选取目标标签，数量选取所有的或者指定阈值
- 反向数据集选取其余的标签数据，数量取正向数据集长度平均至每个TAG，不够的暂时全上

In [None]:
import os
import torch
import random

# 基础数据集路径
train_set_base_path = os.path.join(os.getcwd(), '..', 'dataset', 'processed')

#TODO 如果说数据量不平衡的话，最后那个数据会影响很多东西把
def generate_train_data_list():
    # 生成数据数量
    data_count = {current_train_tag: min(
        len(os.path.join(train_set_base_path, current_train_tag)),
        preset_positive_data_count_min
    )}
    print(f"needed positive data count: {data_count[current_train_tag]}")
    negative_data_count_need_each = data_count[current_train_tag] / (len(status_list) - 1)
    print(f"needed negative data count: {negative_data_count_need_each}")
    for status in status_list:
        if status == current_train_tag:
            continue
        # 统计，取可用数据数量和当前所需目标数量的最小值
        data_count[status] = min(
            negative_data_count_need_each,
            len(os.listdir(os.path.join(train_set_base_path, status)))
        )
    print(f"decided data count use: {data_count}")
    # 根据数量开始处理数据路径
    # 一个是要方便之后拿到TAG，一个是考虑要不要每次随机选取数据，不要每次都是开头的那些数据
    # TAG可以随时从文件名拿到，随机选取数据用random生成一个不重复的数组，然后将它们作为实际的数据索引
    # 跟上完整log以配合输出
    # 正向数据生成
    train_set_positive_data_set = random.choices(
        os.listdir(
            os.path.join(train_set_base_path, current_train_tag)
        ), k=data_count[current_train_tag]
    )
    # 负面数据生成
    train_set_negative_data_set = []
    for status in status_list:
        if status == current_train_tag:
            continue
        train_set_negative_data_set += random.choices(
                os.listdir(
                    os.path.join(train_set_base_path, status)
                ), k=data_count[status]
            )
    # 输出本次数据集到log文件夹
    from datetime import datetime
    log_path = os.path.join(os.getcwd(), '..', 'log')
    current_time = datetime.now()
    print(f"start log at {current_time}")
    with open(os.path.join(log_path, f"train_data_set_list-{current_time}.log"), "w") as log_file:
        log_file.write("\n===========positive start===============\n")
        log_file.writelines(train_set_positive_data_set)
        log_file.write("\n===========positive end=================\n")
        log_file.write("\n===========negative start===============\n")
        log_file.writelines(train_set_negative_data_set)
        log_file.write("\n===========negative end=================\n")
    print("log finished")
    # 获得最终的数据列表
    return train_set_positive_data_set + train_set_negative_data_set

# 用于生成数据对应的TAG
def get_image_tag(filename:str) -> int:
    for index, tag in enumerate(status_list):
        if tag in filename:
            return index

def generate_target_list(train_file_list:list):

    # 取得对应的标签数据
    return list(map(get_image_tag, train_file_list))

配置 DataSet

In [None]:
import cv2 as cv
from torch.utils.data import Dataset

def default_loader(image_name:str):
    image_loaded = cv.imread(
        os.path.join(train_set_base_path, status_list[get_image_tag(image_name)], image_name),
        cv.IMREAD_GRAYSCALE)
    print(f"load image {image_name}")
    # print(image_tensor.size())
    return image_loaded

class TrainSet(Dataset):
    def __init__(self, image_name_list:list, targets:list, loader=default_loader)-> None:
        self.images = image_name_list
        self.targets = targets
        self.loader = loader
    def __getitem__(self, index: int):
        image = self.loader(self.images[index])
        image = torch.tensor(image, device=torch.device('cuda'), dtype=torch.float32).unsqueeze(0)
        target = self.targets[index]
        target = torch.tensor(target, dtype=torch.long, device=torch.device('cuda'))
        return image, target
    def __len__(self)-> int:
        return len(self.images)

def get_train_set(train_file_list:list, target_list:list) -> Dataset:

    train_set = TrainSet(train_file_list, target_list)
    print(train_set)
    print(train_set.__getitem__(3))
    print(train_set.__len__())
    return train_set

构建 DataLoader

In [None]:
from torch.utils.data import DataLoader
def get_train_loader(train_set:Dataset) -> DataLoader:
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
    print(train_loader)
    print(iter(train_loader).next()[0].size())
    print(iter(train_loader).next()[1].size())
    return train_loader


定义模型

In [None]:
import torch

class Lambda(torch.nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    def forward(self, xb):
        return self.func(xb)

def get_module():
    module = torch.nn.Sequential(
        torch.nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=3),
        torch.nn.ReLU(),
        torch.nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=3),
        torch.nn.ReLU(),
        torch.nn.AdaptiveAvgPool2d([16, 36]),
        Lambda(lambda xb: xb.view(-1, 10*16*36)),
        torch.nn.Linear(10*16*36, 16*36),
        torch.nn.Linear(16*36, 36),
        torch.nn.Linear(36, 1),
        torch.nn.Softmax(),
    )
    optimize = torch.optim.SGD(module.parameters(), lr=learn_rate)
    return module, optimize

def get_loss_function():
    loss_func = torch.nn.functional.binary_cross_entropy
    return loss_func

fit() 函数以及模型的保存和恢复

In [None]:
import numpy

def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

def fit(epochs:int, module:torch.nn.Module,
        loss_function, optimize, train_loader, validate_loader):
    index = 0
    for epoch in range(epochs):
        for xb, yb in train_loader:
            index = index + 1
            if index%300 is 0:
                print(f"start index {index} data")
            loss_batch(module, loss_function, xb, yb, optimize)
        module.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(module, loss_function, xb, yb)
                  for xb, yb in validate_loader]
            )
        val_loss = numpy.sum(numpy.multiply(losses, nums)) / numpy.sum(nums)
        print(epoch, val_loss)

准备启动

In [None]:
# def main():

train_file_list = generate_train_data_list()
target_list = generate_target_list(train_file_list)
print(status_list)
print(len(train_file_list))
print(train_file_list[0])
print(len(target_list))
print(target_list[0])

train_set = get_train_set(train_file_list, target_list)

train_loader = get_train_loader(train_set)

optimize, module = get_module()
loss_function = get_loss_function()

fit(epochs=1, module=module, loss_function=loss_function,
    optimize=optimize, train_loader=train_loader, validate_loader=None)
