导入相关的包

In [None]:
from mydataset import MyDataset
import os
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from unet import UNet
import time
import os

import cv2
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from tqdm.notebook import tqdm

一些常量

In [None]:
test_batch_size = 8
batch_size = 8
start_epoch = 10
epochs = 90
lr = 0.001
milestones = [25, 50, 75, 100, 125, 150, 175, 200, 250, 300, 400]
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)


一些函数

In [None]:

def get_model_name(s, lens):
    return os.path.join(os.getcwd(), "saves/U-Net-" + str(s + lens) + "th.model")


def get_optimizer_name(s, lens):
    return os.path.join(os.getcwd(), "saves/U-Net-" + str(s + lens) + "th.opt")


def get_scheduler_name(s, lens):
    return os.path.join(os.getcwd(), "saves/U-Net-" + str(s + lens) + "th.sche")


def get_loss_name_pdf(s, lens):
    return os.path.join(os.getcwd(), "saves/U-Net-" + str(s + lens) + "th-loss.pdf")


def get_loss_name_npy():
    return os.path.join(os.getcwd(), "saves/U-Net-loss.npy")


def get_train_iou_npy():
    return os.path.join(os.getcwd(), "saves/U-Net-train-iou.npy")


def get_val_iou_npy():
    return os.path.join(os.getcwd(), "saves/U-Net-val-iou.npy")


def get_iou_pdf(s, lens):
    return os.path.join(os.getcwd(), "saves/U-Net-" + str(s + lens) + "th-iou.pdf")


def cal_iou(data_loader, model):
    acc_ratios = []
    class_values = list(range(9))  # [0, 1, 2. ... 9]
    totcnt = [0] * 9
    tmp_tot = [0] * 9
    for X, Y in data_loader:
        X, Y = X.to(device), Y.to(device)
        Y_pred = model(X)
        Y_pred = torch.argmax(Y_pred, dim=1)
        for class_value in class_values:
            intersection = torch.logical_and(Y == class_value, Y_pred == class_value).sum(dim=(1, 2))
            union = torch.logical_or(Y == class_value, Y_pred == class_value).sum(dim=(1, 2))
            for i in range(Y.shape[0]):
                if union[i] > 0:
                    tmp_tot[class_value] += intersection[i] / union[i]
                    totcnt[class_value] += 1
    for i in class_values:
        acc_ratios.append((tmp_tot[i] / totcnt[i]).item())
    return acc_ratios


def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour) # np.equal实现把label image每个像素的RGB值与某个class的RGB值进行比对，变成RGB bool值
        class_map = np.all(equality, axis=-1) # np.all 把RGB bool值，变成一个bool值，即实现某个class 的label mask。使用for循环，生成所有class的label mask
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1) # np.stack实现所有class的label mask的堆叠。最终depth size 为num_classes的数量
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis=-1) # axis表示最后一个维度，即channel
    return x


class MyDataset(Dataset):

    def __init__(self, image_dir, mask_dir, train_dir):
        self.Size = (256, 256)
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_fns = os.listdir(image_dir)
        self.mask_fns = os.listdir(mask_dir)
        self.class_dict = pd.read_csv(os.path.join(train_dir, 'labels_class_dict.csv'))
        # Get class names
        self.class_names = self.class_dict['class_names'].tolist()
        # Get class RGB values
        self.class_rgb_values = self.class_dict[['r', 'g', 'b']].values.tolist()
        self.image_preprocessed = []
        self.mask_preprocessed = []

        for index in range(0, len(self.image_fns)):
            image_file_name = self.image_fns[index]
            image_path = os.path.join(self.image_dir, image_file_name)
            mask_file_name = self.mask_fns[index]
            mask_path = os.path.join(self.mask_dir, mask_file_name)
            #image = Image.open(image_path).convert('RGB')
            image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
            image = np.array(image)
            image = self.transform(image)
            mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2RGB)
            #mask = Image.open(mask_path).convert('RGB')
            mask = np.array(mask)
            mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
            mask = reverse_one_hot(mask)
            mask = torch.Tensor(mask).long()
            mask = mask.unsqueeze(0)  # 升一维
            mask_transform = transforms.Resize(size=self.Size, interpolation=InterpolationMode.NEAREST)
            mask = mask_transform(mask)
            mask = mask.squeeze(0)  # 还原维度
            self.image_preprocessed.append(image)
            self.mask_preprocessed.append(mask)

    def __len__(self):
        return len(self.image_fns)

    def transform(self, image):
        transform_ops = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            transforms.Resize(size=self.Size, interpolation=InterpolationMode.NEAREST)
        ])
        return transform_ops(image)

    def __getitem__(self, index):
        return self.image_preprocessed[index], self.mask_preprocessed[index]


设置目录

In [None]:

train_dir = os.path.join(os.getcwd(), "new_data", "train")
image_dir = os.path.join(train_dir, "images")
mask_dir = os.path.join(train_dir, "masks")
train_dataset = MyDataset(image_dir=image_dir, mask_dir=mask_dir, train_dir=train_dir)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size)

val_dir = os.path.join(os.getcwd(), "new_data", "val")
image_dir = os.path.join(val_dir, "images")
mask_dir = os.path.join(val_dir, "masks")
test_dataset = MyDataset(image_dir=image_dir, mask_dir=mask_dir, train_dir=val_dir)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size)

加载上次的模型与loss数据

In [None]:
model = UNet(num_classes=9).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)
epoch_losses = []
train_iou_arr = []
val_iou_arr = []
if start_epoch > 0:
    model.load_state_dict(torch.load(get_model_name(start_epoch, 0)))
    optimizer.load_state_dict(torch.load(get_optimizer_name(start_epoch, 0)))
    scheduler.load_state_dict(torch.load(get_scheduler_name(start_epoch, 0)))
    if os.path.exists(get_loss_name_npy()):
        epoch_losses = np.load(get_loss_name_npy()).tolist()
    if os.path.exists(get_train_iou_npy()):
        train_iou_arr = np.load(get_train_iou_npy()).tolist()
    if os.path.exists(get_val_iou_npy()):
        val_iou_arr = np.load(get_val_iou_npy()).tolist()
epoch_losses = epoch_losses[:start_epoch]
train_iou_arr = train_iou_arr[:start_epoch]
val_iou_arr = val_iou_arr[:start_epoch]

开始训练

In [None]:

for epoch in range(epochs):
    epoch += 1  # start from 1
    print("on iteration:" + str(start_epoch + epoch))
    start_time = time.time()
    epoch_loss = 0
    for X, Y in train_data_loader:
        X, Y = X.to(device), Y.to(device)
        optimizer.zero_grad()
        Y_pred = model(X)
        loss = criterion(Y_pred, Y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_losses.append(epoch_loss / len(train_data_loader))

    # 每次迭代计算训练集和验证集的iou
    cur_train_iou = cal_iou(train_data_loader, model)
    cur_val_iou = cal_iou(test_data_loader, model)
    train_iou_arr.append(cur_train_iou)
    val_iou_arr.append(cur_val_iou)
    # 保存模型与数据
    if (start_epoch + epoch) % 20 == 0:
        torch.save(model.state_dict(), get_model_name(start_epoch, epoch))
        torch.save(optimizer.state_dict(), get_optimizer_name(start_epoch, epoch))
        torch.save(scheduler.state_dict(), get_scheduler_name(start_epoch, epoch))

    np.save(get_loss_name_npy(), epoch_losses)
    np.save(get_train_iou_npy(), train_iou_arr)
    np.save(get_val_iou_npy(), val_iou_arr)

    scheduler.step()
    end_time = time.time()
    print("curt: " + str(end_time - start_time) + "/s")


打印并保存loss图

In [None]:
fig, axe = plt.subplots(figsize=(10, 10))
x_axe = list(range(start_epoch + epochs))
x_axe = [i + 1 for i in x_axe]
axe.plot(x_axe, epoch_losses)
plt.xlabel('epoch')
plt.ylabel('loss')
axe.set_title('loss lr=' + str(lr))
plt.savefig(get_loss_name_pdf(start_epoch, epochs))


打印并保存iou图

In [None]:
fig2, axes2 = plt.subplots(3, 3, figsize=(15, 15))
plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
for row in range(3):
    for col in range(3):
        axes2[row, col].plot(x_axe, [x[row * 3 + col] for x in train_iou_arr])
        axes2[row, col].plot(x_axe, [x[row * 3 + col] for x in val_iou_arr])
        axes2[row, col].set_title("IOU of class " + str(row * 3 + col))
        axes2[row, col].legend(['train', 'val'])
plt.savefig(get_iou_pdf(start_epoch, epochs))
