# Parts of the U-Net model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# Full assembly of the parts to form the complete network

In [3]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

# Train UNet

In [4]:
from utils.dataset import ISBI_Loader
from torch import optim

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计，初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数，输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            # 保存loss值最小的网络参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()
        if epoch = epochs-1:
            print(done)

# 选择设备，有cuda用cuda，没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络，图片单通道1，分类为1。
net = UNet(n_channels=1, n_classes=1)
# 将网络拷贝到deivce中
net.to(device=device)
# 指定训练集地址，开始训练
data_path = "data/train/"
train_net(net, device, data_path)



Loss/train 0.7277677655220032
Loss/train 0.6595970392227173
Loss/train 0.6281815767288208
Loss/train 0.6094270348548889
Loss/train 0.5284303426742554
Loss/train 0.5453810691833496
Loss/train 0.5190775990486145
Loss/train 0.48609039187431335
Loss/train 0.4894610643386841
Loss/train 0.46415722370147705
Loss/train 0.4443259835243225
Loss/train 0.4565272331237793
Loss/train 0.45964720845222473
Loss/train 0.44525468349456787
Loss/train 0.4157237410545349
Loss/train 0.40177133679389954
Loss/train 0.40595850348472595
Loss/train 0.3797835409641266
Loss/train 0.4148378372192383
Loss/train 0.3880782127380371
Loss/train 0.4008256494998932
Loss/train 0.4380495548248291
Loss/train 0.41005077958106995
Loss/train 0.3782820701599121
Loss/train 0.4108103811740875
Loss/train 0.3880535066127777
Loss/train 0.4061937928199768
Loss/train 0.36766961216926575
Loss/train 0.38242197036743164
Loss/train 0.3674065172672272
Loss/train 0.36919093132019043
Loss/train 0.3800937831401825
Loss/train 0.3411778211593628


Loss/train 0.25738129019737244
Loss/train 0.24069589376449585
Loss/train 0.2636692523956299
Loss/train 0.2586187720298767
Loss/train 0.2349618822336197
Loss/train 0.2571025490760803
Loss/train 0.2513042390346527
Loss/train 0.2521103024482727
Loss/train 0.27424895763397217
Loss/train 0.2535446882247925
Loss/train 0.24738043546676636
Loss/train 0.2270861566066742
Loss/train 0.2721731662750244
Loss/train 0.2570911645889282
Loss/train 0.2501601278781891
Loss/train 0.23039492964744568
Loss/train 0.22945666313171387
Loss/train 0.233295738697052
Loss/train 0.22871826589107513
Loss/train 0.2348945587873459
Loss/train 0.23606997728347778
Loss/train 0.21930333971977234
Loss/train 0.23055672645568848
Loss/train 0.2365962564945221
Loss/train 0.22632503509521484
Loss/train 0.24971629679203033
Loss/train 0.24809543788433075
Loss/train 0.2326025366783142
Loss/train 0.22084076702594757
Loss/train 0.22340573370456696
Loss/train 0.2293793261051178
Loss/train 0.22727075219154358
Loss/train 0.265468478202

Loss/train 0.2299439013004303
Loss/train 0.22793111205101013
Loss/train 0.17342647910118103
Loss/train 0.18887339532375336
Loss/train 0.2103615701198578
Loss/train 0.18673530220985413
Loss/train 0.18010520935058594
Loss/train 0.18429622054100037
Loss/train 0.19264158606529236
Loss/train 0.17047753930091858
Loss/train 0.18863004446029663
Loss/train 0.18533477187156677
Loss/train 0.2042684704065323
Loss/train 0.21969084441661835
Loss/train 0.1824648678302765
Loss/train 0.18202880024909973
Loss/train 0.19869495928287506
Loss/train 0.16918835043907166
Loss/train 0.16665522754192352
Loss/train 0.17898061871528625
Loss/train 0.16574181616306305
Loss/train 0.17288286983966827
Loss/train 0.19583475589752197
Loss/train 0.19333821535110474
Loss/train 0.1847568154335022
Loss/train 0.1824151873588562
Loss/train 0.18232271075248718
Loss/train 0.1893659383058548
Loss/train 0.17971718311309814
Loss/train 0.18982283771038055
Loss/train 0.17911681532859802
Loss/train 0.17049047350883484
Loss/train 0.17

Loss/train 0.17719686031341553
Loss/train 0.1413010060787201
Loss/train 0.17092236876487732
Loss/train 0.16432228684425354
Loss/train 0.15279464423656464
Loss/train 0.14312590658664703
Loss/train 0.13402500748634338
Loss/train 0.15148313343524933
Loss/train 0.15432003140449524
Loss/train 0.144993394613266
Loss/train 0.14424720406532288
Loss/train 0.15283477306365967
Loss/train 0.15073351562023163
Loss/train 0.15446437895298004
Loss/train 0.14013884961605072
Loss/train 0.1355498731136322
Loss/train 0.1649799793958664
Loss/train 0.1529901623725891
Loss/train 0.1470465362071991
Loss/train 0.13479724526405334
Loss/train 0.16566714644432068
Loss/train 0.1576082408428192
Loss/train 0.15658248960971832
Loss/train 0.14538197219371796
Loss/train 0.14157456159591675
Loss/train 0.16477283835411072
Loss/train 0.18290257453918457
Loss/train 0.14895862340927124
Loss/train 0.133034348487854
Loss/train 0.14616909623146057
Loss/train 0.14337316155433655
Loss/train 0.13867905735969543
Loss/train 0.15183

Loss/train 0.11575613170862198
Loss/train 0.13267311453819275
Loss/train 0.10915237665176392
Loss/train 0.1358335018157959
Loss/train 0.11258839815855026
Loss/train 0.11967572569847107
Loss/train 0.10483244806528091
Loss/train 0.11667843163013458
Loss/train 0.12309788167476654
Loss/train 0.10093630850315094
Loss/train 0.12057138979434967
Loss/train 0.12130872905254364
Loss/train 0.10840428620576859
Loss/train 0.14752553403377533
Loss/train 0.10888612270355225
Loss/train 0.10901141166687012
Loss/train 0.10479117184877396
Loss/train 0.11764159798622131
Loss/train 0.10858500003814697
Loss/train 0.127885103225708
Loss/train 0.10381962358951569
Loss/train 0.11429998278617859
Loss/train 0.10490847378969193
Loss/train 0.11155711859464645
Loss/train 0.12720413506031036
Loss/train 0.11900614947080612
Loss/train 0.12517857551574707
Loss/train 0.10609009861946106
Loss/train 0.10340137779712677
Loss/train 0.12136414647102356
Loss/train 0.12691758573055267
Loss/train 0.11697518825531006
Loss/train 

# Test

In [5]:
import glob
import numpy as np
import os
import cv2

if __name__ == "__main__":
    # 选择设备，有cuda用cuda，没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络，图片单通道，分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('data/test/*.png')
    # 遍历素有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_Seg.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1，通道为1，大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中，只用cpu就是拷贝到cpu中，用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)



# Evaluate

In [15]:
from PIL import Image
from pylab import *
from medpy import metric
from IPython.display import HTML, display

display(HTML("<table><tr><td><img src='data/test/0.png'></td><td><img src='data/test/0_Seg.png'></td></tr></table>"))

def calculate_metric_percase(pred, gt):
    
    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    hd = metric.binary.hd95(pred, gt)
    asd = metric.binary.asd(pred, gt)
    return dice, jc, hd, asd

#读取图片并转为数组
im_gt = array(Image.open("data/test/0.png"))
im_pred = array(Image.open("data/test/0_Seg.png"))

print('Evaluation resuts on DICE, JC, HD, ASD:')
calculate_metric_percase(im_pred, im_gt)

Evaluation resuts on DICE, JC, HD, ASD:


(0.8956019996081678, 0.8109413447782546, 129.0, 46.49052762124423)