In [1]:
#基本的引入
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
#from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import random
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
import matplotlib.pyplot as plt # plt 用于显示图片

In [None]:
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)

In [None]:
#改编版！新增monitor(mIoU)
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, monitor="val_acc", patience=7, verbose=False, delta=0):
        """
        Args:
            monitor (string): 可以选 "val_acc"or "val_loss"or"val_mIoU"
                            
                            Default: "val_acc"
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.monitor=monitor
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = 0
        self.val_mIoU_max = 0
        self.delta = delta
    
    def __call__(self, val, model):
        if self.monitor=='val_loss':
            val_loss=val
            score = -val_loss

            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_loss, model)
                self.counter = 0
        elif self.monitor=='val_acc':
            #这里的val是0-100之间的数。
            val_acc=val
            score = val_acc
            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_acc, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_acc, model)
                self.counter = 0   
        elif self.monitor=='val_mIoU':
            #这里的val是0-100之间的数。
            val_mIoU=val
            score = val_mIoU
            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_mIoU, model)
            elif score < self.best_score + self.delta:
                self.counter += 1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_mIoU, model)
                self.counter = 0 
    def save_checkpoint(self, val, model):
        '''Saves model when validation loss decrease.'''
        '''Saves model when validation accuracy increase.'''
        if self.monitor=='val_loss':
            if self.verbose:
                print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val:.6f}).  Saving model ...')           
            self.val_loss_min = val
        if self.monitor=='val_acc':
            if self.verbose:
                print(f'Validation accuracy increased ({self.val_acc_max:.6f}% --> {val:.6f}%).  Saving model ...')
            self.val_acc_max = val
        if self.monitor=='val_mIoU':
            if self.verbose:
                print(f'Validation mIoU increased ({self.val_mIoU_max:.6f} --> {val:.6f}).  Saving model ...')
            self.val_mIoU_max = val
        torch.save(model.state_dict(), 'checkpoint.pt')	# 这里会存储迄今最优模型的参数

In [None]:
#为了可以同时处理image和mask，设置了一下的函数
def my_transform1(image, mask):
    #my_transform1是针对训练集的，可以做到data augmentation的效果
    if random.random() > 0.5:
        image = tf.hflip(image)
        mask = tf.hflip(mask)
    if random.random() > 0.5:
        image = tf.vflip(image)
        mask = tf.vflip(mask)
    #对image进行resize，totensor，还有normalize
    transform_image = T.Compose([   
        T.Resize([256,256]),        
        T.ToTensor(),
        T.Normalize(mean= [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    image = transform_image(image)
    #对mask进行resize，把它转换成array，把所有大于1的地方换成1，最后转成tensor。
    mask=T.functional.resize(mask,(256,256))
    mask=np.array(mask)
    mask=(mask>=1).astype(int)
    mask = T.functional.to_tensor(mask).float()
    return image, mask
    
def my_transform2(image, mask):
    #my_transform2是针对valid&test的，所以就不需要rotation之类的处理了。
    #对image进行resize，totensor，还有normalize
    transform_image = T.Compose([
        T.Resize([256, 256]),
        T.ToTensor(),
        T.Normalize(mean= [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    image = transform_image(image)
    #对mask进行resize，把它转换成array，把所有大于1的地方换成1，最后转成tensor。
    mask=T.functional.resize(mask,(256,256))
    mask=np.array(mask)
    mask=(mask>=1).astype(int)
    mask = T.functional.to_tensor(mask).float()
    return image, mask

In [None]:
class MyDataset(data.Dataset):
    
    
    def __init__(self, file_path=None, mask_path=None, transform=None,data="train"):   
        """
        初始化自定义Dataset类的参数
        Attributes
            file_path: 字符串，数据集的存储路径，例如‘./UCF101/train’ 或 './UCF101/eval'等
            mask_path: 字符串，数据集的存储路径，例如‘./UCF101/train_annotation’等
            transform: 传入一个从torchvision.transforms定义的数据预处理
        """
        self.count=0
        self.transform = transform
        if data=='train':
            self.name=open(os.path.join("../input/singleperson/TrainVal_images","train_id.txt"))
        elif data=="valid":
            self.name=open(os.path.join("../input/singleperson/TrainVal_images","val_id.txt"))
            #我要改变一下，就是这里的val_id会被分出去一些,就是10000的一半！
        elif data=="test":
            self.name=open(os.path.join("../input/singleperson/TrainVal_images","val_id.txt"))
        # 初始化给定文件夹下的所有数据
        self.init_all_data(file_path,mask_path,data) 

        return None
        

    def init_all_data(self,file_path,mask_path,data):
        """
        初始化该数据集内所有的图像及其对应的标签，保存在self.images和self.labels两个列表内
        Attributes
            file_path: 字符串，数据集文件夹的存储路径
            mask_path: 字符串，数据集文件夹的存储路径
        """
        # 初始化两个列表，记录该数据集内每一张图片的完整路径及其对应的mask
        self.images = []
        self.labels = []
        for name in self.name:
            # 得当当前图片的完整路径，若是有效图片，则记录该图片
            self.count+=1
            if data=="train"and self.count<=2000:
                img = os.path.join(file_path,name[0:-1]+".jpg")
                mask = os.path.join(mask_path,name[0:-1]+".png")
                if self.is_valid_image(img) and self.is_valid_image(mask):
                    self.images.append(img)
                    self.labels.append(mask)
            if data=="valid" and self.count<=600:
                img = os.path.join(file_path,name[0:-1]+".jpg")
                mask = os.path.join(mask_path,name[0:-1]+".png")
                if self.is_valid_image(img) and self.is_valid_image(mask):
                    self.images.append(img)
                    self.labels.append(mask)
            elif data=="test" and self.count>=5000:
                img = os.path.join(file_path,name[0:-1]+".jpg")
                mask = os.path.join(mask_path,name[0:-1]+".png")
                if self.is_valid_image(img) and self.is_valid_image(mask):
                    self.images.append(img)
                    self.labels.append(mask)
        return None

        
    def is_valid_image(self, img_path):
        """
        判断图片是否为可以打开的有效文件
        Attributes
            img_path: 字符串，待检测图片的存储路径
        Returns
            valid: 布尔变量，True/False分别表示该图片是否可以正常打开
        """
        try:
            # 若读取成功，设valid为True
            i = Image.open(img_path)
            valid = True
        except:
            # 若读取失败，设valid为False
            valid = False
            
        return valid
        

    def __getitem__(self, idx):
        """
        按给定索引，获取对应的图片及其标签
        Attributes
            idx: int类型数字，表示目标图像的索引
        Returns
            image: 一个打开的PIL.Image对象，是PIL库存储图像的一种数据格式（类似于OpenCV利用numpy张量存储图像）
            label: Image类型，表示对应的mask
        """
        # 利用PIL.Image.open打开图片，并将其强制转化为RGB格式（防止数据集中混杂灰度图，导致读取出单通道图片，送入网络因矩阵维度不一致而报错）
        image = Image.open(self.images[idx]).convert('RGB')
        # 获取对应的mask
        label = Image.open(self.labels[idx])
        #获取mask后就要把它转换成全0，1的array，再换成Image
        
        # 进行预处理的变换
        if self.transform:
            image,label = self.transform(image,label)
        return image, label
   

    def __len__(self):
        """
        获取数据集中图像的总数，该方法的作用是用于DataLoader去调用，从而获取在给定Batch Size的情况下，一个Epoch的总长，
        从而可以在一个Epoch结束时实现shuffle数据集的功能
        """

        return len(self.images)

In [None]:
train_data = MyDataset("../input/singleperson/TrainVal_images/TrainVal_images/train_images",
                       "../input/singleperson/TrainVal_parsing_annotations/TrainVal_parsing_annotations/TrainVal_parsing_annotations/train_segmentations",
                       transform=my_transform1,
                      data="train")
valid_data=MyDataset("../input/singleperson/TrainVal_images/TrainVal_images/val_images",
                       "../input/singleperson/TrainVal_parsing_annotations/TrainVal_parsing_annotations/TrainVal_parsing_annotations/val_segmentations",
                       transform=my_transform2,
                      data="valid")
test_data=MyDataset("../input/singleperson/TrainVal_images/TrainVal_images/val_images",
                       "../input/singleperson/TrainVal_parsing_annotations/TrainVal_parsing_annotations/TrainVal_parsing_annotations/val_segmentations",
                       transform=my_transform2,
                      data="test")

In [None]:
Num_workers=2
train_loader=data.DataLoader(dataset=train_data,batch_size=32,
                             shuffle=True, num_workers=Num_workers)
valid_loader=data.DataLoader(dataset=valid_data,batch_size=32,
                             shuffle=True, num_workers=Num_workers)
test_loader=Data.DataLoader(dataset=test_data,batch_size=32,
                             shuffle=True, num_workers=Num_workers)

计算IoU的function

In [None]:
def IoU(inputs,targets,smooth=1):
    #把inputs，targets转成cpu再detach，这样就不会占用GPU资源。
    inputs = inputs.cpu().detach().numpy()
    targets = targets.cpu().detach().view(-1)
    #要对input进行threshold，让他变成（0，1）组成的。
    inputs=torch.tensor((inputs>0.5).astype(np.int32)).view(-1)
    #intersection is equivalent to True Positive count
    #union is the mutually inclusive area of all labels & predictions 
    intersection = (inputs * targets).sum()
    total = (inputs + targets).sum()
    union = total - intersection 
    IoU = (intersection + smooth)/(union + smooth)
    return IoU.numpy()
#return是numpy这样占的内存就不会过大了。

train model

In [None]:
#新增mIoU(其实用的是IoU)如果要改成mIoU只要把里面的两处给改了。
def train_model(model,device, patience, n_epochs):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    # to track the training mIoU as the model trains
    train_mIoU = []
    # to track the valid mIoU as the model trains
    valid_mIoU = []
    # to track the average training mIoU per epoch as the model trains 
    avg_train_mIoU = []
    # to track the average validation mIoU per epoch as the model trains
    avg_valid_mIoU = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping("val_mIoU",patience=patience, verbose=True,delta=0)
    
    for epoch in range(1, n_epochs + 1):
 
        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        
        for step, (X, y) in enumerate(train_loader):
            
            X, y = X.to(device), y.to(device)
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # calculate the loss
            loss = loss_func(output, y)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss and training mIoU
            train_losses.append(loss.item())
            train_mIoU.append(IoU(output, y)) 
            
        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
     
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device), y.to(device)
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # calculate the loss
            loss = loss_func(output, y)
            # record validation loss and valid mIoU
            valid_losses.append(loss.item())
            valid_mIoU.append(IoU(output, y)) 
            
        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        train_mIoU= np.average( train_mIoU)
        valid_mIoU= np.average( valid_mIoU)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        avg_train_mIoU.append(train_mIoU)
        avg_valid_mIoU.append(valid_mIoU)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'train_mIoU: {train_mIoU:.5f} ' +
                     f'\n    valid_loss: {valid_loss:.5f} ' +
                     f'valid_mIoU: {valid_mIoU:.5f} ' )
        
        print(print_msg)
        # early_stopping needs the validation acc to check if it has incresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_mIoU, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        train_mIoU = []
        valid_mIoU = []
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))
 
    return  model, avg_train_losses, avg_valid_losses,avg_train_mIoU,avg_valid_mIoU

In [None]:
#这是一个pytorch提供的大脑扫描pretrain的U-net
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=False).to(device)
#网址：https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/

关于loss function，我试了以下几个。

In [None]:
#PyTorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [None]:
#PyTorch
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = torch.sigmoid(inputs)    
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU

In [None]:
#PyTorch
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = torch.sigmoid(inputs)    
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

开始训练

In [None]:
lr = 0.001
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=1e-5)
#loss_func = torch.nn.MSELoss()  
#loss_func = torch.nn.BCELoss()  
#loss_func = DiceLoss()
loss_func = IoULoss()  

In [None]:
n_epochs=100
patience = 7
#optimizer=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=1e-4)
model, train_loss, valid_loss,train_mIoU,valid_mIoU = train_model(model ,device, patience, n_epochs)

可视化单个图片的训练结果。

In [None]:
#打印一下图片，ground truth 和我做出的mask，以及threshold后的mask
import matplotlib.pyplot as plt # plt 用于显示图片
plt.figure(dpi = 600)#让图片清晰些
#导入要验证的图片
image,label=valid_data.__getitem__(190)
#打印原图
plt.subplot(1,4,1)
plt.title("image")
plt.imshow(np.transpose(image.numpy(),(1,2,0)))
#打印ground truth
plt.subplot(1,4,2)
plt.title("ground truth")
plt.imshow(label.numpy()[0],cmap="gray")
#打印我做出来的mask
img = torch.unsqueeze(image,dim=0)
b_x=img.cuda()
out=model(b_x).to(torch.float64)
out=out.cpu().detach().numpy()[0][0]
plt.subplot(1,4,3)
plt.title("output mask")
plt.imshow(out,cmap="gray")
#打印threshold后的mask
plt.subplot(1,4,4)
plt.title("threshold mask")
out=out>0.5
plt.imshow(out,cmap="gray")
#保存图片
plt.savefig('valid_190_BCE.png')
plt.show()

可视化loss和mIoU

In [None]:
# visualize the loss as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_loss)+1),train_loss, label='Training Loss')
plt.plot(range(1,len(valid_loss)+1),valid_loss,label='Validation Loss')

# find position of lowest validation loss
minposs = valid_loss.index(min(valid_loss))+1 
plt.axvline(minposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
#plt.ylim(0, 0.5) # consistent scale
plt.xlim(0, len(train_loss)+1) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('Dice_loss.png', bbox_inches='tight')

In [None]:
# visualize the mIoU as the network trained
fig = plt.figure(figsize=(10,8))
plt.plot(range(1,len(train_mIoU)+1),train_mIoU, label='Training Loss')
plt.plot(range(1,len(valid_mIoU)+1),valid_mIoU,label='Validation Loss')

# find position of lowest validation loss
maxposs = valid_mIoU.index(max(valid_mIoU))+1 
plt.axvline(maxposs, linestyle='--', color='r',label='Early Stopping Checkpoint')

plt.xlabel('epochs')
plt.ylabel('loss')
#plt.ylim(0, 0.5) # consistent scale
plt.xlim(0, len(train_loss)+1) # consistent scale
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
fig.savefig('Dice_mIoU.png', bbox_inches='tight')

test部分

In [None]:
def test(model, device, test_loader,loss_func):
    model.eval()
    test_loss = []
    test_mIoU=[]
    for (X, y) in test_loader:
        X, y = X.to(device), y.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(X)
        # calculate the loss
        loss = loss_func(output, y)
        # record validation loss
        test_loss.append(loss.item())
        test_mIoU.append(IoU(output,y))   
    
    print('\nTest set: Average loss: {:.4f}, Average mIoU: {:.4f}'.format(
        np.average(test_loss),np.average(test_mIoU)))

In [None]:
test(model, device, test_loader,loss_func)

In [None]:
#打印一下图片，ground truth 和我做出的mask，以及threshold后的mask
import matplotlib.pyplot as plt # plt 用于显示图片
plt.figure(dpi = 600)#让图片清晰些
#导入要验证的图片
image,label=test_data.__getitem__(190)
#打印原图
plt.subplot(1,4,1)
plt.title("image")
plt.imshow(np.transpose(image.numpy(),(1,2,0)))
#打印ground truth
plt.subplot(1,4,2)
plt.title("ground truth")
plt.imshow(label.numpy()[0],cmap="gray")
#打印我做出来的mask
img = torch.unsqueeze(image,dim=0)
b_x=img.cuda()
out=model(b_x).to(torch.float64)
out=out.cpu().detach().numpy()[0][0]
plt.subplot(1,4,3)
plt.title("output mask")
plt.imshow(out,cmap="gray")
#打印threshold后的mask
plt.subplot(1,4,4)
plt.title("threshold mask")
out=out>0.5
plt.imshow(out,cmap="gray")
#保存图片
plt.savefig('test_190_BCE.png')
plt.show()

可视化多个图片的训练结果

In [None]:
def imageshow(num_figure,model,dataloader):
    #定义了一个打印多张图片的function
    fig, axes = plt.subplots(num_figure, 4,dpi = 600, figsize=(7, 6))
    imgs=np.arange(num_figure)*10
    for i in imgs:
        #导入要验证的图片
        image,label=dataloader.__getitem__(i)
        i=int(i/10)#设置index
        #打印原图
        axes[i][0].imshow(np.transpose(image.numpy(),(1,2,0)))
        #打印ground truth
        axes[i][1].imshow(label.numpy()[0],cmap='gray')
        #打印我做出来的mask
        img = torch.unsqueeze(image,dim=0)
        b_x=img.cuda()
        out=model(b_x).to(torch.float64)
        out=out.cpu().detach().numpy()[0][0]
        axes[i][2].imshow(out,cmap="gray")
        #打印threshold后的mask
        out=out>0.5
        axes[i][3].imshow(out,cmap="gray")  
    for ax in axes.ravel():
        ax.axis('off')#关掉坐标轴       
    fig.tight_layout() #让图片紧密 
    #设置标签
    axes[0][0].set_title("Original image")
    axes[0][1].set_title("Ground truth")
    axes[0][2].set_title("Output mask")
    axes[0][3].set_title("Threshold mask")
    #保存图片
    fig.savefig('imshow.png')
    fig.show()

In [None]:
imageshow(6,model,valid_data)