In [1]:
import math
import os
import pickle
import collections
import math
from itertools import repeat
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import pathlib, sys, os, random, time
import cv2
from torch.utils import data
from PIL import Image
from torchvision import datasets,transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from tqdm.notebook import tqdm
import albumentations as A
import functools
from torchvision import models
import torchvision
import Augmentor

In [2]:
transform = transforms.Compose([
 transforms.Resize((448,448)),
 transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
#  transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])

In [3]:
EPOCHES = 120
BATCH_SIZE = 32
IMAGE_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

trfm = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(),
])

In [4]:
class CrackData(data.Dataset):
    """Dataset for Crack detection"""
    def __init__(self, data_images, data_GT,transform):
        imgs = os.listdir(data_images)
        self.imgs=[os.path.join(data_images,k) for k in imgs]
        GT = os.listdir(data_GT)
        self.GTs=[os.path.join(data_GT,k) for k in GT]
        self.transforms=transform

    def __getitem__(self, index):
        img_path = self.imgs[index]
        GT_path = self.GTs[index]
        pil_img = Image.open(img_path)
        pil_GT = Image.open(GT_path)
        pil_GT = pil_GT.convert('L')
        if self.transforms:
            data = self.transforms(pil_img)
            label = self.transforms(pil_GT)
        else:
            pil_img = np.asarray(pil_img)
            pil_GT = np.asarray(pil_GT)
            data = torch.from_numpy(pil_img)
            label = torch.from_numpy(pil_GT)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

In [5]:
dataset = CrackData('./datas/train_img','./datas/train_lab',transform)
len(dataset)

765

In [6]:
valid_idx, train_idx = [], []
for i in range(len(dataset)):
    if i % 7 == 0:
        valid_idx.append(i)
    else:
        train_idx.append(i)
   # elif i % 7 == 1:
print(len(valid_idx))
print(len(train_idx))
       
        
train_ds = data.Subset(dataset, train_idx)
valid_ds = data.Subset(dataset, valid_idx)

# define training and validation data loaders
loader = data.DataLoader(
    train_ds, batch_size=4, shuffle=True, drop_last=True, num_workers=0)

vloader = data.DataLoader(
    valid_ds, batch_size=4, shuffle=False,drop_last=True, num_workers=0)

110
655


In [7]:
def x2conv(in_channels, out_channels, inner_channels=None):
    inner_channels = out_channels // 2 if inner_channels is None else inner_channels
    down_conv = nn.Sequential(
        nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(inner_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    return down_conv

In [8]:
class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder, self).__init__()
        self.down_conv = x2conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.down_conv(x)
        x = self.pool(x)
        return x

In [9]:
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
#             nn.Upsample(scale_factor=2),
#             ch_out = ch_in/2,
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [10]:
class up2_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up2_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [11]:
class down_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(down_conv,self).__init__()
        self.up = nn.Sequential(
#             nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [12]:
class decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.up_conv = x2conv(in_channels, out_channels)

    def forward(self, x_copy, x, interpolate=True):
        x = self.up(x)

        if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)):
            if interpolate:
                # Iterpolating instead of padding
                x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)),
                                mode="bilinear", align_corners=True)
            else:
                # Padding in case the incomping volumes are of different sizes
                diffY = x_copy.size()[2] - x.size()[2]
                diffX = x_copy.size()[3] - x.size()[3]
                x = F.pad(x, (diffX // 2, diffX - diffX // 2,
                                diffY // 2, diffY - diffY // 2))

        # Concatenate
        x = torch.cat([x_copy, x], dim=1)
        x = self.up_conv(x)
        return x

In [13]:
class MDM(nn.Module):
    def __init__(self):
        super(MDM, self).__init__()

        self.conv_3x3_1 = nn.Conv2d(512, 256, kernel_size=3, padding=2, dilation=2)
        self.bn_conv_3x3_1 = nn.BatchNorm2d(256)
        self.conv_3x3_12 = nn.Conv2d(256, 256, kernel_size=3, padding=2, dilation=2)
        self.bn_conv_3x3_12 = nn.BatchNorm2d(256)

        self.conv_3x3_2 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=4, dilation=4)
        self.bn_conv_3x3_2 = nn.BatchNorm2d(256)
        self.conv_3x3_22 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=4, dilation=4)
        self.bn_conv_3x3_22 = nn.BatchNorm2d(256)

        self.conv_3x3_3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=8, dilation=8)
        self.bn_conv_3x3_3 = nn.BatchNorm2d(256)
        self.conv_3x3_32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=8, dilation=8)
        self.bn_conv_3x3_32 = nn.BatchNorm2d(256)

        self.conv_3x3_4 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn_conv_3x3_4 = nn.BatchNorm2d(256)
        self.conv_3x3_42 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn_conv_3x3_42 = nn.BatchNorm2d(256)

        self.conv_1x1_1 = nn.Conv2d(1024, 1024, kernel_size=1)
        self.bn_conv_1x1_1 = nn.BatchNorm2d(1024)

    def forward(self, feature_map):
        # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8))

        # feature_map_h = feature_map.size()[2] # (== h/16)
        # feature_map_w = feature_map.size()[3] # (== w/16)

        out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_12 = F.relu(self.bn_conv_3x3_12(self.conv_3x3_12(out_3x3_1)))  # (shape: (batch_size, 256, h/16, w/16))
#         print('out_3x3_12', out_3x3_12.shape)

        out_3x3_2 = F.relu(self.bn_conv_3x3_2( self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_22 = F.relu(self.bn_conv_3x3_22(self.conv_3x3_22(out_3x3_2)))  # (shape: (batch_size, 256, h/16, w/16))
#         print('out_3x3_22', out_3x3_22.shape)

        out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_32 = F.relu(self.bn_conv_3x3_32(self.conv_3x3_32(out_3x3_3))) # (shape: (batch_size, 256, h/16, w/16))
#         print('out_3x3_32', out_3x3_32.shape)

        out_3x3_4 = F.relu(self.bn_conv_3x3_4(self.conv_3x3_4(feature_map)))  # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_42 = F.relu(self.bn_conv_3x3_42(self.conv_3x3_42(out_3x3_4)))  # (shape: (batch_size, 256, h/16, w/16))
#         print('out_3x3_42', out_3x3_42.shape)

        out = torch.cat([out_3x3_12, out_3x3_22, out_3x3_32, out_3x3_42], 1) # (shape: (batch_size, 1280, h/16, w/16))
        out = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(out))) # (shape: (batch_size, 256, h/16, w/16))

        return out

In [14]:
class sSE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
        self.norm = nn.Sigmoid()

    def forward(self, U):
        q = self.Conv1x1(U)  # U:[bs,c,h,w] to q:[bs,1,h,w]
        q = self.norm(q)
        return U * q  # 广播机制

class cSE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)#全局平均池化
        self.Conv_Squeeze = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False)
        self.Conv_Excitation = nn.Conv2d(in_channels//2, in_channels, kernel_size=1, bias=False)
        self.norm = nn.Sigmoid()

    def forward(self, U):
        z = self.avgpool(U)# shape: [bs, c, h, w] to [bs, c, 1, 1]
        z = self.Conv_Squeeze(z) # shape: [bs, c/2]#线性变换
        z = self.Conv_Excitation(z) # shape: [bs, c]
        z = self.norm(z)#非线性变换
        return U * z.expand_as(U)#每个通道乘以一个权重

class scSE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.cSE = cSE(in_channels)
        self.sSE = sSE(in_channels)

    def forward(self, U):
        U_sse = self.sSE(U)
        U_cse = self.cSE(U)
        return U_cse+U_sse

In [15]:
class MDM2(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(MDM2, self).__init__()

        self.conv_3x3_1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=2, dilation=2)
        self.bn_conv_3x3_1 = nn.BatchNorm2d(out_channel)
        self.conv_3x3_12 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=2, dilation=2)
        self.bn_conv_3x3_12 = nn.BatchNorm2d(out_channel)

        self.conv_3x3_2 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=4, dilation=4)
        self.bn_conv_3x3_2 = nn.BatchNorm2d(out_channel)
        self.conv_3x3_22 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=4, dilation=4)
        self.bn_conv_3x3_22 = nn.BatchNorm2d(out_channel)

        self.conv_3x3_3 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=8, dilation=8)
        self.bn_conv_3x3_3 = nn.BatchNorm2d(out_channel)
        self.conv_3x3_32 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=8, dilation=8)
        self.bn_conv_3x3_32 = nn.BatchNorm2d(out_channel)

        self.conv_3x3_4 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn_conv_3x3_4 = nn.BatchNorm2d(out_channel)
        self.conv_3x3_42 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn_conv_3x3_42 = nn.BatchNorm2d(out_channel)

        self.conv_1x1_1 = nn.Conv2d(out_channel*4, out_channel*4, kernel_size=1)
        self.bn_conv_1x1_1 = nn.BatchNorm2d(out_channel*4)
        self.attenion = scSE(out_channel)

    def forward(self, feature_map):
        # (feature_map has shape (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet instead is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8))

        # feature_map_h = feature_map.size()[2] # (== h/16)
        # feature_map_w = feature_map.size()[3] # (== w/16)

        out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_12 = F.relu(self.bn_conv_3x3_12(self.conv_3x3_12(out_3x3_1)))  # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_12 = self.attenion(out_3x3_12)
        print('out_3x3_12', out_3x3_12.shape)

        out_3x3_2 = F.relu(self.bn_conv_3x3_2( self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_22 = F.relu(self.bn_conv_3x3_22(self.conv_3x3_22(out_3x3_2)))  # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_22 = self.attenion(out_3x3_22)
        print('out_3x3_22', out_3x3_22.shape)

        out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_32 = F.relu(self.bn_conv_3x3_32(self.conv_3x3_32(out_3x3_3))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_32 = self.attenion(out_3x3_32)
        print('out_3x3_32', out_3x3_32.shape)

        out_3x3_4 = F.relu(self.bn_conv_3x3_4(self.conv_3x3_4(feature_map)))  # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_42 = F.relu(self.bn_conv_3x3_42(self.conv_3x3_42(out_3x3_4)))  # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_42 = self.attenion(out_3x3_42)
        print('out_3x3_42', out_3x3_42.shape)

        out = torch.cat([out_3x3_12, out_3x3_22, out_3x3_32, out_3x3_42], 1) # (shape: (batch_size, 1280, h/16, w/16))
        out = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(out))) # (shape: (batch_size, 256, h/16, w/16))

        return out

In [16]:
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        # 下采样的gating signal 卷积
        g1 = self.W_g(g)
        # 上采样的 l 卷积
        x1 = self.W_x(x)
        # concat + relu
        psi = self.relu(g1 + x1)
        # channel 减为1，并Sigmoid,得到权重矩阵
        psi = self.psi(psi)
        # 返回加权的 x
        return x * psi

In [17]:
class eca_layer(nn.Module):
    """Constructs a ECA module.

    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

In [18]:
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6
    
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
    
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.relu = h_swish()

    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.relu(y) 
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        x_h = self.conv2(x_h).sigmoid()
        x_w = self.conv3(x_w).sigmoid()
        x_h = x_h.expand(-1, -1, h, w)
        x_w = x_w.expand(-1, -1, h, w)

        y = identity * x_w * x_h

        return y

In [19]:
class UNet(nn.Module):
    def __init__(self, num_classes, in_channels=3):
        super(UNet, self).__init__()

        self.start_conv = x2conv(in_channels, 64)
        self.down1 = encoder(64, 128)
        self.down2 = encoder(128, 256)
        self.down3 = encoder(256, 512)
        self.down4 = encoder(512, 1024)

        self.middle_conv = x2conv(1024, 1024)
        
        self.Up1 = up_conv(ch_in=1024, ch_out=512)
        self.Up2 = up2_conv(ch_in=512,  ch_out=256)
        self.Up3 = up2_conv(ch_in=256, ch_out=128)
        self.Up4 = up2_conv(ch_in=128, ch_out=64)
        
        self.CDown1 = down_conv(1024, 512)
        self.CDown2 = down_conv(512, 256)
        self.CDown3 = down_conv(256, 128)
        self.CDown4 = down_conv(128, 64)
        
        
        self.Channel1 = eca_layer(512)
        self.Channel2 = eca_layer(256)
        self.Channel3 = eca_layer(128)
        self.Channel4 = eca_layer(64)
        
        self.CA1 = CoordAtt(512,512)
        self.CA2 = CoordAtt(256,256)
        self.CA3 = CoordAtt(128,128)
        self.CA4 = CoordAtt(64,64)
        
        self.Att1 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.Att2 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.Att4 = Attention_block(F_g=64, F_l=64, F_int=32)



        self.up1 = decoder(1024, 512)
        self.up2 = decoder(512, 256)
        self.up3 = decoder(256, 128)
        self.up4 = decoder(128, 64)
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        self._initialize_weights()
        self.mdm = MDM2(512,256)

        self.side5_conv = nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, bias=False)
        self.side4_conv = nn.Conv2d(512, num_classes, kernel_size=1, stride=1, bias=False)
        self.side3_conv = nn.Conv2d(256, num_classes, kernel_size=1, stride=1, bias=False)
        self.side2_conv = nn.Conv2d(128, num_classes, kernel_size=1, stride=1, bias=False)
        self.side1_conv = nn.Conv2d(64, num_classes, kernel_size=1, stride=1, bias=False)
        self.fuse_conv = nn.Conv2d(num_classes * 5, num_classes, kernel_size=1, stride=1, bias=False)

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self, x):
        h, w = x.size()[2:] #224 224
        x1 = self.start_conv(x)
        print('x1',x1.shape) #10, 64, 224, 224
        x2 = self.down1(x1)
        print('x2',x2.shape) # 10, 128, 112, 112
        x3 = self.down2(x2)
        print('x3',x3.shape) # 10, 256, 56, 56
        x4 = self.down3(x3)   
        print('x4',x4.shape) #10, 512, 28, 28
        side5 = self.mdm(x4) 
        print('side5',side5.shape) #([10, 1024, 28, 28])
        

        
#         d5 = self.Up1(side5) 
#         print('d5',d5.shape)#d5 torch.Size([10, 512, 28, 28])
#         x4_a = self.Att1(g=d5, x=x4)
#         print('x4_a',x4_a.shape) #x4_a torch.Size([10, 512, 28, 28])
        x4_c = self.Channel1(x4)
        print('x4_c',x4_c.shape)
    
        x4_CA = self.CA1(x4)
        print('x4_CA',x4_CA.shape) #torch.Size([10, 512, 28, 28])
        x4_CA = x4_CA + x4_c
        d4 = self.Up1(side5) 
        print('d4',d4.shape) #torch.Size([10, 512, 28, 28])

        side4 = self.up1(x4_CA, side5)
        print('side4',side4.shape) #10, 512, 28, 28
        
        side4 = torch.cat([side4, d4], 1)# torch.Size([10, 1024, 28, 28])
        side4 = self.CDown1(side4)
        print('side4',side4.shape) #10, 512, 28, 28

#         x3_a = self.Att2(g=d4, x=x3)
#         print('x3_a',x3_a.shape) #x3_a torch.Size([10, 256, 56, 56])
        
        d3 = self.Up2(d4) 
        print('d4',d4.shape) #d4 torch.Size([10, 256, 56, 56])
        x3_c = self.Channel2(x3)
        print('x3_c',x3_c.shape)
        x3_CA = self.CA2(x3)
        x3_CA = x3_CA + x3_c
        print('x3_CA',x3_CA.shape) #([10, 256, 56, 56])
        side3 = self.up2(x3_CA, side4)
        print('side3',side3.shape) #side3 torch.Size([10, 256, 56, 56])
        side3 = torch.cat([side3, d3], 1)
        side3 = self.CDown2(side3)
        print('side3',side3.shape) # 10, 256, 56, 56
        
        d2 = self.Up3(d3)
        print('d2',d2.shape) #d3 torch.Size([10, 128, 112, 112])
#         x2_a = self.Att3(g=d3, x=x2)
#         print('x2_a',x2_a.shape) # x2_a torch.Size([10, 128, 112, 112])

        x2_c = self.Channel3(x2)
        print('x2_c',x2_c.shape)
        x2_CA = self.CA3(x2)
        
        x2_CA = x2_CA + x2_c
        print('x2_CA',x2_CA.shape) #x2_CA torch.Size([10, 128, 112, 112])
        side2 = self.up3(x2_CA, side3)
        print('side2',side2.shape) #side2 torch.Size([10, 128, 112, 112])
        side2 = torch.cat([side2, d2], 1)
        side2 = self.CDown3(side2)
        print('side2',side2.shape) #side2 torch.Size([10, 128, 112, 112])
        
        d1= self.Up4(d2)
        print('d1',d1.shape) #d2 torch.Size([10, 64, 224, 224])
#         x1_a = self.Att4(g=d2, x=x1)
#         print('x1_a',x1_a.shape) #x2_a torch.Size([10, 64, 224, 224])

        x1_c = self.Channel3(x1)
        print('x1_c',x1_c.shape)

        x1_CA = self.CA4(x1)
        x1_CA = x1_CA + x1_c
        print('x1_CA',x1_CA.shape)
        side1 = self.up4(x1_CA, side2)
        side1 = torch.cat([side1, d1], 1)
        side1 = self.CDown4(side1)
        print('side1',side1.shape) #10, 64, 224, 224

        side_output5 = self.side5_conv(side5)
        side_output4 = self.side4_conv(side4)
        side_output3 = self.side3_conv(side3)
        side_output2 = self.side2_conv(side2)
        side_output1 = self.final_conv(side1)
        # side_output1 = self.side1_conv(side1)

        side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear',
                                     align_corners=True)  # self.up2(side_output2)
        side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear',
                                     align_corners=True)  # self.up4(side_output3)
        side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear',
                                     align_corners=True)  # self.up8(side_output4)
        side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear',
                                     align_corners=True)  # self.up16(side_output5)

        fused = self.fuse_conv(torch.cat([side_output1,
                                          side_output2,
                                          side_output3,
                                          side_output4,
                                          side_output5], dim=1))

        return side_output1, side_output2, side_output3, side_output4, side_output5, fused

In [20]:
@torch.no_grad()
def validation(model, loader, loss_fn):
    losses = []
    model.eval()
    for image, target in loader:
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        #output = model(image)['out']
        output = model(image)
        loss1 = loss_fn(output[0], target)
        loss2 = loss_fn(output[1], target)
        loss3 = loss_fn(output[2], target)
        loss4 = loss_fn(output[3], target)
        loss5 = loss_fn(output[4], target)
        fuse_loss = loss_fn(output[-1], target)
        loss =loss1+loss2+loss3+loss4+loss5+fuse_loss
        losses.append(loss.item())
        
    return np.array(losses).mean()

In [None]:
model = UNet(num_classes=1)
images = torch.rand(10, 3, 224, 224)
result = model(images)
print(result[0].shape)

In [14]:
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(),
                 lr=1e-4, weight_decay=1e-3)

class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2,-1)):

        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims
    
    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()
        return 1 - dc
    
bce_fn = nn.BCEWithLogitsLoss()
dice_fn = SoftDiceLoss()

def loss_fn(y_pred, y_true):
    bce = bce_fn(y_pred, y_true)
#     dice = dice_fn(y_pred.sigmoid(), y_true)
    return bce

In [15]:
header = r'''
        Train | Valid
Epoch |  Loss |  Loss | Time, m
'''
#          Epoch         metrics            time
raw_line = '{:6d}' + '\u2502{:7.3f}'*2 + '\u2502{:6.2f}'
print(header)

EPOCHES = 200
best_loss = 10
for epoch in range(1, EPOCHES+1):
    losses = []
    start_time = time.time()
    model.train()
    for image, target in tqdm_notebook(loader):
        
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        optimizer.zero_grad()
        #output = model(image)['out']
        output = model(image)
        loss1 = loss_fn(output[0], target)
        loss2 = loss_fn(output[1], target)
        loss3 = loss_fn(output[2], target)
        loss4 = loss_fn(output[3], target)
        loss5 = loss_fn(output[4], target)
        fuse_loss = loss_fn(output[-1], target)
        loss =loss1+loss2+loss3+loss4+loss5+fuse_loss
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print(loss.item())
        
    vloss = validation(model, vloader, loss_fn)
    print(raw_line.format(epoch, np.array(losses).mean(), vloss,
                              (time.time()-start_time)/60**1))
    losses = []
    
    if vloss < best_loss:
        best_loss = vloss
        torch.save(model.state_dict(), 'UHDN_Tunnel200_best.pth')


        Train | Valid
Epoch |  Loss |  Loss | Time, m



HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     1│  3.503│  3.093│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     2│  2.975│  2.836│  0.57


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     3│  2.793│  2.720│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     4│  2.693│  2.653│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     5│  2.623│  2.610│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     6│  2.568│  2.543│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     7│  2.517│  2.498│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     8│  2.468│  2.442│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


     9│  2.421│  2.398│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    10│  2.378│  2.375│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    11│  2.327│  2.318│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    12│  2.285│  2.263│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    13│  2.239│  2.240│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    14│  2.194│  2.180│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    15│  2.142│  2.126│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    16│  2.110│  2.088│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    17│  2.052│  2.036│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    18│  2.017│  1.985│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    19│  1.955│  1.924│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    20│  1.899│  1.882│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    21│  1.850│  1.844│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    22│  1.793│  1.786│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    23│  1.734│  1.702│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    24│  1.687│  1.652│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    25│  1.622│  1.628│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    26│  1.567│  1.558│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    27│  1.511│  1.508│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    28│  1.450│  1.417│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    29│  1.408│  1.381│  0.41


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    30│  1.341│  1.317│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    31│  1.266│  1.244│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    32│  1.194│  1.156│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    33│  1.126│  1.065│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    34│  1.039│  1.025│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    35│  0.958│  0.937│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    36│  0.882│  0.847│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    37│  0.804│  0.812│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    38│  0.734│  0.687│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    39│  0.689│  0.684│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    40│  0.601│  0.640│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    41│  0.560│  0.586│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    42│  0.503│  0.512│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    43│  0.456│  0.492│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    44│  0.402│  0.456│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    45│  0.373│  0.479│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    46│  0.345│  0.476│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    47│  0.335│  0.419│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    48│  0.325│  0.358│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    49│  0.303│  0.398│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    50│  0.279│  0.402│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    51│  0.261│  0.379│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    52│  0.243│  0.387│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    53│  0.235│  0.359│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    54│  0.216│  0.374│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    55│  0.205│  0.367│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    56│  0.200│  0.371│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    57│  0.196│  0.362│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    58│  0.182│  0.361│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    59│  0.187│  0.335│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    60│  0.169│  0.352│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    61│  0.178│  0.298│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    62│  0.175│  0.359│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    63│  0.193│  0.289│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    64│  0.186│  0.440│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    65│  0.198│  0.317│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    66│  0.167│  0.326│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    67│  0.159│  0.335│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    68│  0.142│  0.306│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    69│  0.139│  0.336│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    70│  0.138│  0.328│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    71│  0.134│  0.350│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    72│  0.127│  0.333│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    73│  0.123│  0.335│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    74│  0.125│  0.347│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    75│  0.127│  0.327│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    76│  0.160│  0.326│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    77│  0.158│  0.307│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    78│  0.136│  0.349│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    79│  0.120│  0.342│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    80│  0.118│  0.328│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    81│  0.113│  0.342│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    82│  0.110│  0.333│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    83│  0.108│  0.359│  0.47


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    84│  0.110│  0.347│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    85│  0.107│  0.354│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    86│  0.121│  0.367│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    87│  0.118│  0.321│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    88│  0.116│  0.334│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    89│  0.105│  0.332│  0.47


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    90│  0.108│  0.346│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    91│  0.102│  0.330│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    92│  0.098│  0.372│  0.65


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    93│  0.101│  0.325│  0.58


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    94│  0.110│  0.357│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    95│  0.115│  0.333│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    96│  0.112│  0.436│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    97│  0.105│  0.333│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    98│  0.101│  0.312│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


    99│  0.096│  0.317│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   100│  0.092│  0.377│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   101│  0.093│  0.349│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   102│  0.093│  0.348│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   103│  0.089│  0.389│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   104│  0.093│  0.345│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   105│  0.092│  0.364│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   106│  0.178│  1.555│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   107│  0.327│  0.331│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   108│  0.251│  0.285│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   109│  0.178│  0.278│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   110│  0.157│  0.366│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   111│  0.148│  0.303│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   112│  0.125│  0.298│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   113│  0.112│  0.364│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   114│  0.109│  0.267│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   115│  0.102│  0.332│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   116│  0.095│  0.305│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   117│  0.128│  0.288│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   118│  0.121│  0.283│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   119│  0.102│  0.323│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   120│  0.091│  0.329│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   121│  0.087│  0.311│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   122│  0.085│  0.342│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   123│  0.085│  0.351│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   124│  0.080│  0.328│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   125│  0.079│  0.339│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   126│  0.079│  0.342│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   127│  0.079│  0.351│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   128│  0.077│  0.350│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   129│  0.081│  0.311│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   130│  0.141│  0.255│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   131│  0.148│  0.321│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   132│  0.123│  0.290│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   133│  0.108│  0.283│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   134│  0.095│  0.335│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   135│  0.083│  0.336│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   136│  0.082│  0.331│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   137│  0.078│  0.355│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   138│  0.076│  0.338│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   139│  0.076│  0.354│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   140│  0.074│  0.344│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   141│  0.075│  0.344│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   142│  0.075│  0.368│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   143│  0.078│  0.326│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   144│  0.074│  0.343│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   145│  0.073│  0.390│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   146│  0.075│  0.322│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   147│  0.074│  0.327│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   148│  0.076│  0.343│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   149│  0.075│  0.327│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   150│  0.073│  0.342│  0.47


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   151│  0.073│  0.309│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   152│  0.070│  0.359│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   153│  0.072│  0.363│  0.47


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   154│  0.072│  0.359│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   155│  0.070│  0.331│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   156│  0.075│  0.316│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   157│  0.076│  0.355│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   158│  0.076│  0.330│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   159│  0.077│  0.363│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   160│  0.072│  0.355│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   161│  0.075│  0.361│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   162│  0.072│  0.355│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   163│  0.074│  0.341│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   164│  0.075│  0.317│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   165│  0.078│  0.341│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   166│  0.092│  0.229│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   167│  0.195│  0.516│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   168│  0.202│  0.278│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   169│  0.138│  0.264│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   170│  0.103│  0.259│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   171│  0.094│  0.275│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   172│  0.088│  0.283│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   173│  0.078│  0.298│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   174│  0.073│  0.319│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   175│  0.070│  0.333│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   176│  0.071│  0.331│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   177│  0.068│  0.318│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   178│  0.064│  0.315│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   179│  0.067│  0.321│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   180│  0.071│  0.330│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   181│  0.066│  0.330│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   182│  0.066│  0.331│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   183│  0.067│  0.310│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   184│  0.067│  0.318│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   185│  0.067│  0.339│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   186│  0.065│  0.326│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   187│  0.068│  0.335│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   188│  0.066│  0.327│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   189│  0.063│  0.313│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   190│  0.066│  0.348│  0.43


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   191│  0.065│  0.327│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   192│  0.068│  0.327│  0.42


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   193│  0.068│  0.319│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   194│  0.065│  0.342│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   195│  0.064│  0.322│  0.47


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   196│  0.064│  0.333│  0.46


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   197│  0.063│  0.327│  0.45


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   198│  0.063│  0.326│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   199│  0.066│  0.325│  0.44


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))


   200│  0.078│  0.291│  0.46


In [15]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' #使用GPU或者cpu
model.load_state_dict(torch.load("UHDN_Tunnel200_best.pth")) #加载模型的参数
model.to(DEVICE) #模型送到当前设备中

print(model)

UNet(
  (start_conv): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (down1): encoder(
    (down_conv): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mod

In [35]:
test_imgs = os.listdir('./Test/Zhou/1')
img_name = test_imgs
test_imgs=[os.path.join('./Test/Zhou/1',k) for k in test_imgs]
test_imgs
i = 0

In [36]:
for img_data in test_imgs:
    img=Image.open(img_data)
    img = transform(img)
    img.unsqueeze_(0)
    img = img.to(DEVICE)
    output = model(img)
    save_image(output[0], './Test/Zhou/SwinTunel/'+img_name[i])
    i = i+1