# Training Tutorial

## Dependencies & Global Variables

In [1]:
data_root = 'data'
img_dir   = 'setr_images'
ann_dir   = 'setr_annotation_palette'

classes = ('Background', 'Paragraph', 'OtherText', 'VisualFigure')
palette = [[0, 0, 0], [0, 0, 255], [255, 0, 0], [0, 255, 0]]

PATH_TO_CONFIG     = './configs/SETR/SETR_PUP_768x768_40k_cityscapes_bs_8.py'
PATH_TO_CHECKPOINT = './checkpoints/SETR_PUP_cityscapes_b8_40k.pth'
OUTPUT_DIR         = './model_output/3_class_SETR_PUP'

In [2]:
import os
import mmcv
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset


@DATASETS.register_module()
class PrimaDataset(CustomDataset):
    CLASSES = classes
    PALETTE = palette

    def __init__(self, split, **kwargs):
        super().__init__(img_suffix='.png', seg_map_suffix='.png',
                         split=split, **kwargs)
        assert os.path.exists(self.img_dir) and self.split is not None

### Create a config file
In the next step, we need to modify the config for the training. To accelerate the process, we finetune the model from trained weights.

You have to download the config file as well as the checkpoint file from https://github.com/fudan-zvg/SETR#main-results and put them in './checkpoints' and './configs/SETR' respectively.

In [4]:
from mmcv import Config
cfg = Config.fromfile(PATH_TO_CONFIG)

Each config file comes from standard settings which were used for the training of the checkpoint model, so we have to modify the config file in order to satisfy the dimensions of our new dataset and model specifications.

In [5]:
from mmseg.apis import set_random_seed

# Since we use ony one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type = 'BN', requires_grad = True)

cfg.model.backbone.norm_cfg          = cfg.norm_cfg
cfg.model.decode_head.norm_cfg       = cfg.norm_cfg
cfg.model.auxiliary_head[0].norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head[1].norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head[2].norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head[3].norm_cfg = cfg.norm_cfg

# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes       = 4
cfg.model.auxiliary_head[0].num_classes = 4
cfg.model.auxiliary_head[1].num_classes = 4
cfg.model.auxiliary_head[2].num_classes = 4
cfg.model.auxiliary_head[3].num_classes = 4
cfg.model.backbone.num_classes          = 4
# Modify dataset type and path
cfg.dataset_type = 'PrimaDataset'
cfg.data_root = data_root

cfg.classes = classes
cfg.palette = palette


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

cfg.load_from = PATH_TO_CHECKPOINT

# Set up working dir to save files and logs.
cfg.work_dir = OUTPUT_DIR

cfg.classes = classes
cfg.palette = palette


cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = [0]

print(f'Config:\n{cfg.pretty_text}')

Config:
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='VisionTransformer',
        model_name='vit_large_patch16_384',
        img_size=768,
        patch_size=16,
        in_chans=3,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        num_classes=4,
        drop_rate=0.0,
        norm_cfg=dict(type='BN', requires_grad=True),
        pos_embed_interp=True,
        align_corners=False),
    decode_head=dict(
        type='VisionTransformerUpHead',
        in_channels=1024,
        channels=512,
        in_index=23,
        img_size=768,
        embed_dim=1024,
        num_classes=4,
        norm_cfg=dict(type='BN', requires_grad=True),
        num_conv=4,
        upsampling_method='bilinear',
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
        num_upsampe_layer=4),
    auxiliary_head=[
        dict(
           

In [6]:
os.mkdir(OUTPUT_DIR.split('/')[1])
os.mkdir(os.path.join(OUTPUT_DIR.split('/')[1], OUTPUT_DIR.split('/')[2]))

cfg.dump(os.path.join(OUTPUT_DIR, 'config.py'))

### Train and Evaluation

In [6]:
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True,
                meta=dict())

2022-05-31 19:26:12,554 - mmseg - INFO - Loaded 362 images


load pre-trained weight from imagenet21k


2022-05-31 19:26:23,950 - mmseg - INFO - Loaded 91 images
2022-05-31 19:26:23,952 - mmseg - INFO - load checkpoint from ./checkpoints/SETR_PUP_cityscapes_b8_40k.pth

size mismatch for decode_head.conv_seg.weight: copying a param with shape torch.Size([19, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([4, 512, 1, 1]).
size mismatch for decode_head.conv_seg.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([4]).
size mismatch for decode_head.conv_4.weight: copying a param with shape torch.Size([19, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([4, 256, 1, 1]).
size mismatch for decode_head.conv_4.bias: copying a param with shape torch.Size([19]) from checkpoint, the shape in current model is torch.Size([4]).
size mismatch for auxiliary_head.0.conv_seg.weight: copying a param with shape torch.Size([19, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([4, 512, 1, 1

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-05-31 20:20:23,826 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           76.49      78.94
Paragraph            87.89      97.57
OtherText            57.22      69.81
VisualFigure         73.52      95.24
Summary:
Scope                 mIoU       mAcc       aAcc
global               73.78      85.39      88.15

2022-05-31 20:20:23,841 - mmseg - INFO - Iter(val) [4000]	mIoU: 0.7378, mAcc: 0.8539, aAcc: 0.8815
2022-05-31 20:21:01,766 - mmseg - INFO - Iter [4050/40000]	lr: 9.093e-03, eta: 7:57:57, time: 4.226, data_time: 3.471, memory: 23987, decode.loss_seg: 0.3126, decode.acc_seg: 87.4022, aux_0.loss_seg: 0.1237, aux_0.acc_seg: 87.4492, aux_1.loss_seg: 0.1239, aux_1.acc_seg: 87.5170, aux_2.loss_seg: 0.1231, aux_2.acc_seg: 87.5244, aux_3.loss_seg: 0.1221, aux_3.acc_seg: 87.5521, loss: 0.8054
2022-05-31 20:21:39,887 - mmseg - INFO - Iter [4100/40000]	lr: 9.082e-03, eta: 7:57:02, time: 0.762, data_time: 0.004, memory: 23987, decode.loss_seg: 0.3

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-05-31 21:14:19,371 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           85.07      90.94
Paragraph            90.20      95.87
OtherText            56.83      66.76
VisualFigure         82.23      92.78
Summary:
Scope                 mIoU       mAcc       aAcc
global               78.58      86.59      91.65

2022-05-31 21:14:19,399 - mmseg - INFO - Iter(val) [8000]	mIoU: 0.7858, mAcc: 0.8659, aAcc: 0.9165
2022-05-31 21:14:57,240 - mmseg - INFO - Iter [8050/40000]	lr: 8.188e-03, eta: 7:04:45, time: 4.216, data_time: 3.463, memory: 23987, decode.loss_seg: 0.2924, decode.acc_seg: 87.0040, aux_0.loss_seg: 0.1105, aux_0.acc_seg: 87.1836, aux_1.loss_seg: 0.1090, aux_1.acc_seg: 87.1909, aux_2.loss_seg: 0.1093, aux_2.acc_seg: 87.1809, aux_3.loss_seg: 0.1070, aux_3.acc_seg: 87.2127, loss: 0.7282
2022-05-31 21:15:35,754 - mmseg - INFO - Iter [8100/40000]	lr: 8.176e-03, eta: 7:04:00, time: 0.770, data_time: 0.004, memory: 23987, decode.loss_seg: 0.2

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-05-31 22:08:16,669 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           78.92      80.44
Paragraph            88.49      98.19
OtherText            59.99      76.51
VisualFigure         78.97      96.65
Summary:
Scope                 mIoU       mAcc       aAcc
global               76.59      87.95      89.62

2022-05-31 22:08:16,680 - mmseg - INFO - Iter(val) [12000]	mIoU: 0.7659, mAcc: 0.8795, aAcc: 0.8962
2022-05-31 22:08:54,698 - mmseg - INFO - Iter [12050/40000]	lr: 7.270e-03, eta: 6:11:41, time: 4.232, data_time: 3.476, memory: 23987, decode.loss_seg: 0.2832, decode.acc_seg: 87.3254, aux_0.loss_seg: 0.1107, aux_0.acc_seg: 87.1134, aux_1.loss_seg: 0.1059, aux_1.acc_seg: 87.3190, aux_2.loss_seg: 0.1056, aux_2.acc_seg: 87.3274, aux_3.loss_seg: 0.1042, aux_3.acc_seg: 87.4312, loss: 0.7096
2022-05-31 22:09:32,888 - mmseg - INFO - Iter [12100/40000]	lr: 7.259e-03, eta: 6:10:57, time: 0.764, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-05-31 23:02:15,840 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           83.23      86.18
Paragraph            91.10      98.13
OtherText            56.64      76.49
VisualFigure         83.14      93.97
Summary:
Scope                 mIoU       mAcc       aAcc
global               78.53      88.69      91.39

2022-05-31 23:02:15,851 - mmseg - INFO - Iter(val) [16000]	mIoU: 0.7853, mAcc: 0.8869, aAcc: 0.9139
2022-05-31 23:02:53,770 - mmseg - INFO - Iter [16050/40000]	lr: 6.340e-03, eta: 5:18:34, time: 4.223, data_time: 3.469, memory: 23987, decode.loss_seg: 0.2774, decode.acc_seg: 88.1690, aux_0.loss_seg: 0.1106, aux_0.acc_seg: 87.7637, aux_1.loss_seg: 0.1079, aux_1.acc_seg: 88.1818, aux_2.loss_seg: 0.1083, aux_2.acc_seg: 88.0527, aux_3.loss_seg: 0.1074, aux_3.acc_seg: 88.1379, loss: 0.7115
2022-05-31 23:03:31,880 - mmseg - INFO - Iter [16100/40000]	lr: 6.328e-03, eta: 5:17:51, time: 0.762, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-05-31 23:56:14,850 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           82.60      85.38
Paragraph            91.98      97.49
OtherText            59.08      79.77
VisualFigure         81.87      95.87
Summary:
Scope                 mIoU       mAcc       aAcc
global               78.88      89.63      91.37

2022-05-31 23:56:14,861 - mmseg - INFO - Iter(val) [20000]	mIoU: 0.7888, mAcc: 0.8963, aAcc: 0.9137
2022-05-31 23:56:52,751 - mmseg - INFO - Iter [20050/40000]	lr: 5.394e-03, eta: 4:25:24, time: 4.251, data_time: 3.497, memory: 23987, decode.loss_seg: 0.1291, decode.acc_seg: 90.6554, aux_0.loss_seg: 0.0583, aux_0.acc_seg: 89.9980, aux_1.loss_seg: 0.0563, aux_1.acc_seg: 90.2516, aux_2.loss_seg: 0.0564, aux_2.acc_seg: 90.2019, aux_3.loss_seg: 0.0564, aux_3.acc_seg: 90.1891, loss: 0.3565
2022-05-31 23:57:30,877 - mmseg - INFO - Iter [20100/40000]	lr: 5.382e-03, eta: 4:24:42, time: 0.762, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 144s, ETA:     0s

2022-06-01 00:50:17,993 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           80.59      82.52
Paragraph            91.25      98.15
OtherText            53.82      76.41
VisualFigure         81.00      96.52
Summary:
Scope                 mIoU       mAcc       aAcc
global               76.67      88.40      90.40

2022-06-01 00:50:18,003 - mmseg - INFO - Iter(val) [24000]	mIoU: 0.7667, mAcc: 0.8840, aAcc: 0.9040
2022-06-01 00:50:55,896 - mmseg - INFO - Iter [24050/40000]	lr: 4.428e-03, eta: 3:32:15, time: 4.255, data_time: 3.501, memory: 23987, decode.loss_seg: 0.1486, decode.acc_seg: 91.9390, aux_0.loss_seg: 0.0675, aux_0.acc_seg: 91.4371, aux_1.loss_seg: 0.0671, aux_1.acc_seg: 91.5057, aux_2.loss_seg: 0.0655, aux_2.acc_seg: 91.5175, aux_3.loss_seg: 0.0654, aux_3.acc_seg: 91.4748, loss: 0.4140
2022-06-01 00:51:34,006 - mmseg - INFO - Iter [24100/40000]	lr: 4.416e-03, eta: 3:31:34, time: 0.762, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 143s, ETA:     0s

2022-06-01 01:44:24,543 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           83.59      86.76
Paragraph            91.32      97.39
OtherText            55.52      79.79
VisualFigure         86.09      94.76
Summary:
Scope                 mIoU       mAcc       aAcc
global               79.13      89.68      91.68

2022-06-01 01:44:24,555 - mmseg - INFO - Iter(val) [28000]	mIoU: 0.7913, mAcc: 0.8968, aAcc: 0.9168
2022-06-01 01:45:02,462 - mmseg - INFO - Iter [28050/40000]	lr: 3.438e-03, eta: 2:39:05, time: 4.230, data_time: 3.476, memory: 23987, decode.loss_seg: 0.1436, decode.acc_seg: 92.1927, aux_0.loss_seg: 0.0637, aux_0.acc_seg: 91.8824, aux_1.loss_seg: 0.0624, aux_1.acc_seg: 91.9408, aux_2.loss_seg: 0.0624, aux_2.acc_seg: 91.7989, aux_3.loss_seg: 0.0613, aux_3.acc_seg: 91.9464, loss: 0.3934
2022-06-01 01:45:40,599 - mmseg - INFO - Iter [28100/40000]	lr: 3.425e-03, eta: 2:38:24, time: 0.763, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 144s, ETA:     0s

2022-06-01 02:38:28,264 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           80.50      83.18
Paragraph            89.93      98.50
OtherText            46.55      70.42
VisualFigure         83.36      93.74
Summary:
Scope                 mIoU       mAcc       aAcc
global               75.09      86.46      89.94

2022-06-01 02:38:28,272 - mmseg - INFO - Iter(val) [32000]	mIoU: 0.7509, mAcc: 0.8646, aAcc: 0.8994
2022-06-01 02:39:06,299 - mmseg - INFO - Iter [32050/40000]	lr: 2.413e-03, eta: 1:45:51, time: 4.253, data_time: 3.497, memory: 23987, decode.loss_seg: 0.1609, decode.acc_seg: 91.6894, aux_0.loss_seg: 0.0715, aux_0.acc_seg: 90.9660, aux_1.loss_seg: 0.0674, aux_1.acc_seg: 91.1963, aux_2.loss_seg: 0.0664, aux_2.acc_seg: 91.4076, aux_3.loss_seg: 0.0659, aux_3.acc_seg: 91.3745, loss: 0.4321
2022-06-01 02:39:44,519 - mmseg - INFO - Iter [32100/40000]	lr: 2.400e-03, eta: 1:45:10, time: 0.764, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 146s, ETA:     0s

2022-06-01 03:32:33,529 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           85.88      90.10
Paragraph            91.99      97.63
OtherText            58.43      76.50
VisualFigure         86.66      93.26
Summary:
Scope                 mIoU       mAcc       aAcc
global               80.74      89.37      92.63

2022-06-01 03:32:33,571 - mmseg - INFO - Iter(val) [36000]	mIoU: 0.8074, mAcc: 0.8937, aAcc: 0.9263
2022-06-01 03:33:11,488 - mmseg - INFO - Iter [36050/40000]	lr: 1.333e-03, eta: 0:52:35, time: 4.302, data_time: 3.548, memory: 23987, decode.loss_seg: 0.1133, decode.acc_seg: 92.9210, aux_0.loss_seg: 0.0549, aux_0.acc_seg: 92.1240, aux_1.loss_seg: 0.0531, aux_1.acc_seg: 92.2335, aux_2.loss_seg: 0.0522, aux_2.acc_seg: 92.3181, aux_3.loss_seg: 0.0522, aux_3.acc_seg: 92.2966, loss: 0.3257
2022-06-01 03:33:49,608 - mmseg - INFO - Iter [36100/40000]	lr: 1.319e-03, eta: 0:51:55, time: 0.762, data_time: 0.004, memory: 23987, decode.loss_seg: 

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 91/91, 0.6 task/s, elapsed: 145s, ETA:     0s

2022-06-01 04:26:37,629 - mmseg - INFO - per class results:
Class                  IoU        Acc
Background           85.57      88.99
Paragraph            92.33      97.90
OtherText            60.78      80.29
VisualFigure         86.63      94.39
Summary:
Scope                 mIoU       mAcc       aAcc
global               81.33      90.39      92.71

2022-06-01 04:26:37,666 - mmseg - INFO - Iter(val) [40000]	mIoU: 0.8133, mAcc: 0.9039, aAcc: 0.9271
