In [1]:
# #生成train、val、test的txt文件
# from sklearn.model_selection import train_test_split
# import os

# imagedir = './dataset/2w_dataset/2w_image'
# outdir = './dataset/2w_dataset'
# os.makedirs(outdir,exist_ok=True)

# images = []
# for file in os.listdir(imagedir):
#     filename = file.split('.')[0]
#     images.append(filename)

# # Split the data into training, validation, and test sets (8:1:1 ratio)
# train_size = 0.8
# val_size = 0.1
# test_size = 0.1

# train, temp = train_test_split(images, test_size=(val_size + test_size), random_state=100)
# val, test = train_test_split(temp, test_size=(test_size / (val_size + test_size)), random_state=100)

# # Write the lists to text files
# with open(os.path.join(outdir, "train.txt"), 'w') as f:
#     f.write('\n'.join(train))

# with open(os.path.join(outdir, "val.txt"), 'w') as f:
#     f.write('\n'.join(val))

# with open(os.path.join(outdir, "test.txt"), 'w') as f:
#     f.write('\n'.join(test))

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import glob
import time
import os
from skimage.io import imread
import copy
import cv2 as cv

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch.utils.data as Data
from torchvision import transforms
from torchsummary import summary
from unet import UNet
from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net
from ConvUNeXt import ConvUNeXt
from SA_Unet import SA_UNet
from RMA_UNet import RMA_UNet
import torchvision

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 读取数据

In [None]:
#定义读取图像的函数
def read_image(root = './dataset/2w_dataset_revise/train.txt'):
    image = np.loadtxt(root,dtype=str)
    n = len(image)
    data,label = [None]*n, [None]*n
    for i, fname in enumerate(image):
        data[i] = imread('./dataset/2w_dataset_revise/2w_image/%s.jpg' %(fname))
        label[i] = imread('./dataset/2w_dataset_revise/2w_mask/%s.png' %(fname))
    return data,label

In [None]:
#读取训练数据
traindata,trainlabel = read_image(root = './dataset/2w_dataset_revise/train.txt')
#读取验证数据集
valdata,vallabel = read_image(root = './dataset/2w_dataset_revise/val.txt')

In [None]:
#查看训练集和验证集的图像
plt.figure(figsize=(12,8))
plt.subplot(2,2,1)
plt.imshow(traindata[0], cmap=plt.cm.gray)
plt.subplot(2,2,3)
plt.imshow(trainlabel[0], cmap=plt.cm.gray)
plt.subplot(2,2,2)
plt.imshow(valdata[0], cmap=plt.cm.gray)
plt.subplot(2,2,4)
plt.imshow(vallabel[0], cmap=plt.cm.gray)
plt.show()

# 处理图像标签

In [None]:
#列出标签
classes = ['background','pore']
colormap = [[0, 0, 0], # 0 = background
            [128, 0, 0]] 

In [None]:
# 将一个标记好的图像转化为类别标签图像
def image2label(image, colormap):
    # 每个像素点有 0 ~ 255 的选择，RGB 三个通道
    cm2lbl = np.zeros(256**3)
    # 枚举的时候i是下标，cm是一个三元组，分别标记了RGB值
    for i,cm in enumerate(colormap):
        cm2lbl[(cm[0]*256+cm[1]*256+cm[2])] = i    # 建立索引
    # 对一张图像转换
    image = np.array(image, dtype="int64")
    ix = (image[:,:,0]*256+image[:,:,1]*256+image[:,:,2])
    image2 = cm2lbl[ix]
    return image2


# 单组图像的转换操作
def img_transforms(data, label, colormap):
# 数据的随机裁剪、将图像数据进行标准化、将标记图像数据进行二维标签化的操作，输出原始图像和类别标签的张量数据
    data_tfs = transforms.Compose([
        transforms.ToTensor(),
#         transforms.Normalize([0.446],
#                              [0.182])])
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])])
    data = data_tfs(data)
    label = torch.from_numpy(image2label(label, colormap))
    return data, label

# # 定义需要读取的数据路径的函数
# def read_image_path(root = './dataset/img_voc/train.txt'):
# # 原始图像路径输出为data，标签图像路径输出为label
#     image = np.loadtxt(root,dtype=str)
#     n = len(image)
#     data,label = [None]*n, [None]*n
#     for i, fname in enumerate(image):
#         data[i] = imread('./dataset/img_voc/JPEGImages/%s.jpg' %(fname))
#         label[i] = imread('./dataset/img_voc/SegmentationClass/%s.png' %(fname))
#     return data,label

In [None]:
# 定义一个MyDataset继承于torch.utils.data.Dataset类
class MyDataset(Data.Dataset):
    """用于读取图像，并进行相应的裁剪等"""
    def __init__(self, data_root, imtransform, colormap):
        ## data_root:数据所对应的文件名
        ## high,width:图像裁剪后的尺寸
        ## imtransform:预处理操作
        ## colormap:颜色
        self.data_root = data_root
        self.imtransform = imtransform
        self.colormap = colormap
        data_list, label_list = read_image(root=data_root)
        self.data_list = data_list
        self.label_list = label_list

    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.fromarray(img).convert('RGB')
        label = Image.fromarray(label)
        img, label = self.imtransform(img, label, self.colormap)
        return img,label
    def __len__(self):
        return len(self.data_list)

In [None]:
# 读取数据
voc_train = MyDataset("./dataset/2w_dataset_revise/train.txt", img_transforms, colormap)
voc_val = MyDataset("./dataset/2w_dataset_revise/val.txt", img_transforms, colormap)
# 创建数据加载器每个batch使用4张图像
train_loader = Data.DataLoader(voc_train, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
val_loader = Data.DataLoader(voc_val, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
# 检查训练数据集的一个batch的样本的维度是否正确
for step,(b_x,b_y) in enumerate(train_loader):
    if step > 0:
        break
# 输出训练图像的尺寸和标签的尺寸，以及接受类型
print("b_x.shape:",b_x.shape)
print("b_y.shape:",b_y.shape)

In [None]:
# 将标准化后的图像转化为0-1的区间
def inv_normalize_image(data):
    mean= np.array([0.446])
    std = np.array([0.182])
    data = data.astype('float32') * std + mean
    return data.clip(0,1)

# 从预测的标签转化为图像的操作
def label2image(prelabel,colormap):
    h,w = prelabel.shape
    prelabel = prelabel.reshape(h*w, -1)
    image = np.zeros((h*w, 3),dtype="int32")
    for ii in range(len(colormap)):
        index = np.where(prelabel == ii)
        image[index, :] = colormap[ii]
    return image.reshape(h,w,3)

# # 可视化一个batch的图像，检查数据预处理是否正确
# b_x_numpy = b_x.data.numpy()
# b_x_numpy = b_x_numpy.transpose(0,2,3,1)
# b_y_numpy = b_y.data.numpy()
# plt.figure(figsize=(16,6))
# for ii in range(8):
#     plt.subplot(2,8,ii+1)
#     plt.imshow(inv_normalize_image(b_x_numpy[ii]))
#     plt.axis("off")
#     plt.subplot(2,8,ii+9)
#     plt.imshow(label2image(b_y_numpy[ii],colormap))
#     plt.axis("off")
# plt.subplots_adjust(wspace=0.1, hspace=0.1)
# plt.show()

# 网络搭建

In [None]:
model = RMA_UNet(3,2).to(device)
# model = ConvUNeXt(3,2,32).to(device)
# model = UNet(3, 2).to(device)
# model = SA_UNet(3, 2).to(device)
summary(model, input_size=(3, 256, 256))

In [None]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

In [None]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=2):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(1, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

In [None]:
def train_model(model, criterion, optimizer, scheduler, traindataloader, valdataloader, num_epochs):
    """
    :param model: 网络模型
    :param criterion: 损失函数
    :param optimizer: 优化函数
    :param traindataloader: 训练的数据集
    :param valdataloader: 验证的数据集
    :param num_epochs: 训练的轮数
    """
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    
    train_loss_all = []
    train_acc_all = []
    train_iou_all = []
    
    val_loss_all = []
    val_acc_all = []
    val_iou_all = []
    
    lr_list = []
    
    start = time.time()

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        
        train_loss = 0.0
        train_num = 0
        train_acc = 0
        train_iou = 0
        
        val_loss = 0.0
        val_num = 0
        val_acc = 0
        val_iou = 0
        ## 每个epoch包括训练和验证阶段
        model.train()  ## 设置模型为训练模式
        for step,(b_x,b_y) in enumerate(traindataloader):
            optimizer.zero_grad()
            b_x = b_x.float().to(device)
            b_y = b_y.long().to(device)
            out = model(b_x)
#             out = torch.argmax(F.softmax(out, dim=1), dim=1)
            loss = criterion(out, b_y) ## 计算损失函数值
            train_acc += pixel_accuracy(out, b_y)
            train_iou += mIoU(out, b_y)
            loss.backward()
            optimizer.step()
#             lr_list.append(optimizer.param_groups[0]['lr'])
            train_loss += loss.item() * len(b_y)
            train_num += len(b_y)
        scheduler.step()
            

        ## 计算一个epoch在训练集上的损失和精度
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        train_loss_all.append(train_loss / train_num)
        train_acc_all.append(train_acc/len(traindataloader))
        train_iou_all.append(train_iou/len(traindataloader))
        print('Epoch:{} | Train loss"{:.5f} | Train ACC:{:.5f} | Train Iou:{:.5f}'.format(epoch, train_loss_all[-1], train_acc/len(traindataloader), train_iou/len(traindataloader)))

        ## 计算一个epoch训练后在验证集上的损失
        model.eval() ## 设置模型为验证模式
        for step,(b_x,b_y) in enumerate(valdataloader):
            b_x = b_x.float().to(device)
            b_y = b_y.long().to(device)
            out = model(b_x)
#             out = torch.argmax(F.softmax(out, dim=1), dim=1)
            loss = criterion(out, b_y) ## 计算损失函数值
            val_acc += pixel_accuracy(out, b_y)
            val_iou += mIoU(out, b_y)
            val_loss += loss.item() * len(b_y)
            val_num += len(b_y)
#             scheduler.step(val_loss)
            
        ## 计算一个epoch在验证集上的损失和精度
        val_loss_all.append(val_loss / val_num)
        val_acc_all.append(val_acc/len(valdataloader))
        val_iou_all.append(val_iou/len(valdataloader))
        print('Epoch:{} | Val loss"{:.5f} | Val ACC:{:.5f} | Val Iou:{:.5f}'.format(epoch, val_loss_all[-1], val_acc/len(valdataloader), val_iou/len(valdataloader)))

        ## 保存最好的网络参数
        if val_loss_all[-1] < best_loss:
            best_loss = val_loss_all[-1]
            best_model_wts = copy.deepcopy(model.state_dict())
            
#         scheduler.step(val_loss_all[-1])
        
        ## 每个epoch花费的时间
        time_use = time.time() - start
        print("Train and val complete in {:.0f}m {:.0f}s".format(time_use // 60, time_use %60))

        data = {"epoch":range(num_epochs),
                "train_loss_all":train_loss_all,
                "train_acc_all":train_acc_all,
                "train_iou_all":train_iou_all,
                "val_loss_all":val_loss_all,
                "val_acc_all":val_acc_all,
                "val_iou_all":val_iou_all,
                "lr_list":lr_list}
    ## 输出最好的模型
    model.load_state_dict(best_model_wts)
    return model,data

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
scheduler=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100)

In [None]:
model,data = train_model(model,criterion,optimizer,scheduler,train_loader,val_loader, num_epochs=5000)

In [None]:
torch.save(model,"RMA_UNet.pt")

In [None]:
torch.save(data,"RMA_UNet_data.pt")