# Training using freezed Resnet152 as backbone

## Setup and imports

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
%cd /content/drive/My Drive/DeepLearningX/gitclone/light-weight-refinenet/src

/content/drive/My Drive/DeepLearningX/gitclone/light-weight-refinenet/src


In [0]:
# general libs
import argparse
import logging
import os
import random
import re
import sys
import time

# misc
import cv2
import numpy as np

# pytorch libs
import torch
import torch.nn as nn
from sklearn.metrics import confusion_matrix

# custom libs
from util import *

## Methods for training

In [0]:
# defining methods for training
sys.path.append("..")

from torchvision import transforms
from torch.utils.data import DataLoader, random_split
# Custom libraries
from datasets import NYUDataset as Dataset
from datasets import Pad, RandomCrop, RandomMirror, ResizeShorterScale, ToTensor, Normalise

def create_segmenter(
    net, pretrained, num_classes
    ):
    """Create Encoder; for now only ResNet [50,101,152]"""
    from models.resnet import rf_lw50, rf_lw101, rf_lw152
    if str(net) == '50':
        return rf_lw50(num_classes, imagenet=pretrained)
    elif str(net) == '101':
        return rf_lw101(num_classes, imagenet=pretrained)
    elif str(net) == '152':
        return rf_lw152(num_classes, imagenet=pretrained)
    elif str(net) == 'Mob':
        return mbv2(num_classes, pretrained=pretrained)
    else:
        raise ValueError("{} is not supported".format(str(net)))

def create_loaders(
    train_dir, val_dir, train_list, val_list,
    shorter_side, crop_size, low_scale, high_scale,
    normalise_params, batch_size, num_workers, ignore_label
    ):
    """
    Args:
      train_dir (str) : path to the root directory of the training set.
      val_dir (str) : path to the root directory of the validation set.
      train_list (str) : path to the training list.
      val_list (str) : path to the validation list.
      shorter_side (int) : parameter of the shorter_side resize transformation.
      crop_size (int) : square crop to apply during the training.
      low_scale (float) : lowest scale ratio for augmentations.
      high_scale (float) : highest scale ratio for augmentations.
      normalise_params (list / tuple) : img_scale, img_mean, img_std.
      batch_size (int) : training batch size.
      num_workers (int) : number of workers to parallelise data loading operations.
      ignore_label (int) : label to pad segmentation masks with

    Returns:
      train_loader, val loader

    """

    ## Transformations during training ##
    composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),
                                    Pad(crop_size, [123.675, 116.28 , 103.53], ignore_label),
                                    RandomMirror(),
                                    RandomCrop(crop_size),
                                    Normalise(*normalise_params),
                                    ToTensor()])
    composed_val = transforms.Compose([Normalise(*normalise_params),
                                    ToTensor()])
    ## Training and validation sets ##
    trainset = Dataset(data_file=train_list,
                       data_dir=train_dir,
                       transform_trn=composed_trn,
                       transform_val=composed_val)

    valset = Dataset(data_file=val_list,
                         data_dir=val_dir,
                         transform_trn=None,
                         transform_val=composed_val)
    logger.info(" Created train set = {} examples, val set = {} examples"
                .format(len(trainset), len(valset)))
    ## Training and validation loaders ##
    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(valset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=True)
    return train_loader, val_loader

def create_optimisers(
    lr_enc, lr_dec,
    mom_enc, mom_dec,
    wd_enc, wd_dec,
    param_enc, param_dec,
    optim_dec
    ):
    """Create optimisers for encoder, decoder and controller"""
    optim_enc = torch.optim.SGD(param_enc, lr=lr_enc, momentum=mom_enc,
                                weight_decay=wd_enc)
    if optim_dec == 'sgd':
        optim_dec = torch.optim.SGD(param_dec, lr=lr_dec,
                                    momentum=mom_dec, weight_decay=wd_dec)
    elif optim_dec == 'adam':
        optim_dec = torch.optim.Adam(param_dec, lr=lr_dec, weight_decay=wd_dec, eps=1e-3)
    return optim_enc, optim_dec

def load_ckpt(
    ckpt_path, ckpt_dict
    ):
    best_val = epoch_start = 0
    if os.path.exists(CKPT_PATH):
        ckpt = torch.load(ckpt_path)
        for (k, v) in ckpt_dict.items():
            if k in ckpt:
                v.load_state_dict(ckpt[k])
        best_val = ckpt.get('best_val', 0)
        epoch_start = ckpt.get('epoch_start', 0)
        logger.info(" Found checkpoint at {} with best_val {:.4f} at epoch {}".
            format(
                ckpt_path, best_val, epoch_start
            ))
    return best_val, epoch_start

def train_segmenter(
    segmenter, train_loader, optim_enc, optim_dec,
    epoch, segm_crit, freeze_bn
    ):
    """Training segmenter

    Args:
      segmenter (nn.Module) : segmentation network
      train_loader (DataLoader) : training data iterator
      optim_enc (optim) : optimiser for encoder
      optim_dec (optim) : optimiser for decoder
      epoch (int) : current epoch
      segm_crit (nn.Loss) : segmentation criterion
      freeze_bn (bool) : whether to keep BN params intact

    """
    train_loader.dataset.set_stage('train')
    segmenter.train()
    if freeze_bn:
        for m in segmenter.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    for i, sample in enumerate(train_loader):
        start = time.time()
        input = sample['image'].cuda()
        target = sample['mask'].cuda()
        input_var = torch.autograd.Variable(input).float()
        target_var = torch.autograd.Variable(target).long()
        # Compute output
        output = segmenter(input_var)
        output = nn.functional.interpolate(output, size=target_var.size()[1:], mode='bilinear', align_corners=False)
        soft_output = nn.LogSoftmax()(output)
        # Compute loss and backpropagate
        loss = segm_crit(soft_output, target_var)
        optim_enc.zero_grad()
        optim_dec.zero_grad()
        loss.backward()
        optim_enc.step()
        optim_dec.step()
        losses.update(loss.item())
        batch_time.update(time.time() - start)
        if i % PRINT_EVERY == 0:
            logger.info(' Train epoch: {} [{}/{}]\t'
                        'Avg. Loss: {:.3f}\t'
                        'Avg. Time: {:.3f}'.format(
                            epoch, i, len(train_loader),
                            losses.avg, batch_time.avg
                        ))
def compute_iu(cm):
    """Compute IU from confusion matrix.

    Args:
      cm (Tensor) : square confusion matrix.

    Returns:
      IU vector (Tensor).

    """
    pi = 0
    gi = 0
    ii = 0
    denom = 0
    n_classes = cm.shape[0]
    IU = np.ones(n_classes)
    
    for i in range(n_classes):
        pi = sum(cm[:, i])
        gi = sum(cm[i, :])
        ii = cm[i, i]
        denom = pi + gi - ii
        if denom > 0:
            IU[i] = ii / denom
    return IU

def validate(
    segmenter, val_loader, epoch, num_classes=-1
    ):
    """Validate segmenter

    Args:
      segmenter (nn.Module) : segmentation network
      val_loader (DataLoader) : training data iterator
      epoch (int) : current epoch
      num_classes (int) : number of classes to consider

    Returns:
      Mean IoU (float)
    """
    val_loader.dataset.set_stage('val')
    segmenter.eval()
    cm = np.zeros((num_classes, num_classes), dtype=int)
    with torch.no_grad():
        for i, sample in enumerate(val_loader):
            start = time.time()
            input = sample['image']
            target = sample['mask']
            input_var = torch.autograd.Variable(input).float().cuda()
            # Compute output
            output = segmenter(input_var)
            output = cv2.resize(output[0, :num_classes].data.cpu().numpy().transpose(1, 2, 0),
                                target.size()[1:][::-1],
                                interpolation=cv2.INTER_CUBIC).argmax(axis=2).astype(np.uint8)
            # Compute IoU
            gt = target[0].data.cpu().numpy().astype(np.uint8)
            gt_idx = gt < num_classes # Ignore every class index larger than the number of classes
            cm += confusion_matrix(output[gt_idx], gt[gt_idx])

            # if i % PRINT_EVERY == 0:
            #     logger.info(' Val epoch: {} [{}/{}]\t'
            #                 'Mean IoU: {:.3f}'.format(
            #                     epoch, i, len(val_loader),
            #                     compute_iu(cm).mean()
            #                 ))

    ious = compute_iu(cm)
    logger.info(" IoUs: {}".format(ious))
    miou = np.mean(ious)

    miou_path = '/content/drive/My Drive/DeepLearningX/models/ResNet/mious_res{}_{}.txt'.format(ENC, FREEZED)

    with open(miou_path, 'a') as f:
      f.write("{}\n".format(miou))

    logger.info(' Val epoch: {}\tMean IoU: {:.3f}'.format(
                                epoch, miou))
    return miou

def main():
    logging.basicConfig(level=logging.INFO)
    global logger #, args
    # args = get_arguments()
    logger = logging.getLogger(__name__)
    
    ## Add args ##
    NUM_STAGES = len(NUM_CLASSES)

    ## Set random seeds ##
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(RANDOM_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)
    
    ## Generate Segmenter ##
    segmenter = nn.DataParallel(
        create_segmenter(ENC, ENC_PRETRAINED, NUM_CLASSES[0])
        ).cuda()

    # segmenter = create_segmenter(ENC, ENC_PRETRAINED, NUM_CLASSES[0]).cuda()
    logger.info(" Loaded Segmenter {}, ImageNet-Pre-Trained={}, #PARAMS={:3.2f}M"
                .format(ENC, ENC_PRETRAINED, compute_params(segmenter) / 1e6))
    
    ## Restore if any ## (at checkpoint)
    best_val, epoch_start = load_ckpt(CKPT_PATH, {'segmenter' : segmenter})
    
    ## Criterion ##
    segm_crit = nn.NLLLoss2d(ignore_index=IGNORE_LABEL).cuda()

    # ## Saver ##
    # saver = Saver(args=vars(args),
    #               ckpt_dir=SNAPSHOT_DIR,
    #               best_val=best_val,
    #               condition=lambda x, y: x > y)  # keep checkpoint with the best validation score

    logger.info(" Training Process Starts")
    for task_idx in range(NUM_STAGES):
        start = time.time()
        torch.cuda.empty_cache()
        ## Create dataloaders ##
        train_loader, val_loader = create_loaders(TRAIN_DIR,
                                                  VAL_DIR,
                                                  TRAIN_LIST[task_idx],
                                                  VAL_LIST[task_idx],
                                                  SHORTER_SIDE[task_idx],
                                                  CROP_SIZE[task_idx],
                                                  LOW_SCALE[task_idx],
                                                  HIGH_SCALE[task_idx],
                                                  NORMALISE_PARAMS,
                                                  BATCH_SIZE[task_idx],
                                                  NUM_WORKERS,
                                                  IGNORE_LABEL)
        if EVALUATE:
            return validate(segmenter, val_loader, 0, num_classes=NUM_CLASSES[task_idx])

        logger.info(" Training Stage {}".format(str(task_idx)))
        ## Optimisers ##
        enc_params = []
        dec_params = []
        for k,v in segmenter.named_parameters():
            if bool(re.match(".*conv1.*|.*bn1.*|.*layer.*", k)):
                enc_params.append(v)
                logger.info(" Enc. parameter: {}".format(k))
            else:
                dec_params.append(v)
                logger.info(" Dec. parameter: {}".format(k))
        optim_enc, optim_dec = create_optimisers(LR_ENC[task_idx], LR_DEC[task_idx],
                                                 MOM_ENC[task_idx], MOM_DEC[task_idx],
                                                 WD_ENC[task_idx], WD_DEC[task_idx],
                                                 enc_params, dec_params, OPTIM_DEC)
        for epoch in range(NUM_SEGM_EPOCHS[task_idx]):
            train_segmenter(segmenter, train_loader,
                            optim_enc, optim_dec,
                            epoch_start, segm_crit,
                            FREEZE_BN[task_idx])
            if (epoch + 1) % (VAL_EVERY[task_idx]) == 0:
                miou = validate(segmenter, val_loader, epoch_start, NUM_CLASSES[task_idx])
                # saver.save(
                #     miou,
                #     {'segmenter' : segmenter.state_dict(),
                #      'epoch_start' : epoch_start}, logger
                #      )
            epoch_start += 1
            
            torch.save(segmenter, "/content/drive/My Drive/DeepLearningX/models/ResNet/model_res{}_{}_{}".format(ENC, FREEZED, epoch))

        logger.info("Stage {} finished, time spent {:.3f}min".format(
            task_idx, (time.time() - start) / 60.))
        
    # logger.info("All stages are now finished. Best Val is {:.3f}".format(
    #     saver.best_val))    

# if __name__ == '__main__':
#     logging.basicConfig(level=logging.INFO)
#     main()

## Configurations and training

In [0]:
# DATASET PARAMETERS
TRAIN_DIR = "/content/drive/My Drive/DeepLearningX/TrainData-People/Train/"
VAL_DIR = "/content/drive/My Drive/DeepLearningX/TrainData-People/Validation/"
TRAIN_LIST = ["/content/drive/My Drive/DeepLearningX/TrainData-People/Train/train.txt"] * 3
VAL_LIST = ["/content/drive/My Drive/DeepLearningX/TrainData-People/Validation/validation.txt"] * 3
SHORTER_SIDE = [350] * 3
CROP_SIZE = [500] * 3
NORMALISE_PARAMS = [1./255, # SCALE
                    np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)), # MEAN
                    np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))] # STD
BATCH_SIZE = [10] * 3
NUM_WORKERS = 16
NUM_CLASSES = [2] * 3
LOW_SCALE = [0.5] * 3
HIGH_SCALE = [2.0] * 3
IGNORE_LABEL = 255

# ENCODER PARAMETERS
ENC = '152' # Which model we are training
ENC_PRETRAINED = True  # pre-trained on ImageNet or randomly initialised

# GENERAL
EVALUATE = False
FREEZE_BN = [True] * 3
NUM_SEGM_EPOCHS = [100] * 3
PRINT_EVERY = 100
RANDOM_SEED = 42
SNAPSHOT_DIR = './ckpt/'
CKPT_PATH = './ckpt/checkpoint.pth.tar'
VAL_EVERY = [1] * 3 # how often to record validation scores

# OPTIMISERS' PARAMETERS
LR_ENC = [0] #[5e-4, 2.5e-4, 1e-4]  # TO FREEZE, PUT 0
LR_DEC = [5e-3, 2.5e-3, 1e-3]
MOM_ENC = [0] #[0.9] * 3 # TO FREEZE, PUT 0
MOM_DEC = [0.9] * 3
WD_ENC = [0] #[1e-5] * 3 # TO FREEZE, PUT 0
WD_DEC = [1e-5] * 3
OPTIM_DEC = 'sgd'
FREEZED = 'freezed'

In [0]:
main()

INFO:__main__: Loaded Segmenter 152, ImageNet-Pre-Trained=True, #PARAMS=61.95M
INFO:__main__: Training Process Starts
INFO:__main__: Created train set = 699 examples, val set = 100 examples
INFO:__main__: Training Stage 0
INFO:__main__: Enc. parameter: module.conv1.weight
INFO:__main__: Enc. parameter: module.bn1.weight
INFO:__main__: Enc. parameter: module.bn1.bias
INFO:__main__: Enc. parameter: module.layer1.0.conv1.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn1.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn1.bias
INFO:__main__: Enc. parameter: module.layer1.0.conv2.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn2.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn2.bias
INFO:__main__: Enc. parameter: module.layer1.0.conv3.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn3.weight
INFO:__main__: Enc. parameter: module.layer1.0.bn3.bias
INFO:__main__: Enc. parameter: module.layer1.0.downsample.0.weight
INFO:__main__: Enc. parameter: module.laye