In [1]:
import math
import os
import pickle
import collections
import math
from itertools import repeat
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import pathlib, sys, os, random, time
import cv2
from torch.utils import data
from PIL import Image
from torchvision import datasets,transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from tqdm.notebook import tqdm
import albumentations as A
import functools
from torchvision import models
import torchvision

In [2]:
transform = transforms.Compose([
 transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
 # transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1,1]
])

In [3]:
EPOCHES = 120
BATCH_SIZE = 32
IMAGE_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

trfm = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(),
])

In [4]:
class CrackData(data.Dataset):
    """Dataset for Crack detection"""
    def __init__(self, data_images, data_GT,transform):
        imgs = os.listdir(data_images)
        self.imgs=[os.path.join(data_images,k) for k in imgs]
        GT = os.listdir(data_GT)
        self.GTs=[os.path.join(data_GT,k) for k in GT]
        self.transforms=transform

    def __getitem__(self, index):
        img_path = self.imgs[index]
        GT_path = self.GTs[index]
        pil_img = Image.open(img_path)
        pil_GT = Image.open(GT_path)
        if self.transforms:
            data = self.transforms(pil_img)
            label = self.transforms(pil_GT)
        else:
            pil_img = np.asarray(pil_img)
            pil_GT = np.asarray(pil_GT)
            data = torch.from_numpy(pil_img)
            label = torch.from_numpy(pil_GT)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

In [5]:
dataset = CrackData('./datas/train_img','./datas/train_lab',transform)
len(dataset)

765

In [6]:
valid_idx, train_idx = [], []
for i in range(len(dataset)):
    if i % 7 == 0:
        valid_idx.append(i)
    else:
        train_idx.append(i)
   # elif i % 7 == 1:
print(len(valid_idx))
print(len(train_idx))
       
        
train_ds = data.Subset(dataset, train_idx)
valid_ds = data.Subset(dataset, valid_idx)

# define training and validation data loaders
loader = data.DataLoader(
    train_ds, batch_size=4, shuffle=True, drop_last=True, num_workers=0)

vloader = data.DataLoader(
    valid_ds, batch_size=4, shuffle=False,drop_last=True, num_workers=0)

110
655


In [7]:
def x2conv(in_channels, out_channels, inner_channels=None):
    inner_channels = out_channels // 2 if inner_channels is None else inner_channels
    down_conv = nn.Sequential(
        nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(inner_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    return down_conv

In [8]:
class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder, self).__init__()
        self.down_conv = x2conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.down_conv(x)
        x = self.pool(x)
        return x

In [9]:
class decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.up_conv = x2conv(in_channels, out_channels)

    def forward(self, x_copy, x, interpolate=True):
        x = self.up(x)

        if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)):
            if interpolate:
                # Iterpolating instead of padding
                x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)),
                                mode="bilinear", align_corners=True)
            else:
                # Padding in case the incomping volumes are of different sizes
                diffY = x_copy.size()[2] - x.size()[2]
                diffX = x_copy.size()[3] - x.size()[3]
                x = F.pad(x, (diffX // 2, diffX - diffX // 2,
                                diffY // 2, diffY - diffY // 2))

        # Concatenate
        x = torch.cat([x_copy, x], dim=1)
        x = self.up_conv(x)
        return x

In [10]:
class RF(nn.Module):
    # Revised from: Receptive Field Block Net for Accurate and Fast Object Detection, 2018, ECCV
    # GitHub: https://github.com/ruinmessi/RFBNet
    def __init__(self, in_channel, out_channel):
        super(RF, self).__init__()
        self.relu = nn.ReLU(True)

        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )

        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), dim=1))

        x = self.relu(x_cat + self.conv_res(x))
        return x

In [11]:
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [12]:
class aggregation(nn.Module):
    # dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(aggregation, self).__init__()
        self.relu = nn.ReLU(True)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)

        self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
        self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv5 = nn.Conv2d(3*channel, 1, 1)

    def forward(self, x1, x2, x3):
        x1_1 = x1
        x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
        x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
               * self.conv_upsample3(self.upsample(x2)) * x3

        x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
        x2_2 = self.conv_concat2(x2_2)

        x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
        x3_2 = self.conv_concat3(x3_2)

        x = self.conv4(x3_2)
        x = self.conv5(x)

        return x

In [13]:
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

In [14]:
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.relu = h_swish()

    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.relu(y) 
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        x_h = self.conv2(x_h).sigmoid()
        x_w = self.conv3(x_w).sigmoid()
        x_h = x_h.expand(-1, -1, h, w)
        x_w = x_w.expand(-1, -1, h, w)

        y = identity * x_w * x_h

        return y

In [15]:
class UNet(nn.Module):
    def __init__(self, num_classes=1, in_channels=3, channel=32):
        super(UNet, self).__init__()
        
  # ---- ResNet Backbone ----
        self.start_conv = x2conv(in_channels, 64)
        self.down1 = encoder(64, 128)
        self.down2 = encoder(128, 256)
        self.down3 = encoder(256, 512)
        self.down4 = encoder(512, 1024)
        self.down5 = encoder(1024, 1024)
 # ---- Receptive Field Block like module ---- 
        self.rf1= RF(512,channel)
        self.rf2= RF(1024,channel)
        self.rf3= RF(1024,channel)
        
# ---- aggregation ---- 
        self.agg1 = aggregation(channel)
        self.CA1 = CoordAtt(32,32)
        self.CA2 = CoordAtt(32,32)
        self.CA3 = CoordAtt(32,32)
        
        self.conv5 = nn.Conv2d(channel, 1, 1)
        

#         self.up1 = decoder(1024, 512)
#         self.up2 = decoder(512, 256)
#         self.up3 = decoder(256, 128)
#         self.up4 = decoder(128, 64)
#         self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        self._initialize_weights()


    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self, x):

        x1 = self.start_conv(x)
        S1 = self.down1(x1)
#         print('S1',S1.size())
        S2 = self.down2(S1)
#         print('S2',S2.size())
        S3 = self.down3(S2)
#         print('S3',S3.size())
        S4 = self.down4(S3)
#         print('S4',S4.size())
        S5 = self.down5(S4)
#         print('S5',S5.size())
        # print(x4.shape)
        S3_RF = self.rf1(S3)
#         print('S3_RF',S3_RF.size())
        S4_RF = self.rf2(S4)
#         print('S4_RF',S4_RF.size())
        S5_RF = self.rf3(S5)
#         print('S5_RF',S5_RF.size())
        
        attention_map = self.agg1(S5_RF,S4_RF,S3_RF)
#         print('attention_map',attention_map.size())
        attention_map_pred = F.interpolate(attention_map, scale_factor=8, mode='bilinear')
#         print('attention_map_pred',attention_map_pred.size())
        
        #Indentification
        S3_1 = self.CA1(S3_RF)
#         print('S3_1',S3_1.size())
        S3_2 = self.CA2(S3_1)
#         print('S3_2',S3_2.size())
        S3_3 = self.CA3(S3_2)
#         print('S3_3',S3_3.size())
        
        S4_1 = self.CA1(S4_RF)
#         print('S4_1',S4_1.size())
        S4_2 = self.CA2(S4_1)
#         print('S4_2',S4_2.size())
        S4_3 = self.CA3(S4_2)
#         print('S4_3',S4_3.size())
        
        S5_1 = self.CA1(S5_RF)
#         print('S5_1',S5_1.size())
        S5_2 = self.CA2(S5_1)
#         print('S5_2',S5_2.size())
        S5_3 = self.CA3(S5_2)
        S5_3 = self.conv5(S5_3)
#         print('S5_3',S5_3.size())
        
        
        guidance_g = F.interpolate(attention_map, scale_factor=0.25, mode='bilinear')
#         print('guidance_g',guidance_g.size())
        S5_F = S5_3 + guidance_g
#         print('S5_F',S5_F.size()) #S5_3 torch.Size([10, 1, 7, 7])
        S5_pred = F.interpolate(S5_F, scale_factor=32, mode='bilinear')
#         print('S5_pred',S5_pred.size())
        
        S4_F = F.interpolate(S5_F, scale_factor=2, mode='bilinear')
        S4_3 = self.conv5(S4_3)
        S4_F = S4_3 + S4_F
        S4_pred = F.interpolate(S4_F, scale_factor=16, mode='bilinear')
#         print('S4_pred',S4_pred.size())
        
        S3_F = F.interpolate(S4_F, scale_factor=2, mode='bilinear')
        S3_3 = self.conv5(S3_3)
        S3_F = S3_3 + S3_F
        S3_pred = F.interpolate(S3_F, scale_factor=8, mode='bilinear')
#         print('S3_pred',S3_pred.size())
        
        
        

#         x = self.up1(x4, x)
#         x = self.up2(x3, x)
#         x = self.up3(x2, x)
#         x = self.up4(x1, x)
#         x = self.final_conv(x)

        return attention_map_pred, S5_pred, S4_pred, S3_pred

In [22]:
@torch.no_grad()
def validation(model, loader, loss_fn):
    losses = []
    model.eval()
    for image, target in loader:
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        #output = model(image)['out']
        output = model(image)
        loss1 = loss_fn(output[0], target)
        loss2 = loss_fn(output[1], target)
        loss3 = loss_fn(output[2], target)
        loss4 = loss_fn(output[3], target)
        loss = loss1 + loss2+loss2+loss3
        losses.append(loss.item())
        
    return np.array(losses).mean()

In [23]:
model = UNet()

In [20]:
images = torch.rand(10, 3, 224, 224)
result = model(images)
result[1].size()

torch.Size([10, 1, 224, 224])

In [24]:
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(),
                 lr=1e-4, weight_decay=1e-3)

class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2,-1)):

        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims
    
    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()
        return 1 - dc
    
bce_fn = nn.BCEWithLogitsLoss()
dice_fn = SoftDiceLoss()

def loss_fn(y_pred, y_true):
    bce = bce_fn(y_pred, y_true)
#     dice = dice_fn(y_pred.sigmoid(), y_true)
    return bce

In [25]:
header = r'''
        Train | Valid
Epoch |  Loss |  Loss | Time, m
'''
#          Epoch         metrics            time
raw_line = '{:6d}' + '\u2502{:7.3f}'*2 + '\u2502{:6.2f}'
print(header)

EPOCHES = 2
best_loss = 10
for epoch in range(1, EPOCHES+1):
    losses = []
    start_time = time.time()
    model.train()
    for image, target in tqdm_notebook(loader):
        
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        optimizer.zero_grad()
        #output = model(image)['out']
        output = model(image)
        loss1 = loss_fn(output[0], target)
        loss2 = loss_fn(output[1], target)
        loss3 = loss_fn(output[2], target)
        loss4 = loss_fn(output[3], target)
        loss = loss1+ loss2 + loss3 + loss4
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print(loss.item())
        
    vloss = validation(model, vloader, loss_fn)
    print(raw_line.format(epoch, np.array(losses).mean(), vloss,
                              (time.time()-start_time)/60**1))
    losses = []
    
    if vloss < best_loss:
        best_loss = vloss
        torch.save(model.state_dict(), 'Tunnel_best.pth')


        Train | Valid
Epoch |  Loss |  Loss | Time, m



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


     1│  1.671│  1.303│  2.24


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


     2│  1.141│  0.968│  2.08


In [24]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' #使用GPU或者cpu
model.load_state_dict(torch.load("./HED_best.pth")) #加载模型的参数
model.to(DEVICE) #模型送到当前设备中

print(model)

RuntimeError: Error(s) in loading state_dict for HED:
	Unexpected key(s) in state_dict: "conv1_1_down.weight", "conv1_1_down.bias", "conv1_2_down.weight", "conv1_2_down.bias", "conv2_1_down.weight", "conv2_1_down.bias", "conv2_2_down.weight", "conv2_2_down.bias", "conv3_1_down.weight", "conv3_1_down.bias", "conv3_2_down.weight", "conv3_2_down.bias", "conv3_3_down.weight", "conv3_3_down.bias", "conv4_1_down.weight", "conv4_1_down.bias", "conv4_2_down.weight", "conv4_2_down.bias", "conv4_3_down.weight", "conv4_3_down.bias", "conv5_1_down.weight", "conv5_1_down.bias", "conv5_2_down.weight", "conv5_2_down.bias", "conv5_3_down.weight", "conv5_3_down.bias". 
	size mismatch for score_dsn1.weight: copying a param with shape torch.Size([1, 21, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
	size mismatch for score_dsn2.weight: copying a param with shape torch.Size([1, 21, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 128, 1, 1]).
	size mismatch for score_dsn3.weight: copying a param with shape torch.Size([1, 21, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
	size mismatch for score_dsn4.weight: copying a param with shape torch.Size([1, 21, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 512, 1, 1]).
	size mismatch for score_dsn5.weight: copying a param with shape torch.Size([1, 21, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 512, 1, 1]).

In [22]:
test_imgs = os.listdir('C://Users//78731//Desktop//csnet//DeepCrackTest//test_img')
img_name = test_imgs
test_imgs=[os.path.join('C://Users//78731//Desktop//csnet//DeepCrackTest//test_img',k) for k in test_imgs]
test_imgs
i = 0

In [23]:
for img_data in test_imgs:
    img=Image.open(img_data)
    img = transform(img)
    img.unsqueeze_(0)
    img = img.to(DEVICE)
    output = model(img)
    save_image(output, img_name[i])
    i = i+1