In [2]:
import cv2
import numpy as np
import os
import pandas as ps
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
# This is for the progress bar.
from tqdm.auto import tqdm
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import save_image

In [3]:
#文件夹位置加载
train_folder_path = '../data/FAZ/Domain1/train/imgs'
mask_folder_path = '../data/FAZ/Domain1/train/mask'
test_folder_path='../data/FAZ/Domain1/test/imgs'
train_files = os.listdir(train_folder_path)
mask_files= os.listdir(mask_folder_path)
print(train_files[0])

057_N_60.png


In [4]:
# 定义数据集

transform=transforms.Compose([
    transforms.ToTensor()
])
class MyDataset(Dataset):

    def __init__(self,train_files = None,test_files = None):
        super(MyDataset).__init__()
        self.train_files =train_files
        if train_files != None:
            self.train_files = train_files
        self.test_files =test_files
        if test_files != None:
            self.test_files = test_files
        print(f"One sample",self.train_files[0])
    
    def __len__(self):
        return len(self.train_files)
  
    def __getitem__(self,idx):
        fname = self.train_files[idx]
        im = cv2.imread(os.path.join(train_folder_path,fname))
        #im = self.data[idx]
        mname = self.test_files[idx]
        segemnt_im = cv2.imread(os.path.join(mask_folder_path,mname))
        return transform(im),transform(segemnt_im)

In [6]:
#定义网络
""" Parts of the U-Net model """

class Conv_Block(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class DownSample(nn.Module):
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class UpSample(nn.Module):
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1)
    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest')
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out=nn.Conv2d(64,3,3,1,1)
        self.Th=nn.Sigmoid()

    def forward(self,x):
        R1=self.c1(x)
        R2=self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        O1=self.c6(self.u1(R5,R4))
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))

x=torch.randn(2,3,256,256)
net=UNet()
print(net(x).shape) 

torch.Size([2, 3, 256, 256])


In [7]:
dataset = MyDataset(train_files=train_files,test_files=mask_files)
train_set,validation_set = train_test_split(dataset,test_size=0.2,random_state=377,shuffle=True)
print(f'len of the train_set is {len(train_set)};\nlen of the test_set is {len(validation_set)}')

One sample 057_N_60.png
len of the train_set is 195;
len of the test_set is 49


### Tarin

In [8]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path='params/unet.pth'
data_path=r'data'
save_path='train_image'

data_loader=DataLoader(dataset,batch_size=8,shuffle=True)
net=UNet().to(device)
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('successful load weight！')
else:
    print('not successful load weight')

opt=optim.Adam(net.parameters())
loss_fun=nn.BCELoss()

epoch=1
for i in range(epoch):
        for i,(image,segment_image) in enumerate(data_loader):
            image, segment_image=image.to(device),segment_image.to(device)

            out_image=net(image)
            train_loss=loss_fun(out_image,segment_image)

            opt.zero_grad()
            train_loss.backward()
            opt.step()

            if i%5==0:
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            if i%50==0:
                torch.save(net.state_dict(),weight_path)

            _image=image[0]
            _segment_image=segment_image[0]
            _out_image=out_image[0]

            img=torch.stack([_image,_segment_image,_out_image],dim=0)
            save_image(img,f'{save_path}/{i}.png')

        epoch+=1


successful load weight！
1-0-train_loss===>>0.04805968701839447
1-5-train_loss===>>0.028662240132689476
1-10-train_loss===>>0.03715529665350914
1-15-train_loss===>>0.03193848580121994
1-20-train_loss===>>0.023226846009492874
1-25-train_loss===>>0.032953523099422455
1-30-train_loss===>>0.02383861504495144
