##### 导入必要的包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import glob
import numpy as np
import os
from   lxml import etree
from   matplotlib.patches import Rectangle
from   matplotlib import pyplot as plt
from   PIL import Image
from   torch.optim import lr_scheduler
from   torch.utils import data
from   torchvision import transforms
%matplotlib inline

##### 数据预处理

In [2]:
''' 构建训练数据集标签/图片路径 '''
all_pictures      = glob.glob(r"../Chapter13 图像的语义分割/HKdataset/HKdataset/training/*.png")
images            = [p for p in all_pictures if "matte" not in p]
annotations       = [p for p in all_pictures if "matte" in p]
''' 设置随机种子，同步乱序 '''
np.random.seed(2021)
index             = np.random.permutation(len(images))
images            = np.array(images)[index]
annotations       = np.array(annotations)[index]
''' 构建测试数据集标签/图片路径 '''
all_test_pictures = glob.glob(r"./HKdataset/HKdataset/testing/*.png")
test_images       = [p for p in all_pictures if "matte" not in p]
test_annotations  = [p for p in all_pictures if "matte" in p]
''' 创建transform '''
transform         = transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor()
])
''' 自定义Dataset类 '''
class HK_DataSet(data.Dataset):
    def __init__(self,imgs_path,annos_path):
        self.imgs_path = imgs_path
        self.annos_path= annos_path
    
    def __getitem__(self,index):
        img        = self.imgs_path[index]
        anno       = self.annos_path[index]
        
        pil_img    = Image.open(img)
        img_tensor = transform(pil_img)
                               
        anno_img   = Image.open(anno)
        anno_tensor= transform(anno_img)
        anno_tensor[anno_tensor>0] = 1
        anno_tensor= torch.squeeze(anno_tensor).type(torch.LongTensor)
        
        return img_tensor,anno_tensor
    
    def __len__(self):
        return len(self.imgs_path)
''' 创建DataSet和DataLoader '''
train_ds = HK_DataSet(images,annotations)
test_ds  = HK_DataSet(test_images,test_annotations)
train_dl = data.DataLoader(train_ds,batch_size=6,shuffle=True)
test_dl  = data.DataLoader(test_ds,batch_size=6)

##### 创建模型

模型创建要点:

    1. 编写 卷积模块 (卷积 + BN + activate)
    2. 编写 反卷积模块 (反卷积 + BN + activate)
    3. 编码器 (4个卷积模块)
    4. 解码器 (卷积模块 + 反卷积模块 + 卷积模块)
    5. 实现整体网络结构 (卷积模块 + 反卷积模块 + 编码器 + 解码器)
    

In [3]:
# 卷积模块
class Convblock(nn.Module):
    
    def __init__(self,in_channels,out_channels,kernel_size=(3,3),stride=1,padding=1):
        super(Convblock,self).__init__()
        self.conv_bn_relu = nn.Sequential(
                            nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True)
        )
    
    def forward(self,x):
        x = self.conv_bn_relu(x)
        return x

In [4]:
# 反卷积模块
class DeConvblock(nn.Module):
    
    def __init__(self,in_channels,out_channels,kernel_size=(3,3),stride=2,padding=1,output_padding=1):
        super(DeConvblock,self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,output_padding=output_padding)
        self.bn     = nn.BatchNorm2d(out_channels)
        
    
    def forward(self,x,is_out=True):
        x = self.deconv(x)
        if is_out:
            x = torch.relu(self.bn(x))
        return x

In [5]:
# 编码器模块
class Encodeblock(nn.Module):
    
    def __init__(self,in_channels,out_channels):
        super(Encodeblock,self).__init__()
        self.conv_1    = Convblock(in_channels,out_channels,stride=2)
        self.conv_2    = Convblock(out_channels,out_channels)
        self.conv_3    = Convblock(out_channels,out_channels)
        self.conv_4    = Convblock(out_channels,out_channels)
        self.short_cut = Convblock(in_channels,out_channels,stride=2)
    
    def forward(self,x):
        out_1     = self.conv_1(x)
        out_1     = self.conv_2(out_1)
        short_cut = self.short_cut(x)
        
        out_2     = self.conv_3(out_1+short_cut)
        out_2     = self.conv_4(out_2)
        
        res       = out_1 + out_2
        
        return res

In [6]:
# 解码器模块
class Decodeblock(nn.Module):
    
    def __init__(self,in_channels,out_channels):
        super(Decodeblock,self).__init__()
        self.conv_1    = Convblock(in_channels,in_channels//4,kernel_size=(1,1),padding=0)
        self.deconv    = DeConvblock(in_channels//4,in_channels//4)
        self.conv_2    = Convblock(in_channels//4,out_channels,kernel_size=(1,1),padding=0)
    
    def forward(self,x):
        x = self.conv_1(x)
        x = self.deconv(x)
        x = self.conv_2(x)
        return x

In [7]:
# 模型编写
class Net(nn.Module):
    
    def __init__(self):
        super(Net,self).__init__()
        
        self.input_conv    = Convblock(3,64,kernel_size=(7,7),stride=2,padding=3)
        self.input_maxpool = nn.MaxPool2d(kernel_size=(2,2))
        
        self.encode_1      = Encodeblock(64,64)
        self.encode_2      = Encodeblock(64,128)
        self.encode_3      = Encodeblock(128,256)
        self.encode_4      = Encodeblock(256,512)
        
        self.decode_4      = Decodeblock(512,256)
        self.decode_3      = Decodeblock(256,128)
        self.decode_2      = Decodeblock(128,64)
        self.decode_1      = Decodeblock(64,64)
        
        self.deconv_out1   = DeConvblock(64,32)
        self.conv_out      = Convblock(32,32)
        #                                   输出类别数
        self.deconv_out2   = DeConvblock(32,2,        kernel_size=2,padding=0,output_padding=0)
        
    
    def forward(self,x):
        x   = self.input_conv(x)
        x   = self.input_maxpool(x)
        
        e1  = self.encode_1(x)
        e2  = self.encode_2(e1)
        e3  = self.encode_3(e2)
        e4  = self.encode_4(e3)
        
        d4  = self.decode_4(e4)
        d3  = self.decode_3(e3+d4)
        d2  = self.decode_2(e2+d3)
        d1  = self.decode_1(e1+d2)
        
        f1  = self.deconv_out1(d1)
        f2  = self.conv_out(f1)
        f3  = self.deconv_out2(f2,is_out=False)
        
        return f3

In [8]:
model   = Net()
model   = model.to("cuda")

In [9]:
loss_fn = nn.CrossEntropyLoss()

##### 训练模型  -- 使用IOU指标

In [10]:
from torch.optim import lr_scheduler
optimizer          = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler   = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [11]:
def fit(epoch, model, trainloader, testloader):
    correct      = 0
    total        = 0
    running_loss = 0
    epoch_iou    = []
    
    model.train()
    for x, y in trainloader:
        x = x.to("cuda")
        y = y.to("cuda")
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred       = torch.argmax(y_pred, dim=1)
            correct      += (y_pred == y).sum().item()
            total        += y.size(0)
            running_loss += loss.item()
            
            intersection = torch.logical_and(y, y_pred)
            union        = torch.logical_or(y, y_pred)
            batch_iou    = torch.true_divide(torch.sum(intersection), 
                                          torch.sum(union))
            epoch_iou.append(batch_iou)
            
    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc  = correct / (total*256*256)
        
        
    test_correct      = 0
    test_total        = 0
    test_running_loss = 0 
    epoch_test_iou    = []
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            x                       = x.to("cuda")
            y                       = y.to("cuda")
            y_pred                  = model(x)
            loss                    = loss_fn(y_pred, y)
            y_pred                  = torch.argmax(y_pred, dim=1)
            test_correct            += (y_pred == y).sum().item()
            test_total              += y.size(0)
            test_running_loss       += loss.item()
            intersection            = torch.logical_and(y, y_pred)
            union                   = torch.logical_or(y, y_pred)
            batch_iou               = torch.true_divide(torch.sum(intersection),torch.sum(union))
            epoch_test_iou.append(batch_iou)
            
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc  = test_correct / (test_total*256*256)
    
#     print('epoch: \t', epoch,'loss: \t', round(epoch_loss, 3),
#           'accuracy: \t', round(epoch_acc, 3),'IOU: \t', round(np.mean(epoch_iou), 3),'\t', 
#           'test_loss: \t', round(epoch_test_loss, 3),'test_accuracy: \t', round(epoch_test_acc, 3),
#           'test_iou: \t', round(np.mean(epoch_test_iou), 3))
    
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

In [12]:
epochs = 40

In [None]:
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,model,train_dl,test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

##### 保存模型

In [None]:
PATH      = 'linknet_model.pth'
torch.save(model.state_dict(), PATH)

##### 测试模型

In [None]:
my_model  = Net()
my_model.load_state_dict(torch.load(PATH))
num       = 3

## train数据集上的测试结果

image, mask = next(iter(train_dl))
pred_mask   = my_model(image)

plt.figure(figsize=(10, 10))
for i in range(num):
    plt.subplot(num, 3, i*num+1)
    plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    plt.subplot(num, 3, i*num+2)
    plt.imshow(mask[i].cpu().numpy())
    plt.subplot(num, 3, i*num+3)
    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
 
print("\n\n\n\n\n\n\n")

## test数据集上的测试结果
image, mask = next(iter(test_dl))
pred_mask   = my_model(image)

plt.figure(figsize=(10, 10))
for i in range(num):
    plt.subplot(num, 3, i*num+1)
    plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    plt.subplot(num, 3, i*num+2)
    plt.imshow(mask[i].cpu().numpy())
    plt.subplot(num, 3, i*num+3)
    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())