# UNet_xBD_damage_detection

Code based on:

https://github.com/DIUx-xView/xView2_second_place

# Imports

In [25]:
import os
import random

import cv2
import numpy as np
import pandas as pd
import torch
from albumentations.pytorch.functional import img_to_tensor
from skimage.measure import label
from torch.utils.data import Dataset

import json
import os
import argparse
from functools import partial
from multiprocessing.pool import Pool
from os import cpu_count

import cv2
from cv2 import fillPoly
from shapely import wkt
import numpy as np
from shapely.geometry import mapping
from tqdm import tqdm


import argparse
import os

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import cv2

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

import models
from augs import SafeRotate, Lighting

from albumentations import Compose, RandomSizedCrop, HorizontalFlip, VerticalFlip, RGBShift, RandomBrightnessContrast, \
    RandomGamma, OneOf, RandomRotate90, Transpose, RandomCrop, HueSaturationValue, ImageCompression

import losses
from dataset.xview_dataset import XviewSingleDataset

from apex.parallel import DistributedDataParallel, convert_syncbn_model
from tensorboardX import SummaryWriter

from tools.config import load_config
from tools.utils import create_optimizer, AverageMeter

from apex import amp

from losses import dice_round

import numpy as np
import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.distributed as dist

import argparse
import os

from tools.xview_metric import XviewMetrics

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import cv2

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

import models
from augs import Lighting, RandomSizedCropAroundBbox

from albumentations import Compose, HorizontalFlip, VerticalFlip, RGBShift, RandomBrightnessContrast, \
    RandomGamma, RandomRotate90, Transpose

import losses
from dataset.xview_dataset import XviewSingleDataset

from apex.parallel import DistributedDataParallel, convert_syncbn_model
from tensorboardX import SummaryWriter

from tools.config import load_config
from tools.utils import create_optimizer, AverageMeter

from apex import amp

from losses import dice_round

import numpy as np
import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.distributed as dist


ModuleNotFoundError: No module named 'apex'

# Dataset

In [6]:
#functions to import and normalise the data


class XviewSingleDataset(Dataset):
    def __init__(self, data_path, mode, fold=0, folds_csv='folds.csv', equibatch=False, transforms=None, normalize=None,
                 multiplier=1):
        super().__init__()
        self.data_path = data_path
        self.mode = mode

        self.names = sorted(os.listdir(os.path.join(self.data_path, "images")))
        df = pd.read_csv(folds_csv, dtype={'id': object})
        self.df = df
        self.normalize = normalize
        self.fold = fold
        self.equibatch = equibatch
        if self.mode == "train":
            ids = df[df['fold'] != fold]['id'].tolist()
            nondamage = df[(df['fold'] != fold) & (df['nondamage'] == True)]['id'].tolist()
            minor = df[(df['fold'] != fold) & (df['minor'] == True)]['id'].tolist()
            major = df[(df['fold'] != fold) & (df['major'] == True)]['id'].tolist()
            destroyed = df[(df['fold'] != fold) & (df['destroyed'] == True)]['id'].tolist()
            empty = df[(df['fold'] != fold) & (df['empty'] == True)]['id'].tolist()

            self.group_names = {
                "nondamage1": nondamage,
                "nondamage": nondamage,
                "minor": minor,
                "major": major,
                "destroyed": destroyed,
                "empty": empty,
            }
            self.group_ids = list(self.group_names.keys())
            if not self.equibatch:
                ids.extend(minor)
                ids.extend(major)
                ids.extend(destroyed)
        else:
            ids = list(set(df[df['fold'] == fold]['id'].tolist()))
        self.transforms = transforms
        self.names = ids

        if mode == "train":
            self.names = self.names * multiplier
        self.cache = {}

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):

        if self.mode == 'train' and self.equibatch:
            group_id = self.group_ids[idx % len(self.group_ids)]
            name = random.choice(self.group_names[group_id])
        else:
            group_id = "unknown"
            name = self.names[idx]
        pre_img_path = os.path.join(self.data_path, "images", name + "_pre_disaster.png")
        post_img_path = os.path.join(self.data_path, "images", name + "_post_disaster.png")
        image_pre = cv2.imread(pre_img_path, cv2.IMREAD_COLOR)[:, :, ::-1]
        image_post = cv2.imread(post_img_path, cv2.IMREAD_COLOR)[:, :, ::-1]
        mask_pre = cv2.imread(os.path.join(self.data_path, "masks", name + "_pre_disaster.png"), cv2.IMREAD_GRAYSCALE)
        mask_post = cv2.imread(os.path.join(self.data_path, "masks", name + "_post_disaster.png"), cv2.IMREAD_GRAYSCALE)

        rectangles = self.cache.get(self.names[idx], [])
        if not rectangles:
            self.add_boxes(label(mask_post == 2).astype(np.uint8), rectangles)
        if rectangles:
            self.cache[self.names[idx]] = rectangles

        mask = np.stack([mask_pre, mask_post, mask_post], axis=-1)
        sample = self.transforms(image=image_pre, image1=image_post, mask=mask, img_name=name, rectangles=rectangles)
        image = np.concatenate([sample['image'], sample['image1']], axis=-1)
        sample['img_name'] = name
        sample['group_id'] = group_id
        mask = np.zeros((5, *sample["mask"].shape[:2]))
        for i in range(1, 5):
            mask[i - 1, sample["mask"][:, :, 1] == i] = 1
        mask[4] = sample["mask"][:, :, 0] / 255
        del sample["image1"]
        sample['original_mask'] = torch.from_numpy(np.ascontiguousarray(sample["mask"][:, :, 1]))
        sample['mask'] = torch.from_numpy(np.ascontiguousarray(mask)).float()
        sample['image'] = img_to_tensor(np.ascontiguousarray(image), self.normalize)
        return sample

    def add_boxes(self, labels, rectangles):
        max_label = np.max(labels)
        for i in range(1, max_label + 1):
            obj_mask = np.zeros_like(labels)
            obj_mask[labels == i] = 255

            contours, hierarchy = cv2.findContours(obj_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            for cnt in contours:
                points = cv2.boundingRect(cnt)
                rectangles.append(points)



class XviewSingleDatasetTest(Dataset):
    def __init__(self, data_path, transforms=None, normalize=None):
        super().__init__()
        self.data_path = data_path
        self.names = list(set([os.path.splitext(f)[0].replace("test_post_", "").replace("test_pre_", "") for f in
                               sorted(os.listdir(os.path.join(self.data_path, "images")))]))
        self.normalize = normalize
        self.transforms = transforms

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        pre_img_path = os.path.join(self.data_path, "images", "test_pre_" + name + ".png")
        post_img_path = os.path.join(self.data_path, "images", "test_post_" + name + ".png")

        image_pre = cv2.imread(pre_img_path, cv2.IMREAD_COLOR)[:, :, ::-1]
        image_post = cv2.imread(post_img_path, cv2.IMREAD_COLOR)[:, :, ::-1]
        image = np.concatenate([image_pre, image_post], axis=-1)
        sample = {}
        sample['img_name'] = name
        sample['image'] = img_to_tensor(np.ascontiguousarray(image), self.normalize)
        return sample


In [7]:
#This script will generate pixel masks from json files.


def generate_localization_polygon(json_path, out_dir):
    with open(json_path, "r") as f:
        annotations = json.load(f)
    h = annotations["metadata"]["height"]
    w = annotations["metadata"]["width"]
    mask_img = np.zeros((h, w), np.uint8)
    out_filename = os.path.splitext(os.path.basename(json_path))[0] + ".png"
    for feat in annotations['features']['xy']:
        feat_shape = wkt.loads(feat['wkt'])
        coords = list(mapping(feat_shape)['coordinates'][0])
        fillPoly(mask_img, [np.array(coords, np.int32)], (255))
    cv2.imwrite(os.path.join(out_dir, out_filename), mask_img)


def generate_damage_polygon(json_path, out_dir):
    with open(json_path, "r") as f:
        annotations = json.load(f)

    h = annotations["metadata"]["height"]
    w = annotations["metadata"]["width"]
    mask_img = np.zeros((h, w), np.uint8)

    damage_dict = {
        "no-damage": 1,
        "minor-damage": 2,
        "major-damage": 3,
        "destroyed": 4,
        "un-classified": 255
    }
    out_filename = os.path.splitext(os.path.basename(json_path))[0] + ".png"
    for feat in annotations['features']['xy']:
        feat_shape = wkt.loads(feat['wkt'])
        coords = list(mapping(feat_shape)['coordinates'][0])
        fillPoly(mask_img, [np.array(coords, np.int32)], damage_dict[feat['properties']['subtype']])
    cv2.imwrite(os.path.join(out_dir, out_filename), mask_img)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--input',
                        default="/home/selim/datasets/xview/train/",
                        help='Path to parent dataset directory "xBD"')
    args = parser.parse_args()
    out_dir = os.path.join(args.input, "masks")
    in_dir = os.path.join(args.input, "labels")
    pre_images = [os.path.join(in_dir, f) for f in os.listdir(in_dir) if '_pre_' in f]
    post_images = [os.path.join(in_dir, f) for f in os.listdir(in_dir) if '_post_' in f]

    pool = Pool(cpu_count())
    with tqdm(total=len(pre_images)) as pbar:
        for i, v in enumerate(pool.imap_unordered(partial(generate_localization_polygon, out_dir=out_dir), pre_images)):
            pbar.update()
    with tqdm(total=len(post_images)) as pbar:
        for i, v in enumerate(pool.imap_unordered(partial(generate_damage_polygon, out_dir=out_dir), post_images)):
            pbar.update()


usage: ipykernel_launcher.py [-h] [--input INPUT]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/jovyan/.local/share/jupyter/runtime/kernel-e29f9431-eb41-4cee-b92f-fb6658306756.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# Localisation Training


In [11]:
#Used to train binary segmentation models. By default O0 opt level (FP32) is used for Apex due to unstable loss during training.


torch.backends.cudnn.benchmark = True

def create_train_transforms(conf):
    height = conf['crop_height']
    width = conf['crop_width']
    return Compose([
        SafeRotate(45, p=0.4, border_mode=cv2.BORDER_CONSTANT),
        OneOf([
            RandomSizedCrop(min_max_height=(int(height * 0.7), int(height * 1.3)), w2h_ratio=1., height=height,
                            width=width, p=0.8),
            RandomCrop(height=height, width=width, p=0.2)], p=1),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Transpose(),
        ImageCompression(p=0.1),
        Lighting(alphastd=0.3),
        RandomBrightnessContrast(p=0.4),
        RandomGamma(p=0.4),
        OneOf([RGBShift(), HueSaturationValue()], p=0.2)
    ], additional_targets={'image1': 'image'}
    )


def create_val_transforms(conf):
    return Compose([
    ], additional_targets={'image1': 'image'})


def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=8, help='number of cpu threads to use')
    arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='localization_')
    arg('--data-dir', type=str, default="/home/selim/datasets/xview/train")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=1)
    arg("--local_rank", default=0, type=int)
    arg("--opt-level", default='O0', type=str)
    arg("--predictions", default="../oof_preds", type=str)
    arg("--test_every", type=int, default=1)

    args = parser.parse_args()

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = models.__dict__[conf['network']](seg_classes=conf['num_classes'], backbone_arch=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    mask_loss_function = losses.__dict__[conf["mask_loss"]["type"]](**conf["mask_loss"]["params"]).cuda()
    loss_functions = {"mask_loss": mask_loss_function}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)

    dice_best = 0
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = XviewSingleDataset(mode="train",
                                    fold=args.fold,
                                    data_path=args.data_dir,
                                    folds_csv=args.folds_csv,
                                    transforms=create_train_transforms(conf['input']),
                                    multiplier=conf["data_multiplier"],
                                    normalize=conf["input"].get("normalize", None))
    data_val = XviewSingleDataset(mode="val",
                                  fold=args.fold,
                                  data_path=args.data_dir,
                                  folds_csv=args.folds_csv,
                                  transforms=create_val_transforms(conf['input']),
                                  normalize=conf["input"].get("normalize", None)
                                  )
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)

    train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
                                   shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
                                   drop_last=True)
    val_batch_size = 1
    val_data_loader = DataLoader(data_val, batch_size=val_batch_size, num_workers=args.workers, shuffle=False,
                                 pin_memory=False)

    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' + args.prefix + conf['encoder'])
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            if conf['optimizer'].get('zero_decoder', False):
                for key in state_dict.copy().keys():
                    if key.startswith("module.final"):
                        del state_dict[key]
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    dice_best = checkpoint.get('dice_best', 0)
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'], conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']):
        if train_sampler:
            train_sampler.set_epoch(epoch)
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            model.module.encoder_stages.eval()
            for p in model.module.encoder_stages.parameters():
                p.requires_grad = False
        else:
            print("Unfreezing encoder!!!")
            model.module.encoder_stages.train()
            for p in model.module.encoder_stages.parameters():
                p.requires_grad = True
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
                    args.local_rank)

        model = model.eval()
        if args.local_rank == 0:
            torch.save({
                'epoch': current_epoch + 1,
                'state_dict': model.state_dict(),
                'dice_best': dice_best,
            }, args.output_dir + '/' + snapshot_name + "_last")
            if epoch % args.test_every == 0:
                preds_dir = os.path.join(args.predictions, snapshot_name)
                dice_best = evaluate_val(args, val_data_loader, dice_best, model,
                                                     snapshot_name=snapshot_name,
                                                     current_epoch=current_epoch,
                                                     optimizer=optimizer, summary_writer=summary_writer,
                                                     predictions_dir=preds_dir)
        current_epoch += 1


def evaluate_val(args, data_val, dice_best, model, snapshot_name, current_epoch, optimizer, summary_writer,
                 predictions_dir):
    print("Test phase")
    model = model.eval()
    dice = validate(model, data_loader=data_val, predictions_dir=predictions_dir)
    if args.local_rank == 0:
        summary_writer.add_scalar('val/dice', float(dice), global_step=current_epoch)
        if dice > dice_best:
            if args.output_dir is not None:
                torch.save({
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice,

                }, args.output_dir + snapshot_name + "_best_dice")
            dice_best = dice
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'dice_best': dice_best,
        }, args.output_dir + snapshot_name + "_last")
        print("dice: {}, dice_best: {}".format(dice, dice_best))
    return dice_best


def validate(net, data_loader, predictions_dir):
    os.makedirs(predictions_dir, exist_ok=True)
    preds_dir = predictions_dir + "/predictions"
    os.makedirs(preds_dir, exist_ok=True)
    dices = []
    with torch.no_grad():
        for sample in tqdm(data_loader):
            imgs = sample["image"].cuda().float()[:, :3, :, :]
            mask = sample["mask"].cuda().float()

            output = net(imgs)
            binary_pred = torch.sigmoid(output)

            for i in range(output.shape[0]):
                d = dice_round(binary_pred, mask[:, 4:, ...], t=0.5).item()
                dices.append(d)
                cv2.imwrite(os.path.join(preds_dir, "test_localization_" + sample["img_name"][i] + "_prediction.png"),
                            (binary_pred[i, 0].cpu().numpy() > 0.5) * 1)
    return np.mean(dices)


def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
                local_rank):
    losses = AverageMeter()
    dices = AverageMeter()
    iterator = tqdm(train_data_loader)
    model.train()
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["image"].cuda()[:, :3, :, :]
        masks = sample["mask"].cuda().float()
        out_mask = model(imgs)
        mask_band = 4
        with torch.no_grad():
            pred = torch.sigmoid(out_mask)
            d = dice_round(pred, masks[:, mask_band:, ...], t=0.5).item()
        dices.update(d, imgs.size(0))

        mask_loss = loss_functions["mask_loss"](out_mask, masks[:, mask_band:, ...].contiguous())
        loss = mask_loss
        losses.update(loss.item(), imgs.size(0))
        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss ({loss.avg:.4f}); dice ({dice.avg:.4f}); ".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, dice=dices))
        optimizer.zero_grad()
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()

        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * len(train_data_loader))

    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
        summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)


if __name__ == '__main__':
    main()


usage: PyTorch Xview Pipeline [-h] [--config CONFIG_FILE] [--workers WORKERS]
                              [--gpu GPU] [--output-dir OUTPUT_DIR]
                              [--resume RESUME] [--fold FOLD]
                              [--prefix PREFIX] [--data-dir DATA_DIR]
                              [--folds-csv FOLDS_CSV] [--logdir LOGDIR]
                              [--zero-score] [--from-zero] [--distributed]
                              [--freeze-epochs FREEZE_EPOCHS]
                              [--local_rank LOCAL_RANK]
                              [--opt-level OPT_LEVEL]
                              [--predictions PREDICTIONS]
                              [--test_every TEST_EVERY]
PyTorch Xview Pipeline: error: unrecognized arguments: -f /home/jovyan/.local/share/jupyter/runtime/kernel-e29f9431-eb41-4cee-b92f-fb6658306756.json


SystemExit: 2

# Train

In [None]:
torch.backends.cudnn.benchmark = True


def create_train_transforms(conf):
    height = conf['crop_height']
    width = conf['crop_width']
    return Compose([
        RandomSizedCropAroundBbox(min_max_height=(int(height * 0.8), int(height * 1.2)), w2h_ratio=1., height=height,
                               width=width, p=1),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Transpose(),
        Lighting(alphastd=0.3),
        RandomBrightnessContrast(p=0.2),
        RandomGamma(p=0.2),
        RGBShift(p=0.2)
    ], additional_targets={'image1': 'image'}
    )


def create_val_transforms(conf):
    return Compose([
    ], additional_targets={'image1': 'image'})


def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=6, help='number of cpu threads to use')
    arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='damage_')
    arg('--data-dir', type=str, default="/home/selim/datasets/xview/train")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=1)
    arg("--local_rank", default=0, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--predictions", default="../oof_preds", type=str)
    arg("--test_every", type=int, default=1)

    args = parser.parse_args()

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = models.__dict__[conf['network']](seg_classes=conf['num_classes'], backbone_arch=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    damage_loss_function = losses.__dict__[conf["damage_loss"]["type"]](**conf["damage_loss"]["params"]).cuda()
    mask_loss_function = losses.__dict__[conf["mask_loss"]["type"]](**conf["mask_loss"]["params"]).cuda()
    loss_functions = {"damage_loss": damage_loss_function, "mask_loss": mask_loss_function}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)

    dice_best = 0
    xview_best = 0
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = XviewSingleDataset(mode="train",
                                    fold=args.fold,
                                    data_path=args.data_dir,
                                    folds_csv=args.folds_csv,
                                    transforms=create_train_transforms(conf['input']),
                                    multiplier=conf["data_multiplier"],
                                    normalize=conf["input"].get("normalize", None))
    data_val = XviewSingleDataset(mode="val",
                                  fold=args.fold,
                                  data_path=args.data_dir,
                                  folds_csv=args.folds_csv,
                                  transforms=create_val_transforms(conf['input']),
                                  normalize=conf["input"].get("normalize", None)
                                  )
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)

    train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
                                   shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
                                   drop_last=True)
    val_batch_size = 1
    val_data_loader = DataLoader(data_val, batch_size=val_batch_size, num_workers=args.workers, shuffle=False,
                                 pin_memory=False)

    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' + args.prefix + conf['encoder'])
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            if conf['optimizer'].get('zero_decoder', False):
                for key in state_dict.copy().keys():
                    if key.startswith("module.final"):
                        del state_dict[key]
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    dice_best = checkpoint.get('dice_best', 0)
                    xview_best = checkpoint.get('xview_best', 0)
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'], conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']):
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            if hasattr(model.module, 'encoder_stages1'):
                model.module.encoder_stages1.eval()
                model.module.encoder_stages2.eval()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = False
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = False
            else:
                model.module.encoder_stages.eval()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = False
        else:
            if hasattr(model.module, 'encoder_stages1'):
                print("Unfreezing encoder!!!")
                model.module.encoder_stages1.train()
                model.module.encoder_stages2.train()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = True
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = True
            else:
                model.module.encoder_stages.train()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = True
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
                    args.local_rank)

        model = model.eval()
        if args.local_rank == 0:
            torch.save({
                'epoch': current_epoch + 1,
                'state_dict': model.state_dict(),
                'dice_best': dice_best,
                'xview_best': xview_best,
            }, args.output_dir + '/' + snapshot_name + "_last")
            if epoch % args.test_every == 0:
                preds_dir = os.path.join(args.predictions, snapshot_name)
                dice_best, xview_best = evaluate_val(args, val_data_loader, xview_best, dice_best, model,
                                                     snapshot_name=snapshot_name,
                                                     current_epoch=current_epoch,
                                                     optimizer=optimizer, summary_writer=summary_writer,
                                                     predictions_dir=preds_dir)
        current_epoch += 1


def evaluate_val(args, data_val, xview_best, dice_best, model, snapshot_name, current_epoch, optimizer, summary_writer,
                 predictions_dir):
    print("Test phase")
    model = model.eval()
    dice, xview_score = validate(model, data_loader=data_val, predictions_dir=predictions_dir)
    if args.local_rank == 0:
        summary_writer.add_scalar('val/dice', float(dice), global_step=current_epoch)
        summary_writer.add_scalar('val/xview_score', float(xview_score), global_step=current_epoch)
        if dice > dice_best:
            if args.output_dir is not None:
                torch.save({
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice,
                    'xview_best': xview_score,

                }, args.output_dir + snapshot_name + "_best_dice")
            dice_best = dice
        if xview_score > xview_best:
            if args.output_dir is not None:
                torch.save({
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice,
                    'xview_best': xview_score,
                }, args.output_dir + snapshot_name + "_best_xview")
            xview_best = xview_score
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'dice_best': dice_best,
            'xview_best': xview_best,
        }, args.output_dir + snapshot_name + "_last")
        print("dice: {}, dice_best: {}".format(dice, dice_best))
        print("xview: {}, xview_best: {}".format(xview_score, xview_best))
    return dice_best, xview_best


def validate(net, data_loader, predictions_dir):
    os.makedirs(predictions_dir, exist_ok=True)
    preds_dir = predictions_dir + "/predictions"
    os.makedirs(preds_dir, exist_ok=True)
    targs_dir = predictions_dir + "/targets"
    os.makedirs(targs_dir, exist_ok=True)
    with torch.no_grad():
        for sample in tqdm(data_loader):
            imgs = sample["image"].cuda().float()
            mask = sample["mask"].cuda().float()
            original_mask = sample["original_mask"].cuda().long().cpu().numpy()

            output = net(imgs)
            binary_pred = torch.sigmoid(output[:, 4:, ...])

            damage_preds = torch.sigmoid(output[:, :4, ...]).cpu().numpy()
            for i in range(output.shape[0]):
                damage_pred = damage_preds[i]
                first = np.zeros((1, 1024, 1024))
                first[:, :, :] = 0.1
                damage_pred = np.concatenate([first, damage_pred], axis=0)
                cv2.imwrite(os.path.join(preds_dir,
                                         "test_localization_" + sample["img_name"][i] + "_prediction.png"),
                            (binary_pred[i, 0].cpu().numpy() > 0.3) * 1)
                cv2.imwrite(os.path.join(preds_dir,
                                         "test_damage_" + sample["img_name"][i] + "_prediction.png"),
                            np.argmax(damage_pred, axis=0))
                cv2.imwrite(os.path.join(targs_dir,
                                         "test_localization_" + sample["img_name"][i] + "_target.png"),
                            mask.cpu().numpy()[i, 4])
                cv2.imwrite(
                    os.path.join(targs_dir, "test_damage_" + sample["img_name"][i] + "_target.png"),
                    original_mask[i])
    d = XviewMetrics.compute_score(preds_dir, targs_dir, "out.json")
    for k, v in d.items():
        print("{}:{}".format(k, v))
    return d["localization_f1"], d["score"]


def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
                local_rank):
    losses = AverageMeter()
    damage_f1 = AverageMeter()
    localization_f1 = AverageMeter()
    iterator = tqdm(train_data_loader)
    model.train()
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["image"].cuda()
        masks = sample["mask"].cuda().float()
        out_mask = model(imgs)
        mask_band = 4
        with torch.no_grad():
            pred = torch.sigmoid(out_mask[:, :, ...])
            d = dice_round(pred[:, mask_band:, ...], masks[:, mask_band:, ...], t=0.5).item()
            loc_f1 = 0
            for i in range(4):
                loc_f1 += 1/(dice_round(pred[:, i:i+1, ...], masks[:, i:i+1, ...], t=0.3).item() + 1e-3)
            loc_f1 = 4/loc_f1
        localization_f1.update(d, imgs.size(0))
        damage_f1.update(loc_f1, imgs.size(0))

        mask_loss = loss_functions["mask_loss"](out_mask[:, mask_band:, ...].contiguous(),
                                                masks[:, mask_band:, ...].contiguous())
        damage_loss = loss_functions["damage_loss"](out_mask[:, :mask_band, ...].contiguous(),
                                                    masks[:, :mask_band, ...].contiguous())
        loss = 0.7 * damage_loss + 0.3 * mask_loss
        losses.update(loss.item(), imgs.size(0))
        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss ({loss.avg:.4f}); Localization F1 ({dice.avg:.4f}); Damage F1 ({damage.avg:.4f}); ".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, dice=localization_f1, damage=damage_f1))
        optimizer.zero_grad()
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()
        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * len(train_data_loader))

    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
        summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)


if __name__ == '__main__':
    main()
