In [None]:
from util import create_dir, replace_dir, Clock, compare_dir, split_parameters
from transform_img import flatten_onehot, Diff_size_collect, get_transform, norm_black_color,get_pretreat_transform
from loss import Soft_dice_loss, Focal_loss, SSIM, activate
from plot import plot_grad_flow, Progress_writer, onehot_gird, Loss_record, Acc_record, Loss_writer,Acc_writer
from dataset.dataset import Image_Dataset, Zip_dataset, get_data_files
from dataset.tarpath import Tar_path
from dataset.lmdb_format import Lmdb_dataset

from torch.utils.data import DataLoader, Subset
from torchsummary import summary

from net.unet import U_Net
from net.nested_unet import NestedUNet
from net.regseg import RegSeg
from net.regseg_p import RegSeg_dp

import tarfile
from pathlib import Path
from itertools import islice
import copy
from datetime import datetime, timedelta
import cv2
from matplotlib import pyplot as plt
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
import matplotlib
from torch.cuda.amp import GradScaler

In [None]:
# 检查cuda
TRAIN_ON_GPU = torch.cuda.is_available()

if not TRAIN_ON_GPU:
    print('CUDA is not available. Training on CPU')
else:
    print('CUDA is available. Training on GPU')

DEVICE = torch.device("cuda:0" if TRAIN_ON_GPU else "cpu")


In [None]:
# 参数设置
# 模型参数
LABEL_C1 = {
    1: "Human"
}
LABEL_C5 = {
    1: "Hair",
    2: "Face",
    3: "body",
    4: "Leg",
    5: "Arm"
}
LABEL_C12 = {
    1: "Hat",
    2: "Hair",
    3: "Arm",
    4: "Sunglasses",
    5: "Clothes",
    6: "Dress",
    7: "Leg",
    8: "Pants",
    9: "Torso-skin",
    10: "Scarf",
    11: "Skirt",
    12: "Face",
}
C12_TO_C5 = {
    1: [1, 2],
    2: [4, 12],
    3: [5, 6, 9, 10],
    4: [7, 8, 11],
    5: [3,],
}

LABEL = LABEL_C5

INPUT_CHANNEL = 3
OUTPUT_CHANNEL = len(LABEL)+1 if len(LABEL) > 1 else 1

DOWNSAMP_MULTI = 4

BATCH_SIZE = 8
GRAD_ACCUMULATE = 1

EPOCH = 50
# 输入图片的大小
IMAGE_SIZE = 512

RANDOM_SEED = torch.default_generator.initial_seed()

# 数据集参数
SHUFFLE = True
PIN_MEMORY = TRAIN_ON_GPU
NUM_WORKERS = 0

EPOCH_SAVE = 5

TEST_SIZE = 10
# 数据目录
DATA_DIR = "../data/graduate/lip_c5_db"

TRAIN_IMG_DIR = "./training"
TRAIN_GD_TRUTH_DIR = "./training_seg"

VAL_IMG_DIR = "./validation"
VAL_GD_TRUTH_DIR = "./validation_seg"

TEST_IMG_DIR = "./test"

# 创建数据的子集,如果小于1,创建对应比例的子集,大于1,创建对应数量的子集
TRAIN_SUBSET = 1
VALID_SUBSET = 1
# 保存数据目录
SAVE_DIR = Path('../model')

# 是否比较训练图片和分割图片目录中的文件完全匹配
COMPARE_FILE_NAME = False

# 是否清除tensorboard的数据目录
CLEAR_TENSOR_BOARD_RUNS = True
# 如果不为None，作为模型参数加载
STATE_DICT_PATH = "../model/complete/RegSeg_align_3to6_e200_b8_s512/model/model_e199.pth"
SAVE_SUB_DIR = None
# 重新开始的下一个epoch,epoch从0开始计数
START_EPOCH = 0

AMP = False


In [None]:
def get_optimizer(model: torch.nn.Module, step_per_epoch: int, epoch: int):
    initial_lr = 1e-4
    lr = 1e-4
    weight_decay = 1e-4

    warm_up_step = 3000
    start_lr_factor = 1e-5

    max_step = epoch * step_per_epoch - warm_up_step
    last_step = START_EPOCH * step_per_epoch - 1

    verbose = False

    para_decay, para_no_decay = split_parameters(model)
    optimizer = torch.optim.Adam([
        {"params": para_decay, "weight_decay": weight_decay, "initial_lr": initial_lr},
        {"params": para_no_decay, "initial_lr": initial_lr}
    ], lr=lr)

    warm_up_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_lr_factor, total_iters=warm_up_step, last_epoch=last_step, verbose=verbose)

    # lam_scheduler = torch.optim.lr_scheduler.LambdaLR(
    #     optimizer,
    #     lambda epoch: (1 - epoch/max_step)**0.9,
    #     last_epoch=last_step, verbose=verbose)

    cos_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, max_step, eta_min=1e-5, last_epoch=last_step, verbose=verbose)

    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer,[warm_up_scheduler, cos_scheduler], [warm_up_step])

    return optimizer, scheduler


In [None]:
# 模型
model = RegSeg_dp(INPUT_CHANNEL, OUTPUT_CHANNEL, 0)
if STATE_DICT_PATH != None:
    model.load_state_dict(torch.load(STATE_DICT_PATH))
# else:
#     state_dict: dict[str,] = torch.load(
#         "../model/RegSeg_dp_3to13_e50_b8_s512_dv3/model/model_e45.pth")
#     model.stem.load_state_dict(
#         {k: v for k, v in state_dict.items() if k.startswith("stem")}, strict=False)
#     model.body.load_state_dict(
#         {k: v for k, v in state_dict.items() if k.startswith("body")}, strict=False)

model.to(DEVICE)

summary(model, input_size=(3, 512, 512), batch_size=-1, device=DEVICE.type)


In [None]:
# 创建需要路径
create_dir(SAVE_DIR)

save_sub_dir = SAVE_DIR / "{}_{}to{}_e{}_b{}_s{}".format(
    model.__class__.__name__,
    INPUT_CHANNEL, OUTPUT_CHANNEL,
    EPOCH, BATCH_SIZE, IMAGE_SIZE)

if SAVE_SUB_DIR == None:
    replace_dir(save_sub_dir)

    state_dir = save_sub_dir / "state"
    model_dir = save_sub_dir / "model"
    tensorboard_runs_dir = save_sub_dir / "runs"
    
    replace_dir(state_dir)
    replace_dir(model_dir)
    if CLEAR_TENSOR_BOARD_RUNS:
        replace_dir(tensorboard_runs_dir)
else:
    save_sub_dir = Path(SAVE_SUB_DIR)

    state_dir = save_sub_dir / "state"
    model_dir = save_sub_dir / "model"
    tensorboard_runs_dir = save_sub_dir / "runs"

    create_dir(save_sub_dir)
    create_dir(state_dir)
    create_dir(model_dir)



In [None]:
# 创建writer
writer = SummaryWriter(log_dir=str(tensorboard_runs_dir))
writer.add_text("Basic/Model",
                """batch size: {}
epoch: {}
image size: {}
random seed: {}
""".format(BATCH_SIZE, EPOCH, IMAGE_SIZE, RANDOM_SEED)
)


In [None]:
# 查看模型结构
def write_graph():
    writer.add_graph(model, torch.rand(
            (BATCH_SIZE, INPUT_CHANNEL, IMAGE_SIZE, IMAGE_SIZE), device=DEVICE),use_strict_trace=False)


In [None]:
# 数据集
# 对比目录中的数据是否匹配
if COMPARE_FILE_NAME:
    if tarfile.is_tarfile(DATA_DIR):
        data_path = Tar_path.make_tar_root_path(DATA_DIR)
    else:
        data_path = Path(DATA_DIR)

    if not compare_dir(data_path/TRAIN_IMG_DIR, data_path/TRAIN_GD_TRUTH_DIR):
        raise RuntimeError("Data dir {} and {} dose not match.")

    if not compare_dir(data_path/VAL_IMG_DIR, data_path/VAL_GD_TRUTH_DIR):
        raise RuntimeError("Data dir {} and {} dose not match.")
    del data_path

# #############################
# transform, target_transform, transform_rm_rand_layer, target_transform_rm_rand_layer = get_transform(IMAGE_SIZE,
#                                                                                                      OUTPUT_CHANNEL)

# # 创建训练数据集
# train_img_dataset = Image_Dataset(
#     get_data_files(TRAIN_IMG_DIR, DATA_DIR), transform)

# train_target_img_dataset = Image_Dataset(
#     get_data_files(TRAIN_GD_TRUTH_DIR, DATA_DIR), target_transform, cv2.IMREAD_GRAYSCALE)

# train_dataset = Zip_dataset(train_img_dataset, train_target_img_dataset)
# # 验证数据集
# val_img_dataset = Image_Dataset(
#     get_data_files(VAL_IMG_DIR, DATA_DIR), transform_rm_rand_layer)

# val_target_img_dataset = Image_Dataset(
#     get_data_files(VAL_GD_TRUTH_DIR, DATA_DIR), target_transform_rm_rand_layer, cv2.IMREAD_GRAYSCALE)

# val_dataset = Zip_dataset(val_img_dataset, val_target_img_dataset)

# # 测试数据集
# test_transform = transform_rm_rand_layer.transforms
# test_transform = torchvision.transforms.Compose(test_transform)


# class Test_Img_dataset(Image_Dataset):

#     def __getitem__(self: Image_Dataset, idx: int):
#         path = self.get_image_path(idx)
#         transform_img, img = self.read_image(path)
#         return transform_img, img


# test_dataset = Test_Img_dataset(get_data_files(
#     TEST_IMG_DIR, DATA_DIR), test_transform)

# Lmdb######################################
transform, target_transform, transform_rm_rand_layer, target_transform_rm_rand_layer = get_pretreat_transform(
    OUTPUT_CHANNEL)

# 创建训练数据集
train_img_dataset = Lmdb_dataset(DATA_DIR, TRAIN_IMG_DIR, transform)

train_target_img_dataset = Lmdb_dataset(
    DATA_DIR, TRAIN_GD_TRUTH_DIR, target_transform)

train_dataset = Zip_dataset(train_img_dataset, train_target_img_dataset)
# 验证数据集
val_img_dataset = Lmdb_dataset(
    DATA_DIR, VAL_IMG_DIR, transform_rm_rand_layer)

val_target_img_dataset = Lmdb_dataset(
    DATA_DIR, VAL_GD_TRUTH_DIR, target_transform_rm_rand_layer)

val_dataset = Zip_dataset(val_img_dataset, val_target_img_dataset)

# 测试数据集
test_transform = [torchvision.transforms.Lambda(
    lambda img:cv2.cvtColor(img, cv2.COLOR_BGR2RGB))] + transform_rm_rand_layer.transforms
test_transform = torchvision.transforms.Compose(test_transform)


class Test_lmdb_dataset(Lmdb_dataset):

    def __getitem__(self: Image_Dataset, idx: int):
        path = self.get_image_path(idx)
        transform_img, img = self.read_image(path)
        return transform_img, img


test_dataset = Test_lmdb_dataset(DATA_DIR, TEST_IMG_DIR, test_transform)

# 使用一部分数据测试代码
if TRAIN_SUBSET < 1.:
    train_dataset = Subset(train_dataset,
                           list(range(int(len(train_dataset)*TRAIN_SUBSET))))
elif TRAIN_SUBSET > 1.:
    train_dataset = Subset(train_dataset,
                           list(range(int(TRAIN_SUBSET))))
if VALID_SUBSET < 1.:
    val_dataset = Subset(val_dataset,
                         list(range(int(len(val_dataset)*VALID_SUBSET))))
elif VALID_SUBSET > 1.:
    val_dataset = Subset(val_dataset,
                         list(range(int(VALID_SUBSET))))

# 测试dataset
for x, y in islice(train_dataset, 1):
    print(torch.bincount(y.int().flatten()))
    plt.figure()
    plt.subplot(211)
    plt.imshow(x.permute((1, 2, 0)))
    plt.subplot(212)
    plt.imshow(onehot_gird(y).permute((1, 2, 0)))
    plt.show()


In [None]:
# 数据批处理

gd_black = torch.zeros((OUTPUT_CHANNEL, 1, 1))
if OUTPUT_CHANNEL > 1:
    gd_black[0, ...] = torch.tensor([1])

data_loader_para = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "pin_memory": PIN_MEMORY,
    "num_workers": NUM_WORKERS,
    "collate_fn": lambda imgs: Diff_size_collect.collect_fn(
        imgs, DOWNSAMP_MULTI,
        black={
            0: torch.tensor(norm_black_color).reshape((3, 1, 1)),
            1: gd_black
        }),
}

train_data_loader = DataLoader(
    train_dataset, **copy.copy(data_loader_para)
)

valid_data_loader = DataLoader(
    val_dataset, **copy.copy(data_loader_para)
)


In [None]:
def plot_test_img(inputs: torch.Tensor, predictions: tuple[torch.Tensor]):
    # 将所有图像上采样到统一大小
    size = inputs.shape[2:4]
    predictions = [F.interpolate(i, size) for i in predictions]

    result: list[torch.Tensor] = []
    for input, *preds in zip(inputs, *predictions):
        # 前景分割图像
        pred_label = flatten_onehot(preds[-1])
        front_seg = input.clone()
        front_seg[:, pred_label == 0] = torch.tensor(
            norm_black_color).reshape((3, 1))

        perform_imgs = torchvision.utils.make_grid(
            [input, front_seg], normalize=True)
        output_imgs = [perform_imgs] + [onehot_gird(p) for p in preds]

        width = max(*[img.size(-1) for img in output_imgs])
        coll_out_imgs = []
        for img in output_imgs:
            left_len = width - img.size(-1)
            if left_len != 0:
                complement = torch.zeros((3, 1, 1), dtype=img.dtype).expand(
                    (3, img.size(1), left_len))
                img = torch.cat([img, complement], -1)
            coll_out_imgs.append(img.cpu())

        result.append(torch.cat(coll_out_imgs, 1))
    return result


def writer_output(output, index, epoch):
    writer.add_image("Test/Prediction_{}".format(index), output, epoch)


def file_output(out_dir: Path, output: torch.Tensor, index, epoch):
    if not isinstance(out_dir, Path):
        out_dir = Path(out_dir)

    out_file = out_dir/"e{}i_{}.jpg".format(epoch, index)
    torchvision.utils.save_image(output, out_file)


def test(model: torch.nn.Module, epoch: int, test_size=4, output_fn=writer_output):
    """使用测试数据测试模型"""
    with torch.no_grad():
        # 获取图像
        if test_size < 0:
            rands = torch.arange(0, len(test_dataset))
        else:
            rands = torch.randint(0, len(test_dataset), (test_size,))

        for i, r in enumerate(rands):
            test_img, _ = test_dataset[r]

            test_imgs, = Diff_size_collect.collect_fn(
                [(test_img,)], DOWNSAMP_MULTI,
                black={
                    0: torch.tensor(norm_black_color).reshape((3, 1, 1))
                })

            # 预测
            pred_test_imgs = model(test_imgs.to(DEVICE))
            if not isinstance(pred_test_imgs, (tuple, list)):
                pred_test_imgs = (pred_test_imgs,)

            pred_test_imgs = [activate(i) for i in pred_test_imgs]

            output, = plot_test_img(test_imgs, pred_test_imgs)
            output_fn(output, i, epoch)


def test_dataloader():
    with torch.no_grad():
        for x, y in islice(train_data_loader, 1):
            test_outputs = plot_test_img(x, (y,))
            for i, out in enumerate(test_outputs):
                writer.add_image(
                    "Test TR DataLoader/target_{}".format(i), out)
                
        for x, y in islice(valid_data_loader, 1):
            test_outputs = plot_test_img(x, (y,))
            for i, out in enumerate(test_outputs):
                writer.add_image(
                    "Test VL DataLoader/target_{}".format(i), out)


test(model, START_EPOCH, TEST_SIZE)
# test(model, 0, -1, lambda *args: file_output("../tmp/test2", *args))

test_dataloader()


In [None]:
# 优化器
optimizer, scheduler = get_optimizer(
    model, len(train_data_loader), EPOCH)
##################################
# 损失函数


class Binary_loss(torch.nn.Module):
    def __init__(self,
                 focal: torch.nn.Module,
                 ssim: torch.nn.Module,
                 dice: torch.nn.Module) -> None:
        super().__init__()
        self.focal_loss = focal
        self.ssim_loss = ssim
        self.dice_loss = dice

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        loss = self.focal_loss(input, target)
        input = torch.sigmoid(input)
        loss += self.ssim_loss(input, target) + self.dice_loss(input, target)
        return loss

    @staticmethod
    def mk_loss_record(focal_loss, ssim_loss, dice_loss, step: int):
        focal_record = Loss_record(focal_loss, step)
        ssim_record = Loss_record(ssim_loss, step)
        dice_record = Loss_record(dice_loss, step)

        loss = Binary_loss(focal_record, ssim_record, dice_record)
        loss_record = Loss_record(loss, step)
        return focal_record, ssim_record, dice_record, loss_record


class Multi_class_loss(torch.nn.Module):
    def __init__(self,
                 focal: torch.nn.Module,
                 ssim: torch.nn.Module,
                 dice: torch.nn.Module) -> None:
        super().__init__()
        self.focal_loss = focal
        self.ssim_loss = ssim
        self.dice_loss = dice

    def forward(self, input: torch.Tensor, target: torch.Tensor):
        loss = self.focal_loss(input, target) if self.focal_loss != None else 0
        input = torch.softmax(input, 1)
        loss += self.ssim_loss(input, target) if self.ssim_loss != None else 0
        loss += self.dice_loss(input, target) if self.dice_loss != None else 0
        return loss

    @staticmethod
    def mk_loss_record(focal_loss, ssim_loss, dice_loss, step: int):
        records: list[Loss_record] = []
        for loss_fn in (focal_loss, ssim_loss, dice_loss):
            if loss_fn != None:
                records.append(Loss_record(loss_fn, step))
            else:
                records.append(None)

        loss = Multi_class_loss(*records)
        loss_record = Loss_record(loss, step)
        return *[record for record in records if record != None], loss_record


class Deep_sup_loss(torch.nn.Module):
    def __init__(self, *layer_loss: torch.nn.Module, weights=None) -> None:
        super().__init__()
        self.layer_loss = layer_loss
        self.weights = weights

        if self.weights == None:
            self.weights = torch.empty((len(self.layer_loss),)).fill_(1)

    def forward(self, *input: torch.Tensor, target: torch.Tensor):
        loss = []
        for loss_fn, in_lay in zip(self.layer_loss, reversed(input)):
            if in_lay.size() != target.size():
                target = F.interpolate(
                    target, in_lay.size()[-2:], mode="nearest-exact")
            loss.append(loss_fn(in_lay, target))

        return (torch.stack(loss) * self.weights.to(target.device)).sum()

    @staticmethod
    def mk_loss_record(focal_loss, ssim_loss, dice_loss, layer_weights: list[float], step: int):

        records = []
        layer_loss = []
        for i in range(1, len(layer_weights)+1):
            r = Multi_class_loss.mk_loss_record(
                focal_loss, ssim_loss, dice_loss, step)
            layer_loss.append(r[-1])
            records.append(list(r))

        deep_sup_loss = Deep_sup_loss(*layer_loss, weights=layer_weights)
        deep_sup_record = Loss_record(deep_sup_loss, step)
        records.append([deep_sup_record])
        return records


if OUTPUT_CHANNEL == 1:
    focal_loss = Focal_loss(alpha=0.5, mul_class=False)
    ssim_loss = SSIM(activated=True)
    dice_loss = Soft_dice_loss(activated=True)

    train_loss = Binary_loss.mk_loss_record(
        focal_loss, ssim_loss, dice_loss, len(train_data_loader))
    valid_loss = Binary_loss.mk_loss_record(
        focal_loss, ssim_loss, dice_loss, len(valid_data_loader))
else:
    focal_loss = Focal_loss(1.)
    # ssim_loss = SSIM(activated=True)
    ssim_loss = None
    dice_loss = Soft_dice_loss(activated=True)

    train_loss = Multi_class_loss.mk_loss_record(
        focal_loss, ssim_loss, dice_loss, len(train_data_loader))
    valid_loss = Multi_class_loss.mk_loss_record(
        focal_loss, ssim_loss, dice_loss, len(valid_data_loader))

loss_writer = Loss_writer(
    *[(l, "Training") for l in train_loss],
    *[(l, "Validation") for l in valid_loss], writer=writer)

train_loss: Loss_record = train_loss[-1]
valid_loss: Loss_record = valid_loss[-1]


In [None]:
# 监控训练过程
clock = Clock()
print_timer = clock.set_timer(timedelta(seconds=30))


def write_now_time():
    writer.add_text("Basic/time", str(datetime.now()))


progress_writer = Progress_writer(writer, "Basic/Progress")

acc_record = Acc_record(len(valid_data_loader.dataset), len(LABEL)+1)
acc_writer = Acc_writer(writer, acc_record, LABEL)
# 记录epoch数据
valid_loss_min = np.Inf


In [None]:
# 训练步骤
scaler = GradScaler()


def train_step(batch_index: int, epoch: int, x: torch.Tensor, y: torch.Tensor):
    x, y = x.to(DEVICE), y.to(DEVICE)
    # 训练步骤
    optimize = (batch_index+1) % GRAD_ACCUMULATE == 0
    if GRAD_ACCUMULATE == 1 or (batch_index+1) % GRAD_ACCUMULATE == 1:
        optimizer.zero_grad()

    if AMP:
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            prediction: tuple[torch.Tensor] = model(x)
            if not isinstance(prediction, (tuple, list)):
                prediction = (prediction,)
            loss: torch.Tensor = train_loss(*prediction, target=y)

        scaler.scale(loss).backward()

        if optimize:
            scaler.step(optimizer)
            scaler.update()
    else:
        prediction: tuple[torch.Tensor] = model(x)
        if not isinstance(prediction, (tuple, list)):
            prediction = (prediction,)
        loss: torch.Tensor = train_loss(*prediction, target=y)

        loss.backward()
        if optimize:
            optimizer.step()

    scheduler.step()

    # 记录
    for tensor_ in x, y, *prediction, loss:
        tensor_.detach_()

    total_step = epoch * len(train_data_loader) + batch_index

    if clock.is_timeout(print_timer) or batch_index == len(train_data_loader)-1:
        write_now_time()
        # 统计每一层梯度
        fig = plot_grad_flow(model.named_parameters())
        writer.add_figure(
            "Training/Gradients", fig, total_step)
        # 统计loss分布变化
        loss_writer.write_histogram(total_step, "Training")
        # 记录进度
        progress_writer.plot(batch_index+1, len(
            train_data_loader), "Training Batch")

        writer.add_scalar("Basic/learning rate",
                          scheduler.get_last_lr()[0], total_step)

    if batch_index == len(train_data_loader)-1:
        # 统计梯度和权重的分布
        for name, param in model.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(
                "Model weight/{}".format(name),
                param.data.detach().cpu(), epoch)
            if param.grad != None:
                writer.add_histogram(
                    "Model gradient/{}".format(name),
                    param.grad.data.detach().cpu(), epoch)


def valid_step(batch_index: int, epoch: int, x: torch.Tensor, y: torch.Tensor):

    x, y = x.to(DEVICE), y.to(DEVICE)

    prediction: tuple[torch.Tensor] = model(x)
    if not isinstance(prediction, (tuple, list)):
        prediction = (prediction,)
    loss: torch.Tensor = valid_loss(*prediction, target=y)

    total_step = epoch * len(valid_data_loader) + batch_index

    # 计算IOU
    prediction = activate(prediction[-1])
    
    if prediction.size(1) > 1 and y.size(1) > 1:
        prediction = torch.argmax(prediction, dim=1, keepdim=True)
        y = torch.argmax(y, dim=1, keepdim=True)
    elif prediction.size(1) > 1 and y.size(1) == 1:
        prediction = torch.argmax(
            prediction, dim=1, keepdim=True) > 0
        prediction = prediction.to(torch.int32)
        y = y.to(torch.int32)
    elif prediction.size(1) == 1 and y.size(1) > 1:
        prediction = prediction > 0.5
        prediction = prediction.to(torch.int32)
        
        if y.size(1) != 1:
            y = torch.argmax(y, dim=1, keepdim=True) > 0
            y = y.to(torch.int32)


    acc_record.calculate(prediction, y)

    if clock.is_timeout(print_timer) or batch_index == len(valid_data_loader)-1:
        write_now_time()
        # 统计loss分布
        loss_writer.write_histogram(total_step, "Validation")
        # 统计每一类的iou分布
        acc_writer.write_histogram(total_step)

        # 记录进度
        progress_writer.plot(batch_index+1, len(valid_data_loader),
                             "Validation Batch")


def record_epoch_data(current_epoch: int):

    global valid_loss_min

    # 记录loss变化
    loss_writer.write_scalas("Loss", current_epoch)
    # 记录iou变化
    acc_writer.write_scalas(current_epoch)
    # 记录进度
    progress_writer.plot(current_epoch+1, EPOCH, "Epoch")
    # 保存模型
    mean_valid_loss = valid_loss.log.mean()
    if mean_valid_loss <= valid_loss_min and EPOCH_SAVE <= current_epoch:
        writer.add_text("Basic/Saving model",
                        'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model'.format(valid_loss_min, mean_valid_loss), current_epoch)

        torch.save(model.state_dict(),
                   str(model_dir / 'model_e{}.pth'.format(current_epoch)))

        valid_loss_min = mean_valid_loss

    writer.flush()


In [None]:
# 训练
matplotlib.use("agg")
for epoch_index in range(START_EPOCH, EPOCH):

    model.train()

    for batch_index, (x, y) in enumerate(train_data_loader):
        train_step(batch_index, epoch_index, x, y)

    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(valid_data_loader):
            valid_step(batch_index, epoch_index, x, y)

        test(model, epoch_index, TEST_SIZE)
        record_epoch_data(epoch_index)


In [None]:
writer.close()

In [None]:
# 测试性能
def test_perf():
    wait = 5
    warmup = 5
    active = 10
    repeat = 2
    with torch.profiler.profile(
            schedule=torch.profiler.schedule(
                wait=wait, warmup=warmup, active=active, repeat=repeat),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                "./runs/performance"),
            record_shapes=True,
            profile_memory=True,
            with_stack=True
    ) as prof:
        model.train()

        for batch_index, (x, y) in enumerate(train_data_loader):
            if batch_index >= (wait+warmup+active) * repeat:
                break
            train_step(batch_index, 0, x, y)
            prof.step()
