首先导入相关的 `module` 、`class` 和 `function` 。

In [None]:
import os
import logging
import numpy as np
from utils import Dataset, DataLoader
from init import seed_everything, dump_args, data_preprocess
from model import MLPClassifier, accuracy
from nn import CrossEntropyLoss

接着让 `CharDataset` 继承 `Dataset` ，定义数据集。

In [None]:
class CharDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.len = x.shape[0]

    def __getitem__(self, index):
        return {"x": self.x[index], "y": self.y[index]}

    def __len__(self):
        return self.len

接着使用 `args` 存储相关的参数，同时使用 `dump_args` 导出本次运行的参数到文件。

In [None]:
args = {}
args["epoches"] = 100
args["batch_size"] = 64
args["learning_rate"] = 0.0005
args["record_path"] = "./record/char"
args["save_path"] = "./save/char/best_model.pkl"
args["raw_data_path"] = "./data/char/train_raw"
args["data_path"] = "./data/char/data.npz"
args["random_seed"] = 42
args["mode"] = "train_and_test"  # ["train", "test", "train_and_test"]

record_path = dump_args(args, args["record_path"], args["mode"])

使用 `seed_everything` 设置随机数种子，确保整个过程可重复。

In [None]:
seed_everything(args["random_seed"])

尝试加载训练数据，如果训练数据不存在，则通过 `data_preprocess` 生成。加载数据后初始化训练集和验证集。

In [None]:
if not os.path.exists(args["data_path"]):
    data_preprocess(args["raw_data_path"], args["data_path"], shuffle=True)
data_npz = np.load(args["data_path"])

dataset_train = CharDataset(data_npz["x_train"], data_npz["y_train"])
dataloader_train = DataLoader(dataset_train, args["batch_size"], shuffle=True)

dataset_valid = CharDataset(data_npz["x_valid"], data_npz["y_valid"])
dataloader_valid = DataLoader(dataset_valid, args["batch_size"], shuffle=False)

初始化分类器对象，然后尝试加载预训练的模型。

In [None]:
classifier = MLPClassifier(28 * 28, 12)
if os.path.exists(args["save_path"]):
    classifier.load_model(args["save_path"])

接着根据当前的 `args["mode"]` 判断是否进行训练。

In [None]:
if args["mode"] == "train" or args["mode"] == "train_and_test":
    classifier.fit(
        train_loader=dataloader_train,
        valid_loader=dataloader_valid,
        epoches=args["epoches"],
        learning_rate=args["learning_rate"],
        save_path=args["save_path"],
        log_interval=5,
    )

最后可以进行测试集上的测试。

In [None]:
if args["mode"] != "test" and args["mode"] != "train_and_test":
    exit(0)
dataset_test = CharDataset(data_npz["x_test"], data_npz["y_test"])
dataloader_test = DataLoader(dataset_test, args["batch_size"], shuffle=False)

loss = 0.0
acc = 0.0
loss_func = CrossEntropyLoss()
for batch in dataloader_test:
    result = classifier.predict(batch["x"])
    loss += loss_func(result["predicts"], batch["y"])
    acc += accuracy(result["predicts"], batch["y"])
loss /= len(dataloader_test)
acc /= len(dataloader_test)
logging.info("Test mean loss: {:.6f}, mean accuracy: {:.6f}".format(loss, acc))