In [None]:
import os
import logging
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchmetrics import Accuracy
from init import seed_everything, dump_args, data_preprocess, test_preprocess
from dataset import CharDataset
from models import LeNet, ResNet18
from classifier import CNNClassifier

接着使用 `args` 存储相关的参数，同时使用 `dump_args` 导出本次运行的参数到文件，如果要保存日志文件，须设置参数 `log=True` 。

In [None]:
args = {}
args["model"] = "ResNet18"                   # ["LeNet", "ResNet18"]
args["epoches"] = 50
args["batch_size"] = 64
args["learning_rate"] = 0.001
args["record_path"] = "./record"
args["save_path"] = "./save"                 # 完整路径是 f"{args['save_path']}/{args['model']}_best.ckpt"
args["raw_data_path"] = "./data/raw"
args["data_path"] = "./data/data.npz"
args["random_seed"] = 42
args["mode"] = "train"                       # ["train", "test", "train_and_test"]

args["save_path"] = os.path.join(args["save_path"], f"{args['model']}_best.ckpt")
record_path = dump_args(args, args["record_path"], True, args["mode"])

接着使用 `args` 存储相关的参数，同时使用 `dump_args` 导出本次运行的参数到文件。

In [None]:
seed_everything(args["random_seed"])

尝试加载训练数据，如果训练数据不存在，则通过 `data_preprocess` 生成。加载数据后初始化训练集和验证集。

In [None]:
if not os.path.exists(args["data_path"]):
    data_preprocess(args["raw_data_path"], args["data_path"])
data_npz = np.load(args["data_path"])

if args["model"] == "LeNet":
    transform = Compose([ToTensor(), Resize((32,32), antialias=False)])
elif args["model"] == "ResNet18":
    transform = ToTensor()
else:
    raise NotImplementedError

dataset_train = CharDataset(data_npz["x_train"], data_npz["y_train"], transform)
dataloader_train = DataLoader(dataset_train, args["batch_size"], shuffle=True)

dataset_valid = CharDataset(data_npz["x_valid"], data_npz["y_valid"], transform)
dataloader_valid = DataLoader(dataset_valid, args["batch_size"], shuffle=False)

In [None]:
if args["model"] == "LeNet":
    model = LeNet(1, 12)
elif args["model"] == "ResNet18":
    model = ResNet18(1, 12)
else:
    raise NotImplementedError

classifier = CNNClassifier(model, gpu=0)
if os.path.exists(args["save_path"]):
    classifier.load_model(args["save_path"])

In [None]:
if args["mode"] == "train" or args["mode"] == "train_and_test":
    classifier.fit(
        train_loader=dataloader_train,
        valid_loader=dataloader_valid,
        epoches=args["epoches"],
        learning_rate=args["learning_rate"],
        save_path=args["save_path"],
        log_interval=1,
    )

In [None]:
if args["mode"] != "test" and args["mode"] != "train_and_test":
    exit(0)

raw_test_path = ""
test_path = ""
if not os.path.exists(test_path):
    test_preprocess(raw_test_path, test_path, shuffle=True)
test_npz = np.load(test_path)

dataset_test = CharDataset(test_npz["x_test"], test_npz["y_test"], transform)
dataloader_test = DataLoader(dataset_test, args["batch_size"], shuffle=False)

device = classifier.device

test_loss = 0.0
test_acc = 0.0
loss_fn = F.cross_entropy
metric = Accuracy(task="multiclass", num_classes=12, top_k=1).to(device)
for batch in dataloader_test:
    x, y = batch
    x = x.to(device)
    y = y.long().to(device)
    preds = classifier.predict(x)
    loss = loss_fn(preds, y)
    test_loss += loss.item()
    metric.update(preds, y)
test_loss /= len(dataloader_test)
test_acc = metric.compute()
logging.info("Test mean loss: {:.6f}, mean accuracy: {:.6f}".format(test_loss, test_acc))