In [1]:
# System libs
import os
import time
# import math
import random
import argparse
from distutils.version import LooseVersion
# Numerical libs
import torch
import torch.nn as nn
# Our libs
from config import *
from dataset import TrainDataset
from models import models #ModelBuilder, SegmentationModule
from models import *
from utils import AverageMeter, parse_devices, setup_logger
from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback

In [2]:
from matplotlib import pyplot as plt

In [3]:
from yacs.config import CfgNode as CN
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------

_C = CN()
_C.DIR = "ckpt/resnet50-upernet"

# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASET = CN()
_C.DATASET.root_dataset = "./data/"
_C.DATASET.list_train = "./data/training.odgt"
_C.DATASET.list_val = "./data/validation.odgt"
_C.DATASET.num_class = 2
# multiscale train/test, size of short edge (int or tuple)
_C.DATASET.imgSizes = (300, 375, 450, 525, 600)
# maximum input image size of long edge
_C.DATASET.imgMaxSize = 1000
# maxmimum downsampling rate of the network
_C.DATASET.padding_constant = 8
# downsampling rate of the segmentation label
_C.DATASET.segm_downsampling_rate = 8
# randomly horizontally flip images when train/test
_C.DATASET.random_flip = True

# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# architecture of net_encoder
_C.MODEL.arch_encoder = "resnet50"
# architecture of net_decoder
_C.MODEL.arch_decoder = "upernet"
# weights to finetune net_encoder
_C.MODEL.weights_encoder = ""
# weights to finetune net_decoder
_C.MODEL.weights_decoder = ""
# number of feature channels between encoder and decoder
_C.MODEL.fc_dim = 2048

# -----------------------------------------------------------------------------
# Training
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.batch_size_per_gpu = 1
# epochs to train for
_C.TRAIN.num_epoch = 1
# epoch to start training. useful if continue from a checkpoint
_C.TRAIN.start_epoch = 0
# iterations of each epoch (irrelevant to batch size)
_C.TRAIN.epoch_iters = 22

_C.TRAIN.optim = "SGD"
_C.TRAIN.lr_encoder = 0.02
_C.TRAIN.lr_decoder = 0.02
# power in poly to drop LR
_C.TRAIN.lr_pow = 0.9
# momentum for sgd, beta1 for adam
_C.TRAIN.beta1 = 0.9
# weights regularizer
_C.TRAIN.weight_decay = 1e-4
# the weighting of deep supervision loss
_C.TRAIN.deep_sup_scale = 0.4
# fix bn params, only under finetuning
_C.TRAIN.fix_bn = False
# number of data loading workers
_C.TRAIN.workers = 1

# frequency to display
_C.TRAIN.disp_iter = 20
# manual seed
_C.TRAIN.seed = 304

# -----------------------------------------------------------------------------
# Validation
# -----------------------------------------------------------------------------
_C.VAL = CN()
# currently only supports 1
_C.VAL.batch_size = 1
# output visualization during validation
_C.VAL.visualize = False
# the checkpoint to evaluate on
_C.VAL.checkpoint = "epoch_20.pth"

# -----------------------------------------------------------------------------
# Testing
# -----------------------------------------------------------------------------
_C.TEST = CN()
# currently only supports 1
_C.TEST.batch_size = 1
# the checkpoint to test on
_C.TEST.checkpoint = "epoch_20.pth"
# folder to output visualization results
_C.TEST.result = "./"


cfg = _C

parser = argparse.ArgumentParser(
    description="PyTorch Semantic Segmentation Training"
)
parser.add_argument(
    "--cfg",
    default="configuration/resnet50dilated-ppm_deepsup.yaml",
    metavar="FILE",
    help="path to config file",
    type=str,
)
# parser.add_argument(
#     "--gpus",
#     default="0-3",
#     help="gpus to use, e.g. 0-3 or 0,1,2,3"
# )
parser.add_argument(
    "--gpus",
    default="0",
    help="gpus to use, e.g. 0-3 or 0,1,2,3"
)
parser.add_argument(
    "opts",
    help="Modify config options using the command-line",
    default=None,
    nargs=argparse.REMAINDER,
)
args = parser.parse_args(args=[])

# cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)

if not os.path.isdir(cfg.DIR):
    os.makedirs(cfg.DIR)

with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f:
    f.write("{}".format(cfg))

# Start from checkpoint
if cfg.TRAIN.start_epoch > 0:
    cfg.MODEL.weights_encoder = os.path.join(
        cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
    cfg.MODEL.weights_decoder = os.path.join(
        cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
    assert os.path.exists(cfg.MODEL.weights_encoder) and \
        os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"

# Parse gpu ids
gpus = parse_devices(args.gpus)
gpus = [x.replace('gpu', '') for x in gpus]
gpus = [int(x) for x in gpus]
num_gpus = len(gpus)
cfg.TRAIN.batch_size = num_gpus * cfg.TRAIN.batch_size_per_gpu

cfg.TRAIN.max_iters = cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch
cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder
cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder

random.seed(cfg.TRAIN.seed)
torch.manual_seed(cfg.TRAIN.seed)

<torch._C.Generator at 0x1ed652060d0>

In [4]:
# Dataset and Loader
dataset_train = TrainDataset(
    cfg.DATASET.root_dataset,
    cfg.DATASET.list_train,
    cfg.DATASET,
    batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)

# samples: 1


In [5]:
loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=len(gpus),  # we have modified data_parallel
        shuffle=False,  # we do not use this param
        collate_fn=user_scattered_collate,
        num_workers=cfg.TRAIN.workers,
        drop_last=True,
        pin_memory=True)

In [6]:
iterator_train = iter(loader_train)

In [7]:
a = next(iterator_train)

In [8]:
a = a[0]

In [9]:
a['img_data'].shape

torch.Size([1, 3, 450, 450])

In [10]:
a['seg_label'].shape

torch.Size([1, 450, 450])

In [12]:
# plt.imshow(a['img_data'][0, :, : , :].permute((1, 2, 0)))

In [None]:
# plt.imshow(a['seg_label'][0, :, :])

In [None]:
# plt.imshow(a['img_data'][1, :, : , :].permute((1, 2, 0)))

In [None]:
# plt.imshow(a['seg_label'][1, :, :])

In [13]:
net_encoder = models.ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)

In [14]:
net_encoder

Resnet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [15]:
code = net_encoder(a['img_data'], return_feature_maps=True)

In [16]:
print(code[0].shape)
print(code[1].shape)
print(code[2].shape)
print(code[3].shape)

torch.Size([1, 256, 113, 113])
torch.Size([1, 512, 57, 57])
torch.Size([1, 1024, 29, 29])
torch.Size([1, 2048, 15, 15])


In [18]:
net_decoder = models.ModelBuilder.build_decoder(
        arch='upernet',
        fc_dim=cfg.MODEL.fc_dim,
        num_class=cfg.DATASET.num_class,
        weights=cfg.MODEL.weights_decoder,
        use_softmax=True
)

In [19]:
net_decoder

UPerNet(
  (ppm_pooling): ModuleList(
    (0): AdaptiveAvgPool2d(output_size=1)
    (1): AdaptiveAvgPool2d(output_size=2)
    (2): AdaptiveAvgPool2d(output_size=3)
    (3): AdaptiveAvgPool2d(output_size=6)
  )
  (ppm_conv): ModuleList(
    (0): Sequential(
      (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): SynchronizedBatchNorm2d(512, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(2048, 51

In [20]:
decode = net_decoder(code, segSize=(450, 450))
# decode = net_decoder(code)

In [21]:
decode.shape

torch.Size([1, 2, 450, 450])

In [None]:
# crit = nn.CrossEntropyLoss(ignore_index=-1)