In [None]:
# 依照https://blog.csdn.net/chenzhoujian_/article/details/106873451实现

import mindspore as ms
import mindspore.nn as nn
import mindspore.dataset as ds
from mindspore import ops, Tensor, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
import numpy as np
import os
from matplotlib import pyplot as plt
from PIL import Image

from SegNet import create_dataset, SegNet

In [None]:
i = 0

def make_train_txt(num):
    global i
    paths = glob.glob(r"/root/autodl-tmp/leftImg8bit/train/*/*")
    txt = open("./dataset-list/train.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break
    i = 0
    txt.close()

def make_test_txt(num):
    global i
    paths = glob.glob(r"/root/autodl-tmp/leftImg8bit/test/*/*")
    txt = open("./dataset-list/test.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break
    i = 0
    txt.close()

def make_val_txt(num):
    global i
    paths = glob.glob(r"/root/autodl-tmp/leftImg8bit/val/*/*")
    txt = open("./dataset-list/val.txt", "w")

    for path in paths:
        data = path + " " + path.replace("leftImg8bit", "gtFine").replace("gtFine.png", "gtFine_labelTrainIds.png") + "\n"
        txt.write(data)
        i = i + 1
        if i == num:
            break
    i = 0
    txt.close()

train_num = 2972
test_num = 500
val_num = 1525

make_train_txt(train_num)
make_test_txt(test_num)
make_val_txt(val_num)

In [None]:
# 手动设置超参数
CLASS_NUM = 19
CATE_WEIGHT = [1.0] * CLASS_NUM
EPOCH = 20
BATCH_SIZE = 4
LR = 0.01
MOMENTUM = 0.9
# 为什么要用这个矩阵https://blog.csdn.net/fanzonghao/article/details/85263553
# 该矩阵生成文件为couter.py
CATE_WEIGHT = [
    0.11921289124514069, 0.9772031489113517, 0.2606578051907899,
    9.068186030082103, 6.772279222279968, 4.845227365263553,
    28.52810833819015, 10.758335113118157, 0.3736856892826064,
    5.133604194351756, 1.4827554927399786, 4.886283011781665,
    44.10664265269921, 0.8495922090922964, 22.22468639649049,
    25.267384685741828, 25.526398354063044, 60.29661028702076,
    14.370822828405153
]
TXT_PATH = "./dataset-list/train.txt"
PRE_TRAINING = "vgg16_bn-6c64b313.pth"
WEIGHTS = "./weights/"

# 确保权重保存路径存在
if not os.path.exists(WEIGHTS):
    os.makedirs(WEIGHTS)

# 加载数据
train_data = create_dataset(txt_path=TXT_PATH, batch_size=BATCH_SIZE)

In [None]:
def train(SegNet, train_data):

    # 加载预训练权重
    SegNet.load_weights(PRE_TRAINING)

    # 构造数据集
    train_dataset = ds.GeneratorDataset(
        source=train_data, column_names=["image", "label"], shuffle=True
    )
    train_dataset = train_dataset.batch(BATCH_SIZE)

    # 定义优化器
    optimizer = nn.SGD(SegNet.trainable_params(), learning_rate=LR, momentum=MOMENTUM)

    # 定义损失函数
    weight = Tensor(np.array(CATE_WEIGHT).astype(np.float32))
    loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

    # 设置模型
    model = Model(SegNet, loss_func=loss_func, optimizer=optimizer)

    # 定义 Loss 监控器
    loss_monitor = LossMonitor()

    # 保存训练 Loss
    losses = []

    print("Start Training...")
    for epoch in range(EPOCH):
        for step, data in enumerate(train_dataset.create_dict_iterator()):
            b_x = data["image"]
            b_y = data["label"]
            b_y = ops.Reshape()(b_y, (BATCH_SIZE, 224, 224))  # 确保标签形状正确

            # 前向计算和梯度更新
            output = SegNet(b_x)
            loss = loss_func(output, b_y)
            losses.append(loss.asnumpy().item())

            optimizer.clear_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            if step % 100 == 0:
                print(f"Epoch: {epoch} || Step: {step} || Loss: {loss.asnumpy():.4f}")

    # 保存模型权重
    save_path = WEIGHTS + "SegNet_weights.ckpt"
    ms.save_checkpoint(SegNet, save_path)
    print(f"Model saved to {save_path}")

    # 绘制 Loss 曲线
    plt.figure(figsize=(10, 6))
    plt.plot(losses, label="Training Loss")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.legend()
    plt.grid()
    plt.savefig("Loss.svg")
    plt.show()

    return losses

# 初始化模型并开始训练
SegNet = SegNet(3, CLASS_NUM)
losses = train(SegNet, train_data)

In [None]:
# 颜色映射表
COLORS = [
    [128, 64, 128],  # 'road'
    [244, 35, 232],  # 'sidewalk'
    [70, 70, 70],    # 'building'
    [102, 102, 156], # 'wall'
    [190, 153, 153], # 'fence'
    [153, 153, 153], # 'pole'
    [250, 170, 30],  # 'traffic light'
    [220, 220, 0],   # 'traffic sign'
    [107, 142, 35],  # 'vegetation'
    [152, 251, 152], # 'terrain'
    [70, 130, 180],  # 'sky'
    [220, 20, 60],   # 'person'
    [255, 0, 0],     # 'rider'
    [0, 0, 142],     # 'car'
    [0, 0, 70],      # 'truck'
    [0, 60, 100],    # 'bus'
    [0, 80, 100],    # 'train'
    [0, 0, 230],     # 'motorcycle'
    [119, 11, 32]    # 'bicycle'
]

MODE = 1
SAMPLES = "samples/"
OUTPUTS = "outputs/"
WEIGHTS = "weights/SegNet_weights.pth"

alpha = 0.6

In [None]:
def test(SegNet):
    if MODE:
        # 加载预训练权重
        ms.load_checkpoint(WEIGHTS, net=SegNet)
    SegNet.set_train(False)

    # 获取所有测试图片
    paths = os.listdir(SAMPLES)

    for path in paths:
        # 加载并预处理图像
        image_src = Image.open(os.path.join(SAMPLES, path)).convert("RGB")  # 加载图像并转换为RGB模式
        image_src_resized = image_src.resize((224, 224))  # 调整大小到网络输入
        image = np.array(image_src_resized) / 255.0  # 归一化到 [0, 1]
        image = np.transpose(image, (2, 0, 1))  # 转换为 (C, H, W)
        image = Tensor(image[np.newaxis, ...], dtype=mindspore.float32)  # 添加批次维度

        # 模型推理
        output = SegNet(image)
        output = ops.ResizeBilinear((1024, 2048), align_corners=False)(output)  # 调整到目标尺寸
        output = ops.Argmax(axis=1)(output).asnumpy().squeeze()  # 获取预测类别并移除批次维度

        # 生成分割图像
        image_seg = np.zeros((1024, 2048, 3), dtype=np.uint8)
        for c in range(CLASS_NUM):
            mask = (output == c)
            image_seg[mask] = COLORS[c]

        # 将分割结果与原图结合（透明度混合）
        image_src_resized_back = image_src.resize((2048, 1024))  # 将原图调整到分割结果大小
        image_src_array = np.array(image_src_resized_back)  # 转换为数组
        alpha = 0.5  # 设置透明度
        image_blend = (image_src_array * (1 - alpha) + image_seg * alpha).astype(np.uint8)  # 混合

        # 保存结果
        os.makedirs(OUTPUTS, exist_ok=True)
        result = Image.fromarray(image_blend)
        result.show()  # 显示图片
        result.save(os.path.join(OUTPUTS, path))
        # print(f"{path} is done!")

# 调用测试函数
test(SegNet)

In [None]:
VAL_PATHS = "./dataset-list/val.txt"

def calculate_mIoU(val_paths, model, class_num):
    mIoU = []
    with open(val_paths, "r") as paths:
        for index, line in enumerate(paths):
            line = line.strip()
            path = line.split()

            # 加载图像
            image = cv.imread(path[0])
            image = cv.resize(image, (224, 224))
            image = image / 255.0  # 归一化输入
            image = np.transpose(image, (2, 0, 1))  # 转换为 (C, H, W)
            image = Tensor(image[np.newaxis, ...], dtype=mindspore.float32)  # 添加批次维度

            # 模型推理
            model.set_train(False)
            output = model(image)
            output = ops.Argmax(axis=1)(output).asnumpy().squeeze()  # 获取类别预测结果
            predict = cv.resize(np.uint8(output), (2048, 1024))  # 调整大小

            # 加载标签
            label = cv.imread(path[1], cv.IMREAD_GRAYSCALE)

            # 计算 IoU
            intersection, union = [], []
            for i in range(1, class_num):
                intersect = np.sum((predict == i) & (label == i))
                union_area = np.sum(predict == i) + np.sum(label == i) - intersect
                intersection.append(intersect)
                union.append(union_area)

            iou = [inter / u if u > 0 else 0 for inter, u in zip(intersection, union)]
            mIoU.append(np.mean(iou))

            print(f"miou_{index}: {mIoU[index]:.4f}")
    return mIoU

mIoU = calculate_mIoU(VAL_PATHS, SegNet, CLASS_NUM)

result_file = "result.txt"
mean_mIoU = np.mean(mIoU)
print("\n")
print(f"mIoU: {mean_mIoU:.4f}")

with open(result_file, "a") as file:
    file.write(f"评价日期：{time.asctime(time.localtime(time.time()))}\n")
    file.write(f"使用的权重：{WEIGHTS}\n")
    file.write(f"mIoU: {mean_mIoU:.4f}\n")