In [14]:
import cv2
import numpy as np
import os
import pandas as ps
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn as nn 
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
# This is for the progress bar.
from tqdm.auto import tqdm
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import save_image
from net.unetpp import NestedUNet as unetpp

In [3]:
#文件夹位置加载
train_folder_path = '../data/FAZ/Domain1/train/imgs'
mask_folder_path = '../data/FAZ/Domain1/train/mask'
test_folder_path='../data/FAZ/Domain1/test/imgs'
test_mask_folder_path = '../data/FAZ/Domain1/test/mask'

In [4]:
train_files = os.listdir(train_folder_path)
mask_files= os.listdir(mask_folder_path)
test_files = os.listdir(test_folder_path)
test_mask_files = os.listdir(test_mask_folder_path)

# 清理训练集中的非png文件
del_count = 0
for i in range(len(train_files)):
    if train_files[i-del_count].endswith(".png"):
        train_files[i-del_count] = os.path.join(train_folder_path,train_files[i-del_count])
        pass
    else:
        print(train_files[i-del_count])
        del train_files[i-del_count]
        del_count = del_count+1

# 清理训练集mask中的非png文件       
del_count = 0
for i in range(len(mask_files)):
    if mask_files[i-del_count].endswith(".png"):
        mask_files[i-del_count] = os.path.join(mask_folder_path,mask_files[i-del_count])
        pass
    else:
        print(mask_files[i-del_count])
        del mask_files[i-del_count]
        del_count = del_count+1

# 清理测试集img中的非png文件       
del_count = 0
for i in range(len(test_files)):
    if test_files[i-del_count].endswith(".png"):
        test_files[i-del_count] = os.path.join(test_folder_path,test_files[i-del_count])
    else:
        print(test_files[i-del_count])
        del test_files[i-del_count]
        del_count = del_count+1

# 清理测试集mask中的非png文件       
del_count = 0
for i in range(len(test_mask_files)):
    if test_mask_files[i-del_count].endswith(".png"):
        test_mask_files[i-del_count] = os.path.join(test_mask_folder_path,test_mask_files[i-del_count])
        pass
    else:
        print(test_mask_files[i-del_count])
        del test_mask_files[i-del_count]
        del_count = del_count+1
# test_files = os.listdir(test_folder_path)
print(len(train_files),len(mask_files))
print(len(test_files),len(test_mask_files))

.ipynb_checkpoints
244 244
60 60


In [6]:
# 定义数据集

# 转化为tensor
transform=transforms.Compose([
    transforms.ToTensor()
])
class MyDataset(Dataset):

    def __init__(self,train_files = None,mask_files = None):
        super(MyDataset).__init__()
        self.train_files =train_files
        if train_files != None:
            self.train_files = train_files
        self.mask_files =mask_files
        if mask_files != None:
            self.mask_files = mask_files
        print(f"One sample",self.train_files[0])
    
    def __len__(self):
        return len(self.train_files)

    # cv读取图像信息
    def __getitem__(self,idx):   
        fname = self.train_files[idx]
        im = cv2.imread(fname)
        mname = self.mask_files[idx]
        segemnt_im = cv2.imread(mname)
        return transform(im/255).float(),transform(segemnt_im/255).float()

# 数据集
dataset = MyDataset(train_files=train_files,mask_files=mask_files)

One sample ../data/FAZ/Domain1/train/imgs/057_N_60.png


## loss 函数

In [7]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path='params/unet.pth'
data_path=r'data'
save_path='train_image'

def dice_loss(target,predictive,ep=1e-8):
    intersection = 2 * torch.sum(predictive * target) + ep
    union = torch.sum(predictive) + torch.sum(target) + ep
    loss = 1 - intersection / union
    return loss
    
def focal_loss(y_pred, y_real, eps = 1e-8, gamma = 2):
    #y_pred =  # hint: torch.clamp
    L = (y_pred.clamp(min=0) - y_pred*y_real + torch.log(1 + torch.exp(-torch.abs(y_pred)))).mean()
    focal_loss = 1*(1-torch.exp(-L))**gamma * L
    return focal_loss


### Tarin

In [None]:
data_loader=DataLoader(dataset,batch_size=8,shuffle=True)
net=UNet().to(device)
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('successful load weight！')
else:
    print('not successful load weight')

opt=optim.Adam(net.parameters())
loss_fun=nn.BCELoss()
epoch=50
lossArr =[]

# start to train
for j in range(epoch):
    lossItem = 0
    for i,(image,segment_image) in enumerate(data_loader):
            image, segment_image=image.to(device),segment_image.to(device)

            out_image=net(image)
            train_loss=loss_fun(out_image,segment_image)

            opt.zero_grad()
            train_loss.backward()
            opt.step()
            lossItem = lossItem + train_loss.item()
            
            if i%50==0:
                torch.save(net.state_dict(),weight_path)
            # 组合图片
            _image=image[0]
            _segment_image=segment_image[0]
            _out_image=out_image[0]
            img=torch.stack([_image,_segment_image,_out_image],dim=0)
            save_image(img,f'{save_path}/{i}.png')
    print(f'epoch{j}-train_loss===>>{lossItem}')
    lossArr.append(lossItem)
x =range(epoch)
plt.figure()
plt.plot(x, lossArr)
plt.show()

#

In [8]:
test_dataset = MyDataset(train_files=test_files,mask_files=test_mask_files)
test_loader=DataLoader(test_dataset,batch_size=8,shuffle=True)

One sample ../data/FAZ/Domain1/test/imgs/024_N_30.png


In [12]:
net=UNet().to(device)
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('successful load weight！')
else:
    print('not successful load weight')
net.eval()

# img=cv2.imread(os.path.join(test_folder_path,file))
# mask_img = cv2.imread(os.path.join(test_mask_folder_path,file))
# img_data=transform(img).float()
# img_data=torch.unsqueeze(img_data,dim=0).to(device)
# out = net(img_data)[0]
# print(f'shape of output:{out.shape}')
# loss_item=dice_loss(out,transform(mask_img).to(device))
# print(f'loss:{loss_item}')
# out = out.permute((1,2,0)).cpu().detach().numpy()



# fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 8))
# axes[0].imshow(out)
# axes[0].axis('off')
# axes[0].set_title(f'Image predict')
# axes[1].imshow(mask_img)
# axes[1].axis('off')
# axes[1].set_title(f'Image mask')
# axes[2].imshow(img)
# axes[2].axis('off')
# axes[2].set_title(f'Image original')

# plt.tight_layout()
# plt.show()
loss_fun=nn.BCELoss()
lossItem = 0
for i,(image,segment_image) in enumerate(test_loader):
    image, segment_image=image.to(device),segment_image.to(device)
    out_image=net(image)
    train_loss=loss_fun(out_image,segment_image)
    lossItem = lossItem + train_loss.item()
print(f'loss:{lossItem}')        

# for item in test_file:
#     img=cv2.imread(os.path.join(test_folder_path,item))
#     img_data=transform(img/255).float()
#     img_data=torch.unsqueeze(img_data,dim=0).to(device)
#     mask_img = cv2.imread(os.path.join(test_mask_folder_path,item))
#     out_image = net(img_data)[0]
#     loss_item_dice.append(dice_loss(out_image,transform(mask_img).to(device)))
#     # loss_item_focal.append(focal_loss(out_image,transform(mask_img).to(device)))

# plt.plot(range(len(loss_item_dice)+1), loss_item_dice)
# plt.show()
    # out = torch.from_numpy(out)
    # save_image(out_image,f'{mask_save_path}/{item}')

successful load weight！
loss:0.1792113371193409


In [None]:
# 什么网络捏

In [9]:
#定义nnU-Net网络
""" Parts of the U-Net model """

class Conv_Block(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

# 下采样
class DownSample(nn.Module):
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

# 上采样
class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1)
    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest')
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out=nn.Conv2d(64,3,3,1,1)
        self.Th=nn.Sigmoid()

    def forward(self,x):
        R1=self.c1(x)
        R2=self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        O1=self.c6(self.u1(R5,R4))
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))

x=torch.randn(2,3,256,256)
net=UNet()
print(net(x).shape) 

torch.Size([2, 3, 256, 256])
