# BiSeNet v2 구현  
- pytorch

In [1]:
import os
import time
import random
import warnings
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from tqdm.notebook import tqdm

import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import torch.utils.model_zoo as modelzoo
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

from torchcallback import EarlyStopping, CheckPoint
from torchlosses import OhemCELoss
from torchmetrics import Metrics
from torchscheduler import PolynomialLRDecay
from torchtransform import RandomCrop, HorizontalFlip, RandomScale, ColorJitter, Compose

from model import BiSeNetV2
from cityscapes import CityscapesDataset
from train import train_step

In [2]:
warnings.filterwarnings('ignore')
%matplotlib inline
plt.rcParams['figure.figsize'] = [10,7]

# Load Data

In [3]:
width = 1024
height = 512
num_classes = 19

# cuda setting
device = torch.device('cuda')
print(device)

cuda


In [4]:
path = 'C:/Users/user/MY_DL/segmentation/dataset/cityscapes'

batch_size = 16

train_loader = DataLoader(
    CityscapesDataset(path=path, subset='train', cropsize=(width,height)),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

valid_loader = DataLoader(
    CityscapesDataset(path=path, subset='valid', cropsize=(width,height)),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

# Model

In [5]:
summary(BiSeNetV2(num_classes=num_classes), (3, height, width), device='cpu')
model = BiSeNetV2(num_classes=num_classes).to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 512]           1,728
       BatchNorm2d-2         [-1, 64, 256, 512]             128
              ReLU-3         [-1, 64, 256, 512]               0
         ConvBlock-4         [-1, 64, 256, 512]               0
            Conv2d-5         [-1, 64, 256, 512]          36,864
       BatchNorm2d-6         [-1, 64, 256, 512]             128
              ReLU-7         [-1, 64, 256, 512]               0
         ConvBlock-8         [-1, 64, 256, 512]               0
            Conv2d-9         [-1, 64, 128, 256]          36,864
      BatchNorm2d-10         [-1, 64, 128, 256]             128
             ReLU-11         [-1, 64, 128, 256]               0
        ConvBlock-12         [-1, 64, 128, 256]               0
           Conv2d-13         [-1, 64, 128, 256]          36,864
      BatchNorm2d-14         [-1, 64, 1

# Set optimizer, loss function

In [6]:
lr = 0.05
es_save_path = './model/es_checkpoint.pt'
cp_save_path = './model/cp_checkpoint.pt'
loss_func = OhemCELoss(thresh=0.7, ignore_lb=255).to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
lr_scheduler = PolynomialLRDecay(optimizer, max_decay_steps=1000)
metric = Metrics(n_classes=num_classes, dim=1)
checkpoint = CheckPoint(verbose=True, path=cp_save_path)
early_stopping = EarlyStopping(patience=300, verbose=True, path=es_save_path)

In [7]:
def valid_step(model, validation_data):
    model.eval()
    with torch.no_grad():
        vbatch_loss, vbatch_miou = 0, 0
        for vbatch, (val_images, val_labels) in enumerate(validation_data):
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            
            val_outputs, v_s2, v_s3, v_s4, v_s5 = model(val_images)
            
            val_miou = metric.mean_iou(val_outputs, val_labels)
            vbatch_miou += val_miou.item()
            
            p_val_loss = loss_func(val_outputs, val_labels.squeeze())
            a_val_loss1 = loss_func(v_s2, val_labels.squeeze())
            a_val_loss2 = loss_func(v_s3, val_labels.squeeze())
            a_val_loss3 = loss_func(v_s4, val_labels.squeeze())
            a_val_loss4 = loss_func(v_s5, val_labels.squeeze())
            val_loss = p_val_loss + (a_val_loss1+a_val_loss2+a_val_loss3+a_val_loss4)
            vbatch_loss += val_loss.item()
            
            del val_images; del val_labels; del val_outputs
            del v_s2; del v_s3; del v_s4; del v_s5
            torch.cuda.empty_cache()
            
    return vbatch_loss/(vbatch+1), vbatch_miou/(vbatch+1)

def train_on_batch(model, train_data):
    batch_loss, batch_miou = 0, 0
    for batch, (train_images, train_labels) in enumerate(train_data):
        model.train()

        train_images = train_images.to(device)
        train_labels = train_labels.to(device)

        optimizer.zero_grad()

        train_outputs, s2, s3, s4, s5 = model(train_images)

        miou = metric.mean_iou(train_outputs, train_labels)
        batch_miou += miou.item()

        p_loss = loss_func(train_outputs, train_labels.squeeze())
        a_loss1 = loss_func(s2, train_labels.squeeze())
        a_loss2 = loss_func(s3, train_labels.squeeze())
        a_loss3 = loss_func(s4, train_labels.squeeze())
        a_loss4 = loss_func(s5, train_labels.squeeze())
        loss = p_loss + (a_loss1+a_loss2+a_loss3+a_loss4)
        batch_loss += loss.item()

        loss.backward()
        optimizer.step()

        del train_images; del train_labels; del train_outputs
        del s2; del s3; del s4; del s5
        torch.cuda.empty_cache()

    return batch_loss/(batch+1), batch_miou/(batch+1)

def train_step(model,
               train_data,
               validation_data,
               epochs,
               learning_rate_scheduler=False,
               check_point=False,
               early_stop=False,
               last_epoch_save_path='./model/last_checkpoint.pt'):
    
    loss_list, miou_list = [], []
    val_loss_list, val_miou_list = [], []
    
    print('Start Model Training...!')
    start_training = time.time()
    for epoch in tqdm(range(epochs)):
        init_time = time.time()
        
        train_loss, train_miou = train_on_batch(model, train_data)
        loss_list.append(train_loss)
        miou_list.append(train_miou)
            
        val_loss, val_miou = valid_step(model, validation_data)
        val_loss_list.append(val_loss)
        val_miou_list.append(val_miou)
        
        end_time = time.time()
        
        print(f'\n[Epoch {epoch+1}/{epochs}]'
              f'  [time: {end_time-init_time:.3f}s]'
              f'  [lr = {optimizer.param_groups[0]["lr"]}]')
        print(f'[train loss: {train_loss:.3f}]'
              f'  [train miou: {train_miou:.3f}]'
              f'  [valid loss: {val_loss:.3f}]'
              f'  [valid miou: {val_miou:.3f}]')
        
        if learning_rate_scheduler:
            lr_scheduler.step()
            
        if check_point:
            checkpoint(val_loss, model)
            
        if early_stop:
            assert check_point==False, 'Choose between Early Stopping and Check Point'
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print('\n##########################\n'
                      '##### Early Stopping #####\n'
                      '##########################')
                break
                
    if early_stop==False and check_point==False:
        torch.save(model.state_dict(), last_epoch_save_path)
        print('Saving model of last epoch.')
        
    end_training = time.time()
    print(f'\nTotal time for training is {end_training-start_training:.3f}s')
    
    return {
        'model': model, 
        'loss': loss_list, 
        'miou': miou_list, 
        'val_loss': val_loss_list, 
        'val_miou': val_miou_list
        }

In [None]:
EPOCH = 1000

history = train_step(
    model,
    train_data=train_loader,
    validation_data=valid_loader,
    epochs=EPOCH,
    learning_rate_scheduler=True,
    check_point=True,
    early_stop=False,
)

Start Model Training...!


  0%|          | 0/1000 [00:00<?, ?it/s]


[Epoch 1/1000]  [time: 771.766s]  [lr = 0.049955006749624734]
[train loss: 11.064]  [train miou: 0.037]  [valid loss: 7.820]  [valid miou: 0.078]
Validation loss decreased (inf --> 7.820).  Saving model ...

[Epoch 2/1000]  [time: 757.138s]  [lr = 0.049910008995194384]
[train loss: 8.531]  [train miou: 0.080]  [valid loss: 8.160]  [valid miou: 0.098]

[Epoch 3/1000]  [time: 756.575s]  [lr = 0.049865006731744314]
[train loss: 8.170]  [train miou: 0.114]  [valid loss: 8.015]  [valid miou: 0.124]

[Epoch 4/1000]  [time: 755.098s]  [lr = 0.049819999954299435]
[train loss: 7.919]  [train miou: 0.138]  [valid loss: 8.861]  [valid miou: 0.122]

[Epoch 5/1000]  [time: 756.390s]  [lr = 0.049774988657874136]
[train loss: 7.817]  [train miou: 0.154]  [valid loss: 7.032]  [valid miou: 0.137]
Validation loss decreased (7.820 --> 7.032).  Saving model ...

[Epoch 6/1000]  [time: 756.157s]  [lr = 0.04972997283747233]
[train loss: 7.654]  [train miou: 0.165]  [valid loss: 7.527]  [valid miou: 0.177]



[Epoch 55/1000]  [time: 769.123s]  [lr = 0.047518548776340576]
[train loss: 6.650]  [train miou: 0.353]  [valid loss: 7.259]  [valid miou: 0.282]

[Epoch 56/1000]  [time: 766.020s]  [lr = 0.047473300144500546]
[train loss: 6.579]  [train miou: 0.353]  [valid loss: 6.678]  [valid miou: 0.331]

[Epoch 57/1000]  [time: 767.205s]  [lr = 0.047428046719118405]
[train loss: 6.603]  [train miou: 0.360]  [valid loss: 6.600]  [valid miou: 0.347]

[Epoch 58/1000]  [time: 767.920s]  [lr = 0.04738278849460224]
[train loss: 6.572]  [train miou: 0.358]  [valid loss: 6.890]  [valid miou: 0.315]

[Epoch 59/1000]  [time: 769.967s]  [lr = 0.04733752546534768]
[train loss: 6.535]  [train miou: 0.358]  [valid loss: 7.340]  [valid miou: 0.287]

[Epoch 60/1000]  [time: 765.900s]  [lr = 0.04729225762573782]
[train loss: 6.556]  [train miou: 0.364]  [valid loss: 7.173]  [valid miou: 0.328]

[Epoch 61/1000]  [time: 766.019s]  [lr = 0.04724698497014323]
[train loss: 6.598]  [train miou: 0.369]  [valid loss: 7.4


[Epoch 111/1000]  [time: 756.147s]  [lr = 0.044977083588671966]
[train loss: 6.515]  [train miou: 0.402]  [valid loss: 7.403]  [valid miou: 0.360]

[Epoch 112/1000]  [time: 756.958s]  [lr = 0.0449315575455412]
[train loss: 6.471]  [train miou: 0.410]  [valid loss: 6.503]  [valid miou: 0.394]

[Epoch 113/1000]  [time: 756.710s]  [lr = 0.04488602637531396]
[train loss: 6.479]  [train miou: 0.411]  [valid loss: 6.835]  [valid miou: 0.370]

[Epoch 114/1000]  [time: 754.820s]  [lr = 0.044840490071631586]
[train loss: 6.478]  [train miou: 0.419]  [valid loss: 6.882]  [valid miou: 0.352]

[Epoch 115/1000]  [time: 758.727s]  [lr = 0.04479494862812035]
[train loss: 6.460]  [train miou: 0.418]  [valid loss: 6.655]  [valid miou: 0.380]

[Epoch 116/1000]  [time: 756.616s]  [lr = 0.044749402038391395]
[train loss: 6.517]  [train miou: 0.407]  [valid loss: 6.713]  [valid miou: 0.339]

[Epoch 117/1000]  [time: 763.527s]  [lr = 0.04470385029604071]
[train loss: 6.438]  [train miou: 0.417]  [valid los


[Epoch 166/1000]  [time: 754.917s]  [lr = 0.042465365327403026]
[train loss: 6.415]  [train miou: 0.436]  [valid loss: 6.392]  [valid miou: 0.412]

[Epoch 167/1000]  [time: 757.365s]  [lr = 0.042419547437043854]
[train loss: 6.421]  [train miou: 0.430]  [valid loss: 6.857]  [valid miou: 0.405]

[Epoch 168/1000]  [time: 756.194s]  [lr = 0.04237372404600615]
[train loss: 6.367]  [train miou: 0.445]  [valid loss: 6.725]  [valid miou: 0.382]

[Epoch 169/1000]  [time: 753.957s]  [lr = 0.04232789514701694]
[train loss: 6.401]  [train miou: 0.428]  [valid loss: 6.229]  [valid miou: 0.405]

[Epoch 170/1000]  [time: 754.586s]  [lr = 0.042282060732784864]
[train loss: 6.448]  [train miou: 0.435]  [valid loss: 7.048]  [valid miou: 0.365]

[Epoch 171/1000]  [time: 758.405s]  [lr = 0.04223622079600013]
[train loss: 6.388]  [train miou: 0.435]  [valid loss: 6.600]  [valid miou: 0.367]

[Epoch 172/1000]  [time: 756.583s]  [lr = 0.04219037532933442]
[train loss: 6.372]  [train miou: 0.431]  [valid lo


[Epoch 222/1000]  [time: 772.775s]  [lr = 0.03989088379081062]
[train loss: 6.350]  [train miou: 0.444]  [valid loss: 6.737]  [valid miou: 0.379]

[Epoch 223/1000]  [time: 772.526s]  [lr = 0.03984474612598474]
[train loss: 6.398]  [train miou: 0.443]  [valid loss: 6.130]  [valid miou: 0.398]

[Epoch 224/1000]  [time: 758.412s]  [lr = 0.039798602522851605]
[train loss: 6.373]  [train miou: 0.453]  [valid loss: 6.344]  [valid miou: 0.400]

[Epoch 225/1000]  [time: 759.722s]  [lr = 0.039752452972992985]
[train loss: 6.410]  [train miou: 0.447]  [valid loss: 6.557]  [valid miou: 0.387]

[Epoch 226/1000]  [time: 759.348s]  [lr = 0.0397062974679678]
[train loss: 6.379]  [train miou: 0.446]  [valid loss: 6.297]  [valid miou: 0.347]

[Epoch 227/1000]  [time: 757.779s]  [lr = 0.03966013599931209]
[train loss: 6.379]  [train miou: 0.448]  [valid loss: 6.095]  [valid miou: 0.441]

[Epoch 228/1000]  [time: 763.425s]  [lr = 0.0396139685585389]
[train loss: 6.328]  [train miou: 0.452]  [valid loss:


[Epoch 278/1000]  [time: 774.346s]  [lr = 0.03729778746924719]
[train loss: 6.363]  [train miou: 0.461]  [valid loss: 6.808]  [valid miou: 0.405]

[Epoch 279/1000]  [time: 774.525s]  [lr = 0.037251303627232626]
[train loss: 6.281]  [train miou: 0.469]  [valid loss: 6.512]  [valid miou: 0.403]

[Epoch 280/1000]  [time: 776.403s]  [lr = 0.03720481333763396]
[train loss: 6.347]  [train miou: 0.460]  [valid loss: 6.646]  [valid miou: 0.427]

[Epoch 281/1000]  [time: 775.348s]  [lr = 0.03715831659060003]
[train loss: 6.285]  [train miou: 0.461]  [valid loss: 7.240]  [valid miou: 0.387]

[Epoch 282/1000]  [time: 764.800s]  [lr = 0.037111813376250885]
[train loss: 6.426]  [train miou: 0.458]  [valid loss: 6.369]  [valid miou: 0.409]

[Epoch 283/1000]  [time: 759.890s]  [lr = 0.037065303684677704]
[train loss: 6.341]  [train miou: 0.461]  [valid loss: 6.723]  [valid miou: 0.386]

[Epoch 284/1000]  [time: 763.188s]  [lr = 0.03701878750594259]
[train loss: 6.361]  [train miou: 0.457]  [valid lo


[Epoch 334/1000]  [time: 756.176s]  [lr = 0.034684478895082284]
[train loss: 6.322]  [train miou: 0.461]  [valid loss: 6.550]  [valid miou: 0.412]

[Epoch 335/1000]  [time: 761.184s]  [lr = 0.03463761797140863]
[train loss: 6.285]  [train miou: 0.465]  [valid loss: 6.706]  [valid miou: 0.418]

[Epoch 336/1000]  [time: 762.931s]  [lr = 0.034590750000446875]
[train loss: 6.299]  [train miou: 0.470]  [valid loss: 6.306]  [valid miou: 0.427]

[Epoch 337/1000]  [time: 764.157s]  [lr = 0.03454387497052144]
[train loss: 6.301]  [train miou: 0.465]  [valid loss: 6.428]  [valid miou: 0.402]


In [None]:
# check losses and miou scores graph
fig, ax = plt.subplots(1,2, figsize=(20,10))
ax[0].set_title('Loss Graph', fontsize=20)
ax[0].plot(np.arange(len(history['loss'])), history['loss'], label='Train Loss')
ax[0].plot(np.arange(len(history['val_loss'])), history['val_loss'], label='Valid Loss')
ax[0].set_ylim(0,10)
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend(loc='best')
ax[1].set_title('mIoU Score Graph', fontsize=20)
ax[1].plot(np.arange(len(history['miou'])), history['miou'], label='Train mIoU')
ax[1].plot(np.arange(len(history['val_miou'])), history['val_miou'], label='Valid mIoU')
ax[1].set_ylim(0,1)
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('mIoU')
ax[1].legend(loc='best')
fig.show()

# test model

In [None]:
# load test data
test_augment = A.Compose([
    A.Resize(height=height//2,
             width=width//2)
])

train_loader = DataLoader(
    CityscapesDataset(path=path, subset='train', cropsize=(width,height)),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

test_loader = DataLoader(
    CityscapesDataset(path=path, subset='test'),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
)

In [None]:
# test
test_images_list = []
test_outputs_list = []

model.eval()
with torch.no_grad():
    tbatch_loss, tbatch_miou = 0, 0
    for tbatch, test_images in enumerate(test_loader):
        test_images = test_images.to(device)
        # predict
        test_outputs, _, _, _, _ = model(test_images)
        
        test_images_list.append(test_images.detach().cpu())
        test_outputs_list.append(test_outputs.detach().cpu())
        
        del test_images; del test_outputs
    
test_images = torch.cat(test_images_list, dim=0)
del test_images_list
test_outputs = torch.cat(test_outputs_list, dim=0)
del test_outputs_list

In [None]:
# 12 classes channels to 3 RGB channels
def map_class_to_rgb(p):
    return rgb_array[p[0]]

def mask2rgb(images):
    rgb_img_list = []
    for img in tqdm(images):
        img = img.detach().cpu().numpy()
        img = np.transpose(img, (1,2,0))
        img = np.argmax(img, axis=-1)
        rgb_img = np.apply_along_axis(map_class_to_rgb, -1, np.expand_dims(img, -1))
        rgb_img_list.append(rgb_img)
    return np.array(rgb_img_list)

test_result = mask2rgb(test_outputs)
del test_outputs

In [None]:
# show test outputs with miou score
def cuda2numpy(tensors):
    tensors = tensors.detach().cpu().numpy()
    return np.transpose(tensors, (0,2,3,1))

def show_result(input_image, pred_image, ncols):
    input_image = cuda2numpy(input_image)
    for i in range(ncols):
        # plot images
        fig, ax = plt.subplots(1,2, figsize=(20,10))
        ax[0].imshow(input_image[i])
        ax[0].axis('off')
        ax[0].set_title('Input Image')
        ax[1].imshow(pred_image[i])
        ax[1].axis('off')
        ax[1].set_title(f'Preditect Image')
        fig.show()

show_result(test_images, test_result, 20)

In [None]:
model = BiSeNetV2(input_size=(height//2,width//2), bga_size=(height//16,width//16), phase='train').to(device)
model.load_state_dict(torch.load('./model/cp_checkpoint.pt'))

In [None]:
# test
test_images_list = []
test_outputs_list = []

model.eval()
with torch.no_grad():
    tbatch_loss, tbatch_miou = 0, 0
    for tbatch, test_images in enumerate(test_loader):
        test_images = test_images.to(device)
        # predict
        test_outputs, _, _, _, _ = model(test_images)
        
        test_images_list.append(test_images.detach().cpu())
        test_outputs_list.append(test_outputs.detach().cpu())
        
        del test_images; del test_outputs
    
test_images = torch.cat(test_images_list, dim=0)
del test_images_list
test_outputs = torch.cat(test_outputs_list, dim=0)
del test_outputs_list

In [None]:
test_result = mask2rgb(test_outputs)
del test_outputs

In [None]:
show_result(test_images, test_result, 70)

In [None]:
import torch
import math

class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler):

    def __init__(
            self,
            optimizer,
            warmup_iter=500,
            warmup_ratio=5e-4,
            warmup='exp',
            last_epoch=-1,
    ):
        self.warmup_iter = warmup_iter
        self.warmup_ratio = warmup_ratio
        self.warmup = warmup
        super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        ratio = self.get_lr_ratio()
        lrs = [ratio * lr for lr in self.base_lrs]
        return lrs

    def get_lr_ratio(self):
        if self.last_epoch < self.warmup_iter:
            ratio = self.get_warmup_ratio()
        else:
            ratio = self.get_main_ratio()
        return ratio

    def get_main_ratio(self):
        raise NotImplementedError

    def get_warmup_ratio(self):
        assert self.warmup in ('linear', 'exp')
        alpha = self.last_epoch / self.warmup_iter
        if self.warmup == 'linear':
            ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
        elif self.warmup == 'exp':
            ratio = self.warmup_ratio ** (1. - alpha)
        return ratio


class WarmupPolyLrScheduler(WarmupLrScheduler):

    def __init__(
            self,
            optimizer,
            power,
            max_iter,
            warmup_iter=500,
            warmup_ratio=5e-4,
            warmup='exp',
            last_epoch=-1,
    ):
        self.power = power
        self.max_iter = max_iter
        super(WarmupPolyLrScheduler, self).__init__(
            optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)

    def get_main_ratio(self):
        real_iter = self.last_epoch - self.warmup_iter
        real_max_iter = self.max_iter - self.warmup_iter
        alpha = real_iter / real_max_iter
        ratio = (1 - alpha) ** self.power
        return ratio

In [None]:
import matplotlib.pyplot as plt

model = torch.nn.Linear(2, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = WarmupPolyLrScheduler(optimizer=optimizer,
                                  power=0.9, max_iter=100)


lrs = []

for i in range(100):
    optimizer.step()
    lrs.append(optimizer.param_groups[0]["lr"])
#     print("Factor = ", round(0.65 ** i,3)," , Learning Rate = ",round(optimizer.param_groups[0]["lr"],3))
    scheduler.step()

plt.plot(range(100),lrs)