# Dilated HarDNet

In [None]:
import sys
sys.path.insert(0, '.')
import os
import logging
import random
import time
import math
import torch
import numpy as np 
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init
from torch.utils import data
from datetime import datetime, timedelta
from collections import OrderedDict

from torchsummary import summary

import torch.distributed as dist

 
!pip install tensorboardX
from tensorboardX import SummaryWriter

Collecting tensorboardX
  Downloading tensorboardX-2.5-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 5.5 MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.5


In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
batch_size = 8
n_workers = 2
print_interval=10
val_interval=500

n_classes = 19

model_arch = 'Linear_HarDNblock_dilated'


bn_mom = 0.1


In [None]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
## Architecture

### Blocks

In [None]:
class DilatedConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3, dilation=1, stride=1, bias=False):
        super().__init__()
        num_splits = 2
        assert(out_planes%num_splits == 0)
        conv_in_planes = in_planes // num_splits
        conv_out_planes = out_planes // num_splits
        groups = 2
        conv_1 = nn.Conv2d(conv_in_planes, conv_out_planes, kernel_size, padding=kernel_size//2, dilation=1, stride=stride, groups=groups, bias=bias)
        conv_n = nn.Conv2d(conv_in_planes, conv_out_planes, 3, padding=dilation, dilation=dilation, stride=stride, groups=groups, bias=bias)
        self.convs=nn.ModuleList([conv_1, conv_n])
        self.num_splits=num_splits
        self.init_weight()
    
    def forward(self,x):
        x=torch.tensor_split(x,self.num_splits,dim=1)
        res = []
        for i in range(self.num_splits):
            res.append(self.convs[i](x[i]))
        return torch.cat(res,dim=1)
    
    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

In [None]:
class ConvX(nn.Module):
    def __init__(self, in_planes, out_planes, kernel=3, stride=1, dilation=1):
        super(ConvX, self).__init__()
        if dilation == 1:
            self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
        else:
            self.conv = DilatedConv(in_planes, out_planes, kernel_size=kernel, stride=stride, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.init_weight()

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

In [None]:
class HarDBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, grmul, n_layers, dilation=1, keepBase=False):
        super().__init__()
        self.in_channels = in_channels
        self.growth_rate = growth_rate
        self.grmul = grmul
        self.links = []
        self.out_channels = 0
        self.keepBase = keepBase
        self.layers = nn.ModuleList([])
        for i in range(n_layers):
            out_ch, in_ch, el_links = self.get_links(i+1)
            self.links.append(el_links)
            self.layers.append(ConvX(in_ch, out_ch, dilation=dilation))
            if (i % 2 == 0) or (i == n_layers - 1):
                self.out_channels += out_ch

    def __out_ch(self, layer_id): 
        out_ch = self.growth_rate
        for i in range(1, int(math.log2(layer_id))+1):
            if layer_id % 2**i == 0:
                out_ch = out_ch*self.grmul
        return (int(out_ch + 3) // 4) * 4

    def get_links(self, layer_id):
        in_ch = 0
        links_ = []
        for i in range(int(math.log2(layer_id))):
            diff = 2**i
            if (layer_id % diff == 0) and layer_id - diff > 0:
                in_ch += self.__out_ch(layer_id - diff)
                links_.append(layer_id - diff)
        if math.log2(layer_id).is_integer():
            in_ch += self.in_channels
            links_.append(0)
        return self.__out_ch(layer_id), in_ch, links_

    def forward(self, x):
        data = [x]
        for layer in range(len(self.layers)):
            layer_input = []
            for link in self.links[layer]:
                layer_input.append(data[link])
            in_ = layer_input[0] if len(layer_input) == 1 \
                else torch.cat(layer_input, dim=1)
            data.append(self.layers[layer](in_))
        t = len(data)
        out = []
        for i in range(t):
            if (i % 2 == 1) or (self.keepBase and i == 0) or (i == t-1):
                out.append(data[i])
        return torch.cat(out, dim=1)
    
    def get_out_ch(self):
        return self.out_channels

In [None]:
class CatBottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=4, stride=1, dilation=1):
        super(CatBottleneck, self).__init__()
        assert block_num > 1, print("block number should be larger than 1.")
        self.conv_list = nn.ModuleList()
        self.stride = stride
        if stride == 2:
            self.avd_layer = nn.Sequential(
                nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
                nn.BatchNorm2d(out_planes//2),
            )
            self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
            stride = 1

        for idx in range(block_num):
            blk = None
            if idx == 0:
                blk = ConvX(in_planes, out_planes//2, kernel=1)
            elif idx == 1 and block_num == 2:
                blk = ConvX(out_planes//2, out_planes//2, stride=stride, dilation=dilation)
            elif idx == 1 and block_num > 2:
                blk = ConvX(out_planes//2, out_planes//4, stride=stride, dilation=dilation)
            elif idx < block_num - 1:
                blk = ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1)), dilation=dilation)
            else:
                blk = ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)), dilation=dilation)
            self.conv_list.append(blk)
            
    def forward(self, x):
        out_list = []
        out1 = self.conv_list[0](x)

        for idx, conv in enumerate(self.conv_list[1:]):
            if idx == 0:
                if self.stride == 2:
                    out = conv(self.avd_layer(out1))
                else:
                    out = conv(out1)
            else:
                out = conv(out)
            out_list.append(out)

        if self.stride == 2:
            out1 = self.skip(out1)
        out_list.insert(0, out1)

        out = torch.cat(out_list, dim=1)
        return out

In [None]:
class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvX(in_chan, mid_chan, kernel=3, stride=1)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

In [None]:
class Decoder(nn.Module):
    def __init__(self, num_classes, channels):
        super().__init__()
        channels4, channels8, channels16 = channels
        self.head16=ConvX(channels16, 128, 1)
        self.head8=ConvX(channels8, 128, 1)
        self.head4=ConvX(channels4, 8, 1)
        self.conv8=ConvX(128,64,3,1,1)
        self.conv4=ConvX(64+8,64,3,1,1)
        self.classifier=nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        x4, x8, x16 = x
        x16=self.head16(x16)
        x8=self.head8(x8)
        x4=self.head4(x4)
        x16 = F.interpolate(x16, size=x8.shape[-2:], mode='bilinear', align_corners=False)
        x8= x8 + x16
        x8=self.conv8(x8)
        x8 = F.interpolate(x8, size=x4.shape[-2:], mode='bilinear', align_corners=False)
        x4=torch.cat((x8,x4),dim=1)
        x4=self.conv4(x4)
        x4=self.classifier(x4)
        return x4

In [None]:
class LinearHarDBlockDilated(nn.Module):
    def __init__(self, dilations: list, target_planes=256, n_classes=19):
        super(LinearHarDBlockDilated, self).__init__()
        self.conv0 = ConvX(3, 32, kernel=3, stride=2)
        self.conv1 = ConvX(32, 64, kernel=3, stride=2)
        scale = 4
        planes = 64
        blocks = []
        self.outs = []
        channels = []
        prev_dilation = dilations[0]
        for i, dilation in enumerate(dilations):
            if prev_dilation != dilation and scale <= 16:
                scale *= 2
                blocks.append(ConvX(planes, planes, 3, 2))
            in_planes = planes
            out_planes = planes
            blk = HarDBlock(in_planes, growth_rate=40, grmul=1.7, n_layers=4, dilation=dilation)
            blocks.append(blk)
            blk_out_ch = blk.get_out_ch()
            if blk_out_ch < target_planes:
                out_planes = min(blk_out_ch, target_planes)
                blocks.append(ConvX(blk_out_ch, out_planes, kernel=1))
                planes = out_planes
            if i == 0 or i == len(dilations) - 1:
                channels.append(out_planes)
                self.outs.append(i)
            if i + 1 < len(dilations) and scale < 16 and dilation != dilations[i+1]:
                channels.append(out_planes)
                self.outs.append(i)
            prev_dilation = dilation
        self.blocks = nn.ModuleList(blocks)
        self.detail_out = BiSeNetOutput(channels[0], 64, 1)
        self.decoder = Decoder(n_classes, channels[1:])

    def forward(self, x):
        out = self.conv0(x)
        out = self.conv1(out)
        store_out = []
        for i, module in enumerate(self.blocks):
            out = module(out)
            if i in self.outs:
                store_out.append(out)

        out_detail = self.detail_out(store_out[0])
        out = self.decoder(store_out[1:])
        out_detail = F.interpolate(out_detail, size=x.shape[-2:], mode='bilinear', align_corners=False)
        out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
        return out, out_detail

---
## Loss & Optimizer

In [None]:
class OhemCELoss(nn.Module):
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).to(device)
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        N, C, H, W = logits.size()
        loss = self.criteria(logits, labels).view(-1)
        loss, _ = torch.sort(loss, descending=True)
        if loss[self.n_min] > self.thresh:
            loss = loss[loss>self.thresh]
        else:
            loss = loss[:self.n_min]
        return torch.mean(loss)

In [None]:
float_tensor_type = torch.cuda.FloatTensor if device.type=='cuda' else torch.FloatTensor

In [None]:
def dice_loss_func(input, target):
    smooth = 1.
    n = input.size(0)
    iflat = input.view(n, -1)
    tflat = target.view(n, -1)
    intersection = (iflat * tflat).sum(1)
    loss = 1 - ((2. * intersection + smooth) /
                (iflat.sum(1) + tflat.sum(1) + smooth))
    return loss.mean()


class DetailAggregateLoss(nn.Module):
    def __init__(self, *args, **kwargs):
        super(DetailAggregateLoss, self).__init__()
        
        self.laplacian_kernel = torch.tensor(
            [-1, -1, -1, -1, 8, -1, -1, -1, -1],
            dtype=torch.float32).reshape(1, 1, 3, 3).requires_grad_(False).type(float_tensor_type)        

        self.fuse_kernel = torch.nn.Parameter(torch.tensor([[6./10], [3./10], [1./10]],
            dtype=torch.float32).reshape(1, 3, 1, 1).type(float_tensor_type))

    def forward(self, boundary_logits, gtmasks):

        boundary_targets = F.conv2d(gtmasks.unsqueeze(1).type(float_tensor_type), self.laplacian_kernel, padding=1)
        boundary_targets = boundary_targets.clamp(min=0)
        boundary_targets[boundary_targets > 0.1] = 1
        boundary_targets[boundary_targets <= 0.1] = 0

        boundary_targets_x2 = F.conv2d(gtmasks.unsqueeze(1).type(float_tensor_type), self.laplacian_kernel, stride=2, padding=1)
        boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
        
        boundary_targets_x4 = F.conv2d(gtmasks.unsqueeze(1).type(float_tensor_type), self.laplacian_kernel, stride=4, padding=1)
        boundary_targets_x4 = boundary_targets_x4.clamp(min=0)

        boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
        boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
        
        boundary_targets_x2_up[boundary_targets_x2_up > 0.1] = 1
        boundary_targets_x2_up[boundary_targets_x2_up <= 0.1] = 0
        
        
        boundary_targets_x4_up[boundary_targets_x4_up > 0.1] = 1
        boundary_targets_x4_up[boundary_targets_x4_up <= 0.1] = 0
       
        boudary_targets_pyramids = torch.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), dim=1)
        
        boudary_targets_pyramids = boudary_targets_pyramids.squeeze(2)
        boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel)

        boudary_targets_pyramid[boudary_targets_pyramid > 0.1] = 1
        boudary_targets_pyramid[boudary_targets_pyramid <= 0.1] = 0
        
        
        if boundary_logits.shape[-1] != boundary_targets.shape[-1]:
            boundary_logits = F.interpolate(
                boundary_logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True)
        
        bce_loss = F.binary_cross_entropy_with_logits(boundary_logits, boudary_targets_pyramid)
        dice_loss = dice_loss_func(torch.sigmoid(boundary_logits), boudary_targets_pyramid)
        return bce_loss,  dice_loss

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
                nowd_params += list(module.parameters())
        return nowd_params

In [None]:
class Optimizer(object):
    def __init__(self, model, loss, lr0, momentum, wd, warmup_steps, 
                 warmup_start_lr, max_iter, power, *args, **kwargs):
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr
        self.lr0 = lr0
        self.lr = self.lr0
        self.max_iter = float(max_iter)
        self.power = power
        self.it = 0
        # wd_params, nowd_params = model.get_params() # , lr_mul_wd_params, lr_mul_nowd_params
        loss_nowd_params = loss.get_params()
        #---------------------------------------------------------------------------
        #---------------------------------------------------------------------------
        #---------------------------------------------------------------------------
        param_list = [
                {'params': model.parameters()},
                {'params': loss_nowd_params}]
        self.optim = torch.optim.SGD(
                param_list,
                # model.parameters(),
                lr = lr0,
                momentum = momentum,
                weight_decay = wd)
        self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)

    def get_lr(self):
        if self.it <= self.warmup_steps:
            lr = self.warmup_start_lr*(self.warmup_factor**self.it)
        else:
            factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
            lr = self.lr0 * factor
        return lr

    def step(self):
        self.lr = self.get_lr()
        for pg in self.optim.param_groups:
            if pg.get('lr_mul', False):
                pg['lr'] = self.lr * 10
            else:
                pg['lr'] = self.lr
        if self.optim.defaults.get('lr_mul', False):
            self.optim.defaults['lr'] = self.lr * 10
        else:
            self.optim.defaults['lr'] = self.lr
        self.it += 1
        self.optim.step()
        if self.it == self.warmup_steps+2:
            logger.info('==> warmup done, start to implement poly lr strategy')
    
    def get_state(self):
        return {
            'warmup_steps': self.warmup_steps,
            'warmup_start_lr': self.warmup_start_lr,
            'lr0': self.lr0,
            'lr': self.lr,
            'max_iter': self.max_iter,
            'power': self.power, 
            'it': self.it,
            'optim_state': self.optim.state_dict(),
            'warmup_factor': self.warmup_factor
        }

    def load_state(self, state):
        self.warmup_steps = state.get('warmup_steps')
        self.warmup_start_lr = state.get('warmup_start_lr')
        self.lr0 = state.get('lr0')
        self.lr = state.get('lr')
        self.max_iter = state.get('max_iter')
        self.power = state.get('power')
        self.it = state.get('it')
        self.optim.load_state_dict(state.get('optim_state'))
        self.warmup_factor = state.get('warmup_factor')


    def zero_grad(self):
        self.optim.zero_grad()

In [None]:
score_thres = 0.7
n_img_per_gpu = 8
cropsize = (512, 1024)
n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//32
ignore_idx=255

---
## Dataset

In [None]:
!wget https://raw.githubusercontent.com/MichaelFan01/STDC-Seg/master/cityscapes_info.json

--2022-04-06 06:38:32--  https://raw.githubusercontent.com/MichaelFan01/STDC-Seg/master/cityscapes_info.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7412 (7.2K) [text/plain]
Saving to: ‘cityscapes_info.json’


2022-04-06 06:38:32 (65.5 MB/s) - ‘cityscapes_info.json’ saved [7412/7412]



In [None]:
from PIL import Image
import PIL.ImageEnhance as ImageEnhance
import random
import numpy as np


class RandomCrop(object):
    def __init__(self, size, *args, **kwargs):
        self.size = size

    def __call__(self, im_lb):
        im = im_lb['im']
        lb = im_lb['lb']
        assert im.size == lb.size
        W, H = self.size
        w, h = im.size

        if (W, H) == (w, h): return dict(im=im, lb=lb)
        if w < W or h < H:
            scale = float(W) / w if w < h else float(H) / h
            w, h = int(scale * w + 1), int(scale * h + 1)
            im = im.resize((w, h), Image.BILINEAR)
            lb = lb.resize((w, h), Image.NEAREST)
        sw, sh = random.random() * (w - W), random.random() * (h - H)
        crop = int(sw), int(sh), int(sw) + W, int(sh) + H
        return dict(
                im = im.crop(crop),
                lb = lb.crop(crop)
                    )


class HorizontalFlip(object):
    def __init__(self, p=0.5, *args, **kwargs):
        self.p = p

    def __call__(self, im_lb):
        if random.random() > self.p:
            return im_lb
        else:
            im = im_lb['im']
            lb = im_lb['lb']
            return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
                        lb = lb.transpose(Image.FLIP_LEFT_RIGHT),
                    )


class RandomScale(object):
    def __init__(self, scales=(1, ), *args, **kwargs):
        self.scales = scales
        # print('scales: ', scales)

    def __call__(self, im_lb):
        im = im_lb['im']
        lb = im_lb['lb']
        W, H = im.size
        scale = random.choice(self.scales)
        # scale = np.random.uniform(min(self.scales), max(self.scales))
        w, h = int(W * scale), int(H * scale)
        return dict(im = im.resize((w, h), Image.BILINEAR),
                    lb = lb.resize((w, h), Image.NEAREST),
                )


class ColorJitter(object):
    def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
        if not brightness is None and brightness>0:
            self.brightness = [max(1-brightness, 0), 1+brightness]
        if not contrast is None and contrast>0:
            self.contrast = [max(1-contrast, 0), 1+contrast]
        if not saturation is None and saturation>0:
            self.saturation = [max(1-saturation, 0), 1+saturation]

    def __call__(self, im_lb):
        im = im_lb['im']
        lb = im_lb['lb']
        r_brightness = random.uniform(self.brightness[0], self.brightness[1])
        r_contrast = random.uniform(self.contrast[0], self.contrast[1])
        r_saturation = random.uniform(self.saturation[0], self.saturation[1])
        im = ImageEnhance.Brightness(im).enhance(r_brightness)
        im = ImageEnhance.Contrast(im).enhance(r_contrast)
        im = ImageEnhance.Color(im).enhance(r_saturation)
        return dict(im = im,
                    lb = lb,
                )


class MultiScale(object):
    def __init__(self, scales):
        self.scales = scales

    def __call__(self, img):
        W, H = img.size
        sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
        imgs = []
        [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
        return imgs


class Compose(object):
    def __init__(self, do_list):
        self.do_list = do_list

    def __call__(self, im_lb):
        for comp in self.do_list:
            im_lb = comp(im_lb)
        return im_lb

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import os.path as osp
import os
from PIL import Image
import numpy as np
import json



class CityScapes(Dataset):
    def __init__(self, rootpth, cropsize=(640, 480), mode='train', 
    randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
        super(CityScapes, self).__init__(*args, **kwargs)
        assert mode in ('train', 'val', 'test', 'trainval')
        self.mode = mode
        print('self.mode', self.mode)
        self.ignore_lb = 255

        with open('./cityscapes_info.json', 'r') as fr:
            labels_info = json.load(fr)
        self.lb_map = {el['id']: el['trainId'] for el in labels_info}
        

        ## parse img directory
        self.imgs = {}
        imgnames = []
        impth = osp.join(rootpth, 'leftImg8bit', mode)
        folders = os.listdir(impth)
        for fd in folders:
            fdpth = osp.join(impth, fd)
            im_names = os.listdir(fdpth)
            names = [el.replace('_leftImg8bit.png', '') for el in im_names]
            impths = [osp.join(fdpth, el) for el in im_names]
            imgnames.extend(names)
            self.imgs.update(dict(zip(names, impths)))

        ## parse gt directory
        self.labels = {}
        gtnames = []
        gtpth = osp.join(rootpth, 'gtFine', mode)
        folders = os.listdir(gtpth)
        for fd in folders:
            fdpth = osp.join(gtpth, fd)
            lbnames = os.listdir(fdpth)
            lbnames = [el for el in lbnames if 'labelIds' in el]
            names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
            lbpths = [osp.join(fdpth, el) for el in lbnames]
            gtnames.extend(names)
            self.labels.update(dict(zip(names, lbpths)))

        self.imnames = imgnames
        self.len = len(self.imnames)
        print('self.len', self.mode, self.len)
        assert set(imgnames) == set(gtnames)
        assert set(self.imnames) == set(self.imgs.keys())
        assert set(self.imnames) == set(self.labels.keys())

        ## pre-processing
        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        self.trans_train = Compose([
            ColorJitter(
                brightness = 0.5,
                contrast = 0.5,
                saturation = 0.5),
            HorizontalFlip(),
            RandomScale(randomscale),
            RandomCrop(cropsize)
            ])


    def __getitem__(self, idx):
        fn  = self.imnames[idx]
        impth = self.imgs[fn]
        lbpth = self.labels[fn]
        img = Image.open(impth).convert('RGB')
        label = Image.open(lbpth)
        if self.mode == 'train' or self.mode == 'trainval':
            im_lb = dict(im = img, lb = label)
            im_lb = self.trans_train(im_lb)
            img, label = im_lb['im'], im_lb['lb']
        img = self.to_tensor(img)
        label = np.array(label).astype(np.int64)[np.newaxis, :]
        label = self.convert_labels(label)
        return img, label


    def __len__(self):
        return self.len


    def convert_labels(self, label):
        for k, v in self.lb_map.items():
            label[label == k] = v
        return label

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
dspth = '/content/drive/MyDrive/RnD/datasets/'
cfg_data = {
    'dataset': 'cityscapes',
    'train_split': 'train',
    'val_split': 'val',
    'img_rows': cropsize[0],
    'img_cols': cropsize[1],
    'path': dspth
}
randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)
ds = CityScapes(cfg_data['path'], cropsize=cropsize, mode='train', randomscale=randomscale)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
dl = DataLoader(ds,
                batch_size = batch_size,
                shuffle = False,
                # sampler = sampler,
                num_workers = n_workers,
                pin_memory = False,
                drop_last = True)
# exit(0)
dsval = CityScapes(cfg_data['path'], mode='val', randomscale=randomscale)
# sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
dlval = DataLoader(dsval,
                batch_size = 2,
                shuffle = False,
                # sampler = sampler_val,
                num_workers = n_workers,
                drop_last = False)

## model
ignore_idx = 255

self.mode train
self.len train 2975
self.mode val
self.len val 500


---
# Preparing for train

In [None]:
best_iou = -100.0
flag = True
loss_all = 0
loss_n = 0

In [None]:
def get_logger(logdir):
    logger = logging.getLogger("DR_test")
    ts = str(datetime.now()).split(".")[0].replace(" ", "_")
    ts = ts.replace(":", "_").replace("-", "_")
    file_path = os.path.join(logdir, "run_{}.log".format(ts))
    hdlr = logging.FileHandler(file_path)
    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    return logger

In [None]:
!git clone https://github.com/PingoLH/FCHarDNet.git
!cp -r FCHarDNet/ptsemseg ./
!rm -rf FCHarDNet

Cloning into 'FCHarDNet'...
remote: Enumerating objects: 130, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 130 (delta 2), reused 7 (delta 1), pack-reused 117[K
Receiving objects: 100% (130/130), 9.10 MiB | 23.00 MiB/s, done.
Resolving deltas: 100% (50/50), done.


In [None]:
from ptsemseg.loader import get_loader
from ptsemseg.metrics import runningScore, averageMeter
from ptsemseg.augmentations import get_composed_augmentations

In [None]:
base_path = "/content/drive/MyDrive/RnD/runs/HarDNet"
model_modification = 'Dilated_4959'
model_modification_path = os.path.join(base_path, model_modification)

In [None]:
logdir = os.path.join(model_modification_path, str(datetime.fromtimestamp(int(time.time()))))
writer = SummaryWriter(log_dir=logdir)

print("RUNDIR: {}".format(logdir))

logger = get_logger(logdir)

RUNDIR: /content/drive/MyDrive/RnD/runs/HarDNet/Dilated_4959/2022-04-06 06:39:25


In [None]:
# Setup seeds
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
np.random.seed(1337)
random.seed(1337)

In [None]:
running_metrics_val = runningScore(n_classes)

dilations = [1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8]
model = LinearHarDBlockDilated(dilations)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight)

model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model.apply(weights_init)
pass

In [None]:
# optimizer init data
momentum = 0.9
weight_decay = 5e-4
lr_start = 1e-2
power = 0.9
warmup_steps = 1000
warmup_start_lr = 1e-5
epoch_iteration = len(ds) // batch_size
max_epoch = 484
max_iter = max_epoch * epoch_iteration

In [None]:
start_epoch = 0
it = 0
local_max_epoch = start_epoch + 6 if start_epoch + 6 < max_epoch else max_epoch

In [None]:
criteria_ffm = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)    # out1
boundary_loss_func = DetailAggregateLoss()                                          # out3
criteria_val = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)    # out1 
val_loss_meter = averageMeter()

In [None]:
optim = Optimizer(
            model = model.module,
            loss = boundary_loss_func,
            lr0 = lr_start,
            momentum = momentum,
            wd = weight_decay,
            warmup_steps = warmup_steps,
            warmup_start_lr = warmup_start_lr,
            max_iter = max_iter,
            power = power)

In [None]:
loss_avg = []
loss_boundery_bce = []
loss_boundery_dice = []

---
## Restore state

In [None]:
runs = sorted(os.listdir(model_modification_path), reverse=True)
best_path = None
last_path = None
for run in runs:
    tmp_base = os.path.join(model_modification_path, run)
    model_name = "{}_{}".format(model_arch, cfg_data['dataset'])
    checkpoint = os.path.join(tmp_base, model_name+'_checkpoint.pkl')
    best = os.path.join(tmp_base, model_name+'_best_model.pkl')
    if not last_path and os.path.exists(checkpoint):
        last_path = checkpoint
    if not best_path and os.path.exists(best):
        best_path = best
    if last_path and best_path:
        break

In [None]:
loaded = torch.load(last_path)
best_iou_arrc = torch.load(best_path)
model_state = loaded.get('model_state')

optimizer_state = loaded.get('optimizer_state')
start_epoch = loaded.get('epoch') + 1
local_max_epoch = start_epoch + 12

best_iou = best_iou_arrc.get('best_iou')
i = 0
flag = True
loss_all = 0
loss_n = 0

model.load_state_dict(model_state)
optim.load_state(optimizer_state)

In [None]:
file_ = list(filter(lambda x: x.endswith('.log'), os.listdir(logdir)))[0]
file_ = os.path.join(logdir, file_)
if not flag and osp.isfile(file_):
    with open(file_, "r") as f:
        str_ = f.readlines()[-24]
        st__, end__ = str_.find('Epoch') + 6, str_.find(' Iter')
        if st__ > -1 and end__ > -1 and local_max_epoch - 1 == int(str_[st__:end__]):
            start_epoch = local_max_epoch
            local_max_epoch += 12


start_epoch, local_max_epoch, best_iou

(12, 24, 0.2444844925521009)

## Training

In [None]:
st = glob_st = time.time()
flag = False
for epoch_id in range(start_epoch, local_max_epoch):
    for images, labels in dl:
        it += 1
        start_ts = time.time()
        
        model.train()

        images = images.to(device)
        labels = labels.to(device)
        labels = torch.squeeze(labels, 1)
 
        optim.zero_grad()

        out_main, out_detail = model(images)
 
        loss_ffm = criteria_ffm(out_main, labels)
        
        boundery_bce, boundery_dice = boundary_loss_func(out_detail, labels)

        boundery_bce_loss = boundery_bce
        boundery_dice_loss = boundery_dice

        loss = loss_ffm + boundery_bce_loss + boundery_dice_loss

        loss.backward()
        optim.step()

        loss_avg.append(loss.item())

        loss_boundery_bce.append(boundery_bce_loss.item())
        loss_boundery_dice.append(boundery_dice_loss.item())

        if (it + 1) % print_interval == 0:
            loss_avg = sum(loss_avg) / len(loss_avg)
            lr = optim.lr
            ed = time.time()
            t_intv, glob_t_intv = ed - st, ed - glob_st
            eta = int((max_iter - it) * (glob_t_intv / it))
            eta = str(timedelta(seconds=eta))

            loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
            loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
            msg = ', '.join([
                'epoch: {epoch}/{max_epoch}'
                'it: {it}/{max_it}',
                'lr: {lr:4f}',
                'loss: {loss:.4f}',
                'boundery_bce_loss: {boundery_bce_loss:.4f}',
                'boundery_dice_loss: {boundery_dice_loss:.4f}',
                'eta: {eta}',
                'time: {time:.4f}',
            ]).format(
                epoch = epoch_id,
                max_epoch = max_epoch,
                it = it+1,
                max_it = epoch_iteration,
                lr = lr,
                loss = loss_avg,
                boundery_bce_loss = loss_boundery_bce_avg,
                boundery_dice_loss = loss_boundery_dice_avg,
                time = t_intv,
                eta = eta
            )
            
            logger.info(msg)
            print("loss/train_loss", loss.item(), it + 1)
            loss_avg = []
            loss_boundery_bce = []
            loss_boundery_dice = []
            st = ed

        if ((it + 1) % val_interval == 0 and it + 10 < epoch_iteration) or (it + 1) % epoch_iteration == 0:
            print('validation')
            torch.cuda.empty_cache()
            model.eval()
            loss_all = 0
            loss_n = 0
            with torch.no_grad():
                for i_val, (images_val, labels_val) in enumerate(dlval):
                    if (i_val + 1) % 50 == 0:
                        print(i_val + 1)

                    images_val = images_val.to(device)
                    labels_val = labels_val.to(device)
                    labels_val = torch.squeeze(labels_val, 1)

                    outputs = model(images_val)[0]
                    val_loss = criteria_val(outputs, labels_val)

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                    val_loss_meter.update(val_loss.item())

            print("loss/val_loss", val_loss_meter.avg, it + 1)
            logger.info("Epoch %3d Iter %d Val Loss: %.4f" % (epoch_id, it + 1, val_loss_meter.avg))

            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                logger.info("{}: {}".format(k, v))
                print("val_metrics/{}".format(k), v, it+ 1)

            for k, v in class_iou.items():
                logger.info("{}: {}".format(k, v))
                print("val_metrics/cls_{}".format(k), v, it+ 1)

            val_loss_meter.reset()
            running_metrics_val.reset()

            state = {
                    "epoch": epoch_id,
                    "iteration": it+ 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optim.get_state(),
            }
            save_path = os.path.join(
                writer.file_writer.get_logdir(),
                "{}_{}_checkpoint.pkl".format(model_arch, cfg_data['dataset']),
            )
            torch.save(state, save_path)

            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                state = {
                    "epoch": epoch_id,
                    "iteration":it+ 1,
                    "model_state": model.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_best_model.pkl".format(model_arch, cfg_data['dataset']),
                )
                torch.save(state, save_path)
            torch.cuda.empty_cache()
    it = 0

loss/train_loss 1.891932725906372 10
loss/train_loss 1.7107865810394287 20
loss/train_loss 2.4063501358032227 30
loss/train_loss 1.8365590572357178 40
loss/train_loss 1.7762398719787598 50
loss/train_loss 2.1504526138305664 60
loss/train_loss 1.9234713315963745 70
loss/train_loss 1.776898980140686 80
loss/train_loss 1.8373000621795654 90
loss/train_loss 2.1039223670959473 100
loss/train_loss 1.6998591423034668 110
loss/train_loss 2.2428207397460938 120
loss/train_loss 1.9891433715820312 130
loss/train_loss 2.0237367153167725 140
loss/train_loss 2.2747228145599365 150
loss/train_loss 1.9942841529846191 160
loss/train_loss 2.074845314025879 170
loss/train_loss 2.5347585678100586 180
loss/train_loss 1.9456737041473389 190
loss/train_loss 2.2065742015838623 200
loss/train_loss 2.0525262355804443 210
loss/train_loss 2.0738396644592285 220
loss/train_loss 2.134902000427246 230
loss/train_loss 2.004265308380127 240
loss/train_loss 2.2432117462158203 250
loss/train_loss 1.6663174629211426 260


In [None]:
# with torch.no_grad():
#     for (images_val, labels_val, _) in valloader:
#         images_val = images_val.to(device)
#         labels_val = labels_val.to(device)

#         outputs = model(images_val)
#         outputs = output_val_upsample(outputs)
#         val_loss = loss_fn(input=outputs, target=labels_val)

#         pred = outputs.data.max(1)[1].cpu().numpy()
#         gt = labels_val.data.cpu().numpy()

#         running_metrics_val.update(gt, pred)
#         val_loss_meter.update(val_loss.item())

# writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
# logger.info("Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg))

# score, class_iou = running_metrics_val.get_scores()