# LinkNet

In [53]:
import torch
from torch import optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Function
from torchvision import models
import torch.nn.functional as F
from functools import partial

import data_loader
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import logging
from tqdm import tqdm

In [27]:
INPUT_SIZE, OUTPUT_SIZE = 256, 256
# root_path = 'D://Data/massachusetts-roads-dataset/'
root_path = '/home/renyan/ossdata/massachusetts-roads-dataset/'
road_path = root_path + "tiff_select2_parts_16/"

BATCH_SIZE = 4
EPOCH_NUM = 20
LR = 0.0002

## Model

In [31]:
class LinkNet34(nn.Module):
    def __init__(self, num_classes = 1):
        super(LinkNet34, self).__init__()
        self.nonlinearity = partial(F.relu,inplace=True)
        self.n_classes = num_classes

        filters = [64, 128, 256, 512]
        resnet = models.resnet34(pretrained = True)
        self.firstconv = resnet.conv1
        self.firstbn = resnet.bn1
        self.firstrelu = resnet.relu
        self.firstmaxpool = resnet.maxpool
        self.encoder1 = resnet.layer1
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4

        self.decoder4 = DecoderBlock(filters[3], filters[2])
        self.decoder3 = DecoderBlock(filters[2], filters[1])
        self.decoder2 = DecoderBlock(filters[1], filters[0])
        self.decoder1 = DecoderBlock(filters[0], filters[0])

        self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
        self.finalrelu1 = self.nonlinearity
        self.finalconv2 = nn.Conv2d(32, 32, 3)
        self.finalrelu2 = self.nonlinearity
        self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)

    def forward(self, x):
        # Encoder
        x = self.firstconv(x)
        x = self.firstbn(x)
        x = self.firstrelu(x)
        x = self.firstmaxpool(x)
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        # Decoder
        d4 = self.decoder4(e4) + e3
        d3 = self.decoder3(d4) + e2
        d2 = self.decoder2(d3) + e1
        d1 = self.decoder1(d2)
        out = self.finaldeconv1(d1)
        out = self.finalrelu1(out)
        out = self.finalconv2(out)
        out = self.finalrelu2(out)
        out = self.finalconv3(out)

        return torch.sigmoid(out)

In [4]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, n_filters):
        super(DecoderBlock,self).__init__()
        self.nonlinearity = partial(F.relu,inplace=True)

        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
        self.norm1 = nn.BatchNorm2d(in_channels // 4)
        self.relu1 = self.nonlinearity

        self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
        self.norm2 = nn.BatchNorm2d(in_channels // 4)
        self.relu2 = self.nonlinearity

        self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
        self.norm3 = nn.BatchNorm2d(n_filters)
        self.relu3 = self.nonlinearity

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.deconv2(x)
        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        return x

In [29]:
tnet = LinkNet34().cuda()
summary(tnet, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
         MaxPool2d-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]          36,864
       BatchNorm2d-6           [-1, 64, 64, 64]             128
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,864
       BatchNorm2d-9           [-1, 64, 64, 64]             128
             ReLU-10           [-1, 64, 64, 64]               0
       BasicBlock-11           [-1, 64, 64, 64]               0
           Conv2d-12           [-1, 64, 64, 64]          36,864
      BatchNorm2d-13           [-1, 64, 64, 64]             128
             ReLU-14           [-1, 64,

## BCE loss function

In [6]:
class dice_bce_loss(nn.Module):
    def __init__(self, batch=True):
        super(dice_bce_loss, self).__init__()
        self.batch = batch
        self.bce_loss = nn.BCELoss()
        
    def soft_dice_coeff(self, y_true, y_pred):
        smooth = 0.0  # may change
        if self.batch:
            i = torch.sum(y_true)
            j = torch.sum(y_pred)
            intersection = torch.sum(y_true * y_pred)
        else:
            i = y_true.sum(1).sum(1).sum(1)
            j = y_pred.sum(1).sum(1).sum(1)
            intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
        score = (2. * intersection + smooth) / (i + j + smooth)
        #score = (intersection + smooth) / (i + j - intersection + smooth)#iou
        return score.mean()

    def soft_dice_loss(self, y_true, y_pred):
        loss = 1 - self.soft_dice_coeff(y_true, y_pred)
        return loss
        
    def __call__(self, y_true, y_pred):
        a =  self.bce_loss(y_pred, y_true)
        b =  self.soft_dice_loss(y_true, y_pred)
        return a + b

## Data

In [7]:
train_dataset = data_loader.RoadDataset(road_path, True, INPUT_SIZE, OUTPUT_SIZE)
val_dataset = data_loader.RoadDataset(road_path, False, INPUT_SIZE, OUTPUT_SIZE)

Read 7056 images
Read 224 images


In [8]:
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle = False)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)

## Evaluation

In [54]:
class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    # 在进入 forward 之前，所有变量都会被转化为 tensor
    def forward(self, input, target):
        self.save_for_backward(input, target) # tensor 转化为变量保存到后续操作
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):
        input, target = self.saved_variables
        grad_input = grad_target = None

        # 判断 input 是否需要求梯度
        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
                         / (self.union * self.union)
        # 判断 target 是否需要求梯度
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    # 在合适的设备上初始化一个1*1零向量
    # 同一个 batch 中 dice loss 取平均
    s = torch.FloatTensor(1).cuda().zero_() if input.is_cuda else torch.FloatTensor(1).zero_()
    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])
    return s / (i + 1)

def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    # 关闭 batchnorm 和 dropout
    net.eval() # 仔细看
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0

    # 括号里设置文字输出信息
#     with tqdm(total = n_val, desc='Validation round', unit='batch', leave = False) as pbar:
        # 对于每个 batch
    for batch in loader:
        imgs, true_masks = batch[0], batch[1]
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        # 不需要追踪梯度变化，不需要进行反向传播，提升速度
        with torch.no_grad():
            # 得到模型预测结果
            mask_pred = net(imgs)

        # 不同类别预测结果损失累加
        if net.n_classes > 1:
            tot += F.cross_entropy(mask_pred, true_masks).item()
        else:
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred, true_masks).item()
#             pbar.update()

    net.train()
    return tot / n_val

## Train

In [55]:
def train_net(net, device, train_dataset, val_dataset, epochs = EPOCH_NUM, lr = LR, save_cp = True,
             batch_size = BATCH_SIZE):

    train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle = False)
    val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)
    
    # 每轮 evaluation 检验的 batch 个数
    n_val = len(val_dataset)
    # 每轮 train 检验的 batch 个数
    n_train = len(train_dataset)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{BATCH_SIZE}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
    ''')
#     换 SGD，图像用 SGD Adam，收敛速度而非效果
#     optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 在发现loss不再降低或者acc不再提高之后，降低学习率。patience 含义：不再减小（或增大）的累计次数
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()
        print(net.firstconv.weight[0,0,0])

        epoch_loss = 0
        with tqdm(total = n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                
                imgs = batch[0]
                true_masks = batch[1]

                imgs = imgs.to(device = device, dtype = torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device = device, dtype = mask_type) # 01

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # 对于每个 batch 都要更新一次参数空间
                optimizer.zero_grad()
                loss.backward()
                # 防止梯度爆炸，设置梯度截断
                nn.utils.clip_grad_value_(net.parameters(), 1)
                optimizer.step()

                # 每个 batch 结束更新一次进度条，迭代器内部计数器累加 batch 的大小
                pbar.update(imgs.shape[0])
                global_step += 1
                
                # 在 tensorboard 中记录一次
                if global_step % (n_train // (10 * batch_size) + 1) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    if net.n_classes > 1:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks, global_step)
                        writer.add_images('masks/pred_0.5', torch.sigmoid(masks_pred) > 0.5, global_step)
                        writer.add_images('masks/pred_0.4', torch.sigmoid(masks_pred) > 0.4, global_step)
                        writer.add_images('masks/pred_0.3', torch.sigmoid(masks_pred) > 0.3, global_step)
                        writer.add_images('masks/pred_0.2', torch.sigmoid(masks_pred) > 0.2, global_step)
                        writer.add_images('masks/pred_0.1', torch.sigmoid(masks_pred) > 0.1, global_step)

        if save_cp:
            try:
                os.mkdir(DIR_CHECKPOINT)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       DIR_CHECKPOINT + f'unet_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved.')
            if os.path.exists(DIR_CHECKPOINT + f'unet_epoch{epoch - 4}.pth') & (epoch - 4)//10 != 0:
                os.remove(DIR_CHECKPOINT + f'unet_epoch{epoch - 4}.pth')
                logging.info(f'Checkpoint {epoch - 4} deleted.')

    writer.close()

In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = LinkNet34().to(device)

In [57]:
# for tag, value in net.named_parameters():
#     print(tag)

In [None]:
train_net(net, device, train_dataset, val_dataset)

Epoch 1/20:   0%|          | 0/7056 [00:00<?, ?img/s]

tensor([ 0.0054, -0.0069,  0.0079,  0.0379,  0.0491,  0.0307,  0.0254],
       device='cuda:0', grad_fn=<SelectBackward>)


Epoch 1/20:  29%|██▉       | 2036/7056 [06:24<13:19,  6.28img/s, loss (batch)=0.9]    

In [17]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# net = net().to(device)
# optimizer = torch.optim.Adam(params = net.parameters(), lr = LR)
# criterion= dice_bce_loss()
# epochs = EPOCH_NUM
# batch_size = BATCH_SIZE
# global_step = 0

# for epoch in range(epochs):
#     train_epoch_loss = 0
    
#     with tqdm(total = n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
#         for i, (imgs, masks) in enumerate(iter(train_loader)):
            
#             imgs = imgs.to(device = device, dtype = torch.float32)
#             masks = masks.to(device = device, dtype = torch.float32)
#             preds = net(img)
#             loss = criterion(masks, preds)
            
#             optimizer.zero_grad()
#             loss.backward()
#             nn.utils.clip_grad_value_(net.parameters(), 1)
#             optimizer.step()
#             train_epoch_loss += loss.item()
            
#             pbar.update(imgs.shape[0])
#             global_step += 1
            
#             pbar.set_postfix(**{'loss (batch)': loss.item()})
#             writer.add_scalar('Loss/train', loss.item(), global_step)
            
#             if global_step % (len(train_dataset) // (10 * batch_size) + 1) == 0:
#                     for tag, value in net.named_parameters():
#                         tag = tag.replace('.', '/')
#                         writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
#                         writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
#                     val_score = eval_net(net, val_loader, device)
#                     scheduler.step(val_score)
#                     writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

#                     if net.n_classes > 1:
#                         logging.info('Validation cross entropy: {}'.format(val_score))
#                         writer.add_scalar('Loss/test', val_score, global_step)
#                     else:
#                         logging.info('Validation Dice Coeff: {}'.format(val_score))
#                         writer.add_scalar('Dice/test', val_score, global_step)

#                     writer.add_images('images', imgs, global_step)
#                     if net.n_classes == 1:
#                         writer.add_images('masks/true', true_masks, global_step)
#                         writer.add_images('masks/pred_0.5', torch.sigmoid(masks_pred) > 0.5, global_step)
#                         writer.add_images('masks/pred_0.4', torch.sigmoid(masks_pred) > 0.4, global_step)
#                         writer.add_images('masks/pred_0.3', torch.sigmoid(masks_pred) > 0.3, global_step)
#                         writer.add_images('masks/pred_0.2', torch.sigmoid(masks_pred) > 0.2, global_step)
#                         writer.add_images('masks/pred_0.1', torch.sigmoid(masks_pred) > 0.1, global_step)
            
#         train_epoch_loss /= len(train_loader)