In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
#进入需要训练项目的文件夹
%cd /content/drive/My Drive/Colab Notebooks/ML2020_spring_crack_detection/CrackU-Net-mean-squared-error

/content/drive/My Drive/Colab Notebooks/ML2020_spring_crack_detection/CrackU-Net-mean-squared-error


In [None]:
# TPU相关-------------------TPU相关---------------------TPU相关----------
# 安装好TPU计算相应的包
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'
DIST_BUCKET="gs://tpu-pytorch/wheels"
TORCH_WHEEL="torch-1.15-cp36-cp36m-linux_x86_64.whl"
TORCH_XLA_WHEEL="torch_xla-1.15-cp36-cp36m-linux_x86_64.whl"
TORCHVISION_WHEEL="torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl"

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5

In [None]:
# TPU相关-------------------TPU相关---------------------TPU相关----------
# 导入相关的库
# import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.data_parallel as dp


# 后续对于torchTensor类型的量可以使用 .to(xm.xla_device()) 将torch.FloatTensor转换成 xla tensor

In [4]:
!pip install xlutils



In [0]:
from models_crack_unet import SegmentNet, weights_init_normal
from dataset_crack_unet import CFDDataset

import torch.nn as nn
import torch

from torchvision import datasets
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader

import os
import sys
import argparse
import time
import PIL.Image as Image

import numpy as np

import xlwt
import xlrd
from xlutils.copy import copy

In [None]:
# TPU相关-------------------TPU相关---------------------TPU相关----------
# 获取TPU的设备数，并且根据设备数准备并行计算
devices = (xm.get_xla_supported_devices(max_devices=num_cores) if num_cores != 0 else [])

print("Devices: {}".format(devices))

# 使用TPU设备来运行模型
# model_parallel = dp.DataParallel(SegmentNet, device_ids=devices)

In [6]:
# 在ipynb文件中，parse的创建用函数来创建
# 直接用parser=parser = argparse.ArgumentParser() 来创建之后调用会报错

def get_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument("--tpu", type=bool, default=True, help="tpu")
    parser.add_argument("--tpu_num", type=int, default=8, help="number of gpu") #TPU共有8个设备
    parser.add_argument("--worker_num", type=int, default=4, help="number of input workers") 
    parser.add_argument("--batch_size", type=int, default=8, help="batch size of input")
    parser.add_argument("--lr", type=float, default=0.0005, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")

    parser.add_argument("--begin_epoch", type=int, default=2, help="begin_epoch")
    parser.add_argument("--end_epoch", type=int, default=101, help="end_epoch")

    parser.add_argument("--need_test", type=bool, default=True, help="need to test")
    parser.add_argument("--test_interval", type=int, default=2, help="interval of test")
    parser.add_argument("--need_save", type=bool, default=True, help="need to save")
    parser.add_argument("--save_interval", type=int, default=2, help="interval of save weights")


    parser.add_argument("--img_width", type=int, default=480, help="size of image width")
    parser.add_argument("--img_height", type=int, default=320, help="size of image height")
    
    return parser.parse_args(args=[])

opt = get_arguments()

print(opt)

Namespace(b1=0.5, b2=0.999, batch_size=4, begin_epoch=2, cuda=True, end_epoch=101, gpu_num=1, img_height=320, img_width=480, lr=0.0005, need_save=True, need_test=True, save_interval=2, test_interval=2, worker_num=0)


In [0]:
dataSetRoot = "../Data" 

In [0]:
# 建立网络
segment_net = SegmentNet(init_weights=True)

# 选择均方误差损失函数
criterion_segment  = torch.nn.MSELoss()

In [0]:
# 选择训练环境和参数
if opt.tpu:
    #segment_net = segment_net.cuda()
    #criterion_segment.cuda()
    segment_net = dp.DataParallel(SegmentNet, device_ids=devices)
    criterion_segment = dp.DataParallel( criterion_segment, device_ids=devices)
    
#if opt.gpu_num > 1:
#    segment_net = torch.nn.DataParallel(segment_net, device_ids=list(range(opt.gpu_num)))

if opt.begin_epoch != 0:
    # 加载前期训练的模型
    segment_net.load_state_dict(torch.load("./saved_models/segment_net_%d.pth" % (opt.begin_epoch)))
else:
    # 初始化权重
    segment_net.apply(weights_init_normal)

In [0]:
# Optimizers
optimizer_seg = torch.optim.Adam(segment_net.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

In [0]:
# 对原始数据和真实值进行一定前期处理，方便后续训练
transforms_ = transforms.Compose([
    transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
    transforms.ToTensor()
])

transforms_mask = transforms.Compose([
    transforms.Resize((opt.img_height, opt.img_width)), 
    transforms.ToTensor()
])

In [0]:
# 加载训练集和测试集
trainCFDloader = DataLoader(
    CFDDataset(dataSetRoot, transforms_=transforms_, transforms_mask= transforms_mask, subFold="CFD", isTrain=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.worker_num,
)

testloader = DataLoader(
    CFDDataset(dataSetRoot, transforms_=transforms_, transforms_mask= transforms_mask,  subFold="CFD/cfd_TEST", isTrain=True),
    batch_size=1,
    shuffle=False,
    num_workers=opt.worker_num
)

In [0]:
# 定义对输出结果进行阈值化处理的函数，将小于阈值的计算为0，大于阈值的计算为1，在图像中像素为1的点为白色，为0的点为黑色
def data_threshold(data, threshold):
    # threshold = torch.Tensor([threshold]).cuda()
    # data_target = torch.Tensor([i//threshold for i in data]).cuda()
    threshold = torch.Tensor([threshold]).to(xm.xla_device())
    data_target = torch.Tensor([i//threshold for i in data]).to(xm.xla_device())
    return data_target

In [0]:
# 获取训练之后的 accuracy, precision , recall ,F1评价指标
# mask是真实值，data是预测值
def evaluate_metric(mask, data):
    
    count_TP, count_FN, count_FP, count_TN = 0, 0, 0, 0
    
    for i in range(len(mask)):
        if mask[i]==1 and data[i]==1:
            count_TP += 1  
        elif mask[i]==1:
            count_FN += 1      
        elif data[i]==1:
            count_FP += 1      
        else:
            count_TN += 1

    count = count_TP + count_FN + count_FP +count_TN
    
    # 准确率
    accuracy = (count_TP+count_TN)/count 
    # 精准率
    precision = count_TP / (count_TP + count_FP) 
    # 查全率
    recall = count_TP / (count_TP + count_FN)
    # F1分
    F1 = 2*count_TP/(2*count_TP + count_FP + count_FN)
    
    return accuracy, precision, recall, F1

In [0]:
if opt.begin_epoch == 0:
  # 将训练过程的需要保存的数据保存到xls文件中
  # 创建一个workbook，设置编码
  workbook = xlwt.Workbook(encoding = 'utf-8')

  #---------------------写入训练过程的epoch, loss, accuracy------------------
  # 创建一个worksheet
  worksheet = workbook.add_sheet('sheet1')
  worksheet.write(0, 0, 'epoch')
  worksheet.write(0, 1, 'loss')
  worksheet.write(0, 2, 'accuracy')

  #---------------------写入测试过程的epoch, loss, accuracy------------------
  worksheet = workbook.add_sheet('sheet2')
  worksheet.write(0, 0, 'epoch')
  worksheet.write(0, 1, 'loss')
  worksheet.write(0, 2, 'accuracy')
                  
  #---------------------写入测试过程的epoch, accuracy, precision, recall, F1------------------                
  worksheet = workbook.add_sheet('sheet3')
  worksheet.write(0, 0, 'epoch')
  worksheet.write(0, 1, 'accuracy')
  worksheet.write(0, 2, 'precision')
  worksheet.write(0, 3, 'recall')
  worksheet.write(0, 4, 'F1')

  # 保存
  workbook.save('evaluate_data.xls')

def write_excel_xls_append(path, value, sheet_num):
  index = len(value)  # 获取需要写入数据的行数
  workbook = xlrd.open_workbook(path)            # 打开工作簿
  sheets = workbook.sheet_names()                # 获取工作簿中的所有表格
  worksheet = workbook.sheet_by_name(sheets[sheet_num])  # 获取工作簿中所有表格中的的第一个表格
  rows_old = worksheet.nrows                 # 获取表格中已存在的数据的行数
  new_workbook = copy(workbook)                # 将xlrd对象拷贝转化为xlwt对象
  new_worksheet = new_workbook.get_sheet(sheet_num)      # 获取转化后工作簿中的第sheet_num个表格
  
  for i in range(0, index):
      new_worksheet.write(rows_old, i, value[i])  # 追加写入数据
  new_workbook.save(path)  # 保存工作簿
  print("xls格式表格[追加]写入数据成功！")

In [2]:
# 定义模型的训练函数
def train_my_model(epoch, iter_data, lenNum):
    # -----------------------------------------------------------------------------
    # 开始训练
    # 记录每一个epoch的总损失和总精度
    train_loss_sum, train_acc_sum, batch_count = 0.0, 0.0, 0.0

    for i in range(0, lenNum):

        # batchData = iterCFD.__next__()
        batchData = iterdata.__next__()

        # img = batchData["img"].cuda()
        # mask = batchData["mask"].cuda()
        img = batchData["img"].to(xm.xla_device())
        mask = batchData["mask"].to(xm.xla_device())


        optimizer_seg.zero_grad()

        rst = segment_net(img)
        seg = rst["seg"]

        # 计算训练过程的损失loss
        loss_seg = criterion_segment(seg, mask)
        loss_seg.backward()
        optimizer_seg.step()

        train_loss_sum += loss_seg.item() 

        # 计算训练过程的accuracy
        net_seg = data_threshold(seg.clone().flatten(), 0.6)    # 预测值
        mask_seg = mask.clone().flatten()                      # 真实值

        # 对每个像素点的值进行比较，相同的点计入right_seg 
        right_seg = torch.eq(net_seg, mask_seg).sum().float().item()
        total_num = float(mask.clone().flatten().size()[0])

        batch_acc = right_seg/total_num
        train_acc_sum += batch_acc

        batch_count += 1

        # 输出每个epoch之中每个batch的信息
        print("[Epocn:{0}],[batch_count:{1}],[loss:{2:.6f}],[accuracy:{3:.6f}]".format(epoch, batch_count, loss_seg.item(), batch_acc))

    # 输出训练过程每个epoch平均的loss和accuracy
    print("[Epoch {0}/{1}], [loss:{2:.6f}], accuracy:{3:.6f}]".format(epoch, opt.end_epoch, train_loss_sum/batch_count, train_acc_sum/batch_count))

    # 将上述epoch, loss, accuracy数据写入xls文件
    print("------------------------------------------------------------------------------------------")
    print("开始写入训练过程的epoch, loss, accuracy")
    train_xls_value = [epoch, train_loss_sum/batch_count, train_acc_sum/batch_count]
    write_excel_xls_append("evaluate_data.xls", train_xls_value, 0)
    print("------------------------------------------------------------------------------------------")

In [6]:
# 定义模型的保存函数
def save_my_model(epoch, segment_net, need_save, save_interval):
    # -----------------------------------------------------------------------------
    # 以一定周期保存训练之后的模型
    if need_save and epoch % save_interval == 0 and epoch >= save_interval:
        save_path_str = "./saved_models"
        if os.path.exists(save_path_str) == False:
            os.makedirs(save_path_str, exist_ok=True)

    torch.save(segment_net.state_dict(), "%s/segment_net_%d.pth" % (save_path_str, epoch))
    print("------------------------------------------------------------------------------------------")
    print("save weights ! epoch = %d" %epoch)
    print("------------------------------------------------------------------------------------------")
    pass

In [8]:
# 定义模型的测试函数
def test_my_model(epoch, segment_net, criterion_segment, testloader, need_test, test_interval):
    # -----------------------------------------------------------------------------
    # 对模型进行测试，并保存结果
    if need_test and epoch % test_interval == 0 and epoch >= test_interval:
        test_loss_sum, test_acc_sum, batch_count = 0.0, 0.0, 0.0
        result_evaluate_epoch = np.array([0.0, 0.0, 0.0, 0.0])

    for i, testBatch in enumerate(testloader):
        # imgTest = testBatch["img"].cuda()
        imgTest = testBatch["img"].to(xm.xla_device())
        t1 = time.time()
        rstTest = segment_net(imgTest)
        t2 = time.time()

        # 计算测试过程的损失loss
        # mask = testBatch["mask"].cuda()
        mask = testBatch["mask"].to(xm.xla_device())
        loss_test = criterion_segment(rstTest["seg"], mask)

        test_loss_sum += loss_test.item()

        # 计算测试过程的accuracy
        net_seg = data_threshold(rstTest["seg"].clone().flatten(), 0.6)  # 预测值
        mask_seg = mask.clone().flatten()                  # 真实值

        # 对每个像素点的值进行比较，相同的点计入right_seg 
        right_seg = torch.eq(net_seg, mask_seg).sum().float().item()
        total_num = float(mask.clone().flatten().size()[0])

        batch_acc = right_seg/total_num
        test_acc_sum += batch_acc

        batch_count += 1

        # 对一个batch测试结果进行综合评估，并进行累加，方便后续保存
        result_evaluate_batch = np.array(list(evaluate_metric(mask_seg, net_seg)))
        result_evaluate_epoch += result_evaluate_batch

        # 对保存的图片进行阈值化处理
        seg_shape = rstTest["seg"].data.shape
        segTest_flatten = data_threshold(rstTest["seg"].flatten(), 0.6)
        segTest = segTest_flatten.reshape(seg_shape[0], seg_shape[1], seg_shape[2], seg_shape[3])

        # 建立文件的保存路径
        save_path_str = "./testResultSeg/epoch_%d"%(epoch)
        if os.path.exists(save_path_str) == False:
            os.makedirs(save_path_str, exist_ok=True)

        # 输出文件的保存信息
        print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
        save_image(imgTest.data, "%s/img_%d.jpg"% (save_path_str, i))
        save_image(segTest.data, "%s/img_%d_seg.jpg"% (save_path_str, i))

    # 将上述测试过程的评估参数acc,precision,recall和F1分数进行保存
    print("------------------------------------------------------------------------------------------")
    print("开始写入评估参数")
    result_evaluate_epoch = result_evaluate_epoch/np.array([batch_count])
    test_xls_metric = [epoch] + list(result_evaluate_epoch)
    write_excel_xls_append("evaluate_data.xls", test_xls_metric, 2)
    print("------------------------------------------------------------------------------------------")


    # 输出测试过程每个epoch平均的loss和accuracy
    print("------------------------------------------------------------------------------------------")
    print("[Epoch {0}/{1}], [loss:{2:.6f}], accuracy:{3:.6f}]".format(epoch, opt.end_epoch, test_loss_sum/batch_count, test_acc_sum/batch_count))
    print("------------------------------------------------------------------------------------------")

    # 将上述epoch, loss, accuracy数据写入xls文件
    print("------------------------------------------------------------------------------------------")
    print("开始写入测试过程的epoch, loss, accuracy")
    test_xls_value = [epoch, test_loss_sum/batch_count, test_acc_sum/batch_count]
    write_excel_xls_append("evaluate_data.xls", test_xls_value, 1)
    print("------------------------------------------------------------------------------------------")

In [None]:
# 开始模型的训练、测试与保存
for epoch in range(opt.begin_epoch, opt.end_epoch):
    iterCFD = trainCFDloader.__iter__()

    lenNum = len(trainCFDloader)

    segment_net.train()
    
    # 模型训练
    train_my_model(epoch=epoch, iter_data = iterCFD, lenNum=lenNum)
    
    # 模型保存
    save_my_model(epoch=epoch, segment_net=segment_net, need_save=opt.need_save, save_interval=opy.save_interval)
    
    # 模型测试
    test_my_model(epoch=epoch, segment_net=segment_net, criterion_segment=criterion_segment, 
                  testloader=testloader , need_test=opt.need_test, test_interval=opt.need_test)
    

In [9]:
'''
for epoch in range(opt.begin_epoch, opt.end_epoch):

  iterCFD = trainCFDloader.__iter__()

  lenNum = len(trainCFDloader)

  segment_net.train()

  # -----------------------------------------------------------------------------
  # 开始训练
  # 记录每一个epoch的总损失和总精度
  train_loss_sum, train_acc_sum, batch_count = 0.0, 0.0, 0.0

  for i in range(0, lenNum):
      
    batchData = iterCFD.__next__()
    
    # img = batchData["img"].cuda()
    # mask = batchData["mask"].cuda()
    img = batchData["img"].to(xm.xla_device())
    mask = batchData["mask"].to(xm.xla_device())


    optimizer_seg.zero_grad()

    rst = segment_net(img)
    seg = rst["seg"]

    # 计算训练过程的损失loss
    loss_seg = criterion_segment(seg, mask)
    loss_seg.backward()
    optimizer_seg.step()

    train_loss_sum += loss_seg.item() 
    
    # 计算训练过程的accuracy
    net_seg = data_threshold(seg.clone().flatten(), 0.6)    # 预测值
    mask_seg = mask.clone().flatten()              # 真实值
    
    # 对每个像素点的值进行比较，相同的点计入right_seg 
    right_seg = torch.eq(net_seg, mask_seg).sum().float().item()
    total_num = float(mask.clone().flatten().size()[0])
    
    batch_acc = right_seg/total_num
    train_acc_sum += batch_acc
    
    batch_count += 1

    # 输出每个epoch之中每个batch的信息
    print("[Epocn:{0}],[batch_count:{1}],[loss:{2:.6f}],[accuracy:{3:.6f}]".format(epoch, batch_count, loss_seg.item(), batch_acc))
    
  # 输出训练过程每个epoch平均的loss和accuracy
  print("[Epoch {0}/{1}], [loss:{2:.6f}], accuracy:{3:.6f}]".format(epoch, opt.end_epoch, train_loss_sum/batch_count, train_acc_sum/batch_count))
    
  # 将上述epoch, loss, accuracy数据写入xls文件
  print("------------------------------------------------------------------------------------------")
  print("开始写入训练过程的epoch, loss, accuracy")
  train_xls_value = [epoch, train_loss_sum/batch_count, train_acc_sum/batch_count]
  write_excel_xls_append("evaluate_data.xls", train_xls_value, 0)
  print("------------------------------------------------------------------------------------------")


  # -----------------------------------------------------------------------------
  # 以一定周期保存训练之后的模型
  if opt.need_save and epoch % opt.save_interval == 0 and epoch >= opt.save_interval:

    save_path_str = "./saved_models"
    if os.path.exists(save_path_str) == False:
        os.makedirs(save_path_str, exist_ok=True)

    torch.save(segment_net.state_dict(), "%s/segment_net_%d.pth" % (save_path_str, epoch))
    print("------------------------------------------------------------------------------------------")
    print("save weights ! epoch = %d" %epoch)
    print("------------------------------------------------------------------------------------------")
    pass
    

  # -----------------------------------------------------------------------------
  # 对模型进行测试，并保存结果
  if opt.need_test and epoch % opt.test_interval == 0 and epoch >= opt.test_interval:

    test_loss_sum, test_acc_sum, batch_count = 0.0, 0.0, 0.0
    result_evaluate_epoch = np.array([0.0, 0.0, 0.0, 0.0])

    for i, testBatch in enumerate(testloader):
      # imgTest = testBatch["img"].cuda()
      imgTest = testBatch["img"].to(xm.xla_device())
      t1 = time.time()
      rstTest = segment_net(imgTest)
      t2 = time.time()

      # 计算测试过程的损失loss
      # mask = testBatch["mask"].cuda()
      mask = testBatch["mask"].to(xm.xla_device())
      loss_test = criterion_segment(rstTest["seg"], mask)

      test_loss_sum += loss_test.item()

      # 计算测试过程的accuracy
      net_seg = data_threshold(rstTest["seg"].clone().flatten(), 0.6)  # 预测值
      mask_seg = mask.clone().flatten()                  # 真实值
      
      # 对每个像素点的值进行比较，相同的点计入right_seg 
      right_seg = torch.eq(net_seg, mask_seg).sum().float().item()
      total_num = float(mask.clone().flatten().size()[0])
      
      batch_acc = right_seg/total_num
      test_acc_sum += batch_acc
      
      batch_count += 1

      # 对一个batch测试结果进行综合评估，并进行累加，方便后续保存
      result_evaluate_batch = np.array(list(evaluate_metric(mask_seg, net_seg)))
      result_evaluate_epoch += result_evaluate_batch
      
      # 对保存的图片进行阈值化处理
      seg_shape = rstTest["seg"].data.shape
      segTest_flatten = data_threshold(rstTest["seg"].flatten(), 0.6)
      segTest = segTest_flatten.reshape(seg_shape[0], seg_shape[1], seg_shape[2], seg_shape[3])

      # 建立文件的保存路径
      save_path_str = "./testResultSeg/epoch_%d"%(epoch)
      if os.path.exists(save_path_str) == False:
          os.makedirs(save_path_str, exist_ok=True)

      # 输出文件的保存信息
      print("processing image NO %d, time comsuption %fs"%(i, t2 - t1))
      save_image(imgTest.data, "%s/img_%d.jpg"% (save_path_str, i))
      save_image(segTest.data, "%s/img_%d_seg.jpg"% (save_path_str, i))

    # 将上述测试过程的评估参数acc,precision,recall和F1分数进行保存
    print("------------------------------------------------------------------------------------------")
    print("开始写入评估参数")
    result_evaluate_epoch = result_evaluate_epoch/np.array([batch_count])
    test_xls_metric = [epoch] + list(result_evaluate_epoch)
    write_excel_xls_append("evaluate_data.xls", test_xls_metric, 2)
    print("------------------------------------------------------------------------------------------")
    
    
    # 输出测试过程每个epoch平均的loss和accuracy
    print("------------------------------------------------------------------------------------------")
    print("[Epoch {0}/{1}], [loss:{2:.6f}], accuracy:{3:.6f}]".format(epoch, opt.end_epoch, test_loss_sum/batch_count, test_acc_sum/batch_count))
    print("------------------------------------------------------------------------------------------")
    
    # 将上述epoch, loss, accuracy数据写入xls文件
    print("------------------------------------------------------------------------------------------")
    print("开始写入测试过程的epoch, loss, accuracy")
    test_xls_value = [epoch, test_loss_sum/batch_count, test_acc_sum/batch_count]
    write_excel_xls_append("evaluate_data.xls", test_xls_value, 1)
    print("------------------------------------------------------------------------------------------")
'''

'\nfor epoch in range(opt.begin_epoch, opt.end_epoch):\n\n  iterCFD = trainCFDloader.__iter__()\n\n  lenNum = len(trainCFDloader)\n\n  segment_net.train()\n\n  # -----------------------------------------------------------------------------\n  # 开始训练\n  # 记录每一个epoch的总损失和总精度\n  train_loss_sum, train_acc_sum, batch_count = 0.0, 0.0, 0.0\n\n  for i in range(0, lenNum):\n      \n    batchData = iterCFD.__next__()\n    \n    # img = batchData["img"].cuda()\n    # mask = batchData["mask"].cuda()\n    img = batchData["img"].to(xm.xla_device())\n    mask = batchData["mask"].to(xm.xla_device())\n\n\n    optimizer_seg.zero_grad()\n\n    rst = segment_net(img)\n    seg = rst["seg"]\n\n    # 计算训练过程的损失loss\n    loss_seg = criterion_segment(seg, mask)\n    loss_seg.backward()\n    optimizer_seg.step()\n\n    train_loss_sum += loss_seg.item() \n    \n    # 计算训练过程的accuracy\n    net_seg = data_threshold(seg.clone().flatten(), 0.6)    # 预测值\n    mask_seg = mask.clone().flatten()              # 真实值\n    \