In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as T

import os
from time import time
import random
import shutil
from datetime import datetime
from collections import defaultdict

from models.inceptionresnetv2 import pretrained_settings
from models.unet import unet_inceptionresnetv2, UNetUp2

pjoin = join = os.path.join

In [None]:
TRAIN_ROOT = '/home/kaliev/Downloads/AerialImageDataset_crop/train/'
IMAGES_ROOT = TRAIN_ROOT + 'images'
GT_ROOT = TRAIN_ROOT + 'gt'

OUTPUT_DIR = './output/'

IMAGE_SZ = 255
ORIG_IMAGE_SZ = 1250

In [None]:
cv2.setNumThreads(0)

In [None]:
random.seed(2018)
np.random.seed(2018)
torch.manual_seed(2018)
torch.cuda.manual_seed_all(2018)

In [None]:
class Progress:
    def __init__(self, total, desc, print_freq=0.1):
        self.total = total
        self.desc = desc
        self.prog = 0
        self.last_print = 0
        self.postfix = ''
        self.print_freq = print_freq
                
    def update(self, step):
        self.prog += step
        if self.prog == self.total:
            self._print()
        if self.last_print / self.total < self.print_freq \
            and (self.last_print + step) / self.total >= self.print_freq:
            self._print()
            self.last_print = 0
        else:
            self.last_print += step
            
    def _print(self):
        prog = self.prog / self.total * 100
        line = f'{self.desc} {prog:.2f}% {self.postfix}'
        print(line)
        
    def set_postfix(self, postfix):
        self.postfix = str(postfix)
        

class Logger:
    def __init__(self, log_filepath):
        self.log = open(log_filepath, 'w')

    def msg(self, msg, output_to_console=False, tq=None):
        self.log.write('%s %s\n' % (datetime.now().isoformat(), msg))
        if output_to_console:
            if tq is not None:
                tq.write(msg)
            else:
                print(msg)
                
    def __del__(self):
        self.log.close()
        
        
def ocv_loader(fpath):
    im = cv2.imread(fpath)
    return im[:, :, ::-1].copy() if im is not None else None
        
        
normalizer = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=pretrained_settings['mean'], std=pretrained_settings['std']),
])


class Denormalize:
    def __init__(self, mean, std):
        self.mean = torch.FloatTensor(mean)
        self.std = torch.FloatTensor(std)
        
    def __call__(self, img):
        img = img.permute(1, 2, 0)
        return img * self.std + self.mean


denormalizer = Denormalize(pretrained_settings['mean'], pretrained_settings['std'])


def mask_overlay(image, mask, ch=0):
    if not isinstance(mask, np.ndarray):
        mask = mask.numpy()
    if mask.ndim > 2:
        mask = mask[ch,:,:]
    if isinstance(image, str):
        im = cv2.imread(image)
        im = cv2.resize(im, (mask.shape[1], mask.shape[0]))
    else:
        im = denormalizer(image).numpy()
        im = (im * 255).astype(np.uint8)
    mask_ch = np.clip((255*mask) + im[...,0], 0, 255).astype(np.uint8)
    return np.dstack((mask_ch,im[...,1],im[...,2]))


def variable(x, volatile=False):
    return Variable(x, volatile=volatile).cuda()


# func to show image
def imshow(im,figsz=(12,12),**kwargs):
    plt.figure(figsize=figsz)
    plt.imshow(im,**kwargs)
    plt.axis('off')

In [None]:
class RandomCrop(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img, mask=None):
        h, w = img.shape[:2]
        tw, th = self.size
        if w == tw and h == th:
            return img, mask

        if ((w - tw) > 0) and ((h - th) > 0):
            x1 = np.random.randint(0, w - tw)
            y1 = np.random.randint(0, h - th)
        else:
            x1 = 0
            y1 = 0
        img = img[y1:y1 + th, x1:x1 + tw]
        if mask is not None:
            mask = mask[y1:y1 + th, x1:x1 + tw]
        return img, mask
    

class CenterCrop:
    def __init__(self, size):
        if isinstance(size, (list, tuple, np.ndarray)):
            self.width, self.height = size
        else:
            self.width = self.height = size

    def __call__(self, img, mask=None):
        h, w, c = img.shape
        dx = (w-self.width)//2
        dy = (h-self.height)//2

        y1 = dy
        y2 = y1 + self.height
        x1 = dx
        x2 = x1 + self.width
        img = img[y1:y2, x1:x2]
        if mask is not None:
            mask = mask[y1:y2, x1:x2]
        return img, mask
    
    
# ### IMGAUG ###
import imgaug as ia
import imgaug.augmenters as iaa
    
    
sometimes = lambda aug: iaa.Sometimes(0.5, aug)

# GEOMETRY
ia_aug_geom_light = iaa.Sequential([
    iaa.Fliplr(0.5),
    sometimes(iaa.Affine(
        rotate=(-45, 45),
        order=1
    ))
], random_order=True)

ia_aug_geom = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5),
    sometimes(iaa.Affine(
        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
        rotate=(-45, 45),
        shear=(-10, 10),
        order=1
    )),
    sometimes(iaa.OneOf([
        iaa.PiecewiseAffine(scale=(0.01, 0.02)),
        iaa.PerspectiveTransform(scale=(0.01, 0.1))
    ])),
], random_order=True)

# COLOR
ia_aug_color_light = iaa.Sequential([
    iaa.Add((-10, 10)),
    iaa.ContrastNormalization((0.8, 1.2))
], random_order=True)

ia_aug_color = iaa.Sequential([
    iaa.Add((-10, 10)),
    iaa.ContrastNormalization((0.8, 1.2)),
    iaa.OneOf([
        iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255)),
        iaa.OneOf([
            iaa.GaussianBlur((0, 1.0)),
            iaa.AverageBlur(k=(1, 3)),
            iaa.MedianBlur(k=(1, 5)),
        ]),
        iaa.Sharpen(alpha=(0, 0.5), lightness=(0.75, 1.5)),
    ])
], random_order=True)


def augment(img, mask):
    aug_geom = ia_aug_geom.to_deterministic()
    aug_color = ia_aug_color
    img = aug_geom.augment_image(img)
    img = aug_color.augment_image(img)
    mask = aug_geom.augment_image(mask)
    return img.copy(), mask.copy()

In [None]:
def train_val_split():
    def get_city_n(s):
        for i, c in enumerate(s):
            if c.isdigit():
                break
        return int(s[i:])

    train_fnames, val_fnames = [], []
    for fname in os.listdir(IMAGES_ROOT):
        toks = fname[:-4].split('-')
        city_n = '-'.join(toks[:-1])
        part = toks[-1]
        if get_city_n(city_n) < 6:
            val_fnames.append(fname)
        else:
            train_fnames.append(fname)
    return train_fnames, val_fnames


def prepare_data(img, mask):
    img = normalizer(img)
    mask = torch.from_numpy(mask / 255).unsqueeze(0).float()
    return img, mask


class Dataset:
    def __init__(self, image_sz, image_list, epoch_mul=1):
        self.image_sz = image_sz
        self.image_list = image_list
        self.epoch_mul = epoch_mul
    
    def __getitem__(self, idx):
        idx %= len(self.image_list)            
        image_path = pjoin(IMAGES_ROOT, self.image_list[idx])
        gt_path = pjoin(GT_ROOT, self.image_list[idx])
        img = ocv_loader(image_path)
        gt = ocv_loader(gt_path)[:, :, 0]
        assert img is not None
        assert gt is not None
        
        pre_aug_sz = int(self.image_sz * 1.5)
        img, gt = RandomCrop(pre_aug_sz)(img, gt)
        img, gt = augment(img, gt)
        img, gt = CenterCrop(self.image_sz)(img, gt)
        img, gt = prepare_data(img, gt)
        
        return img, gt, image_path
    
    def __len__(self):
        return len(self.image_list) * self.epoch_mul

In [None]:
def fix_shape(img):
    pad = 31 - img.shape[0] % 32
    pads = [[0, pad], [0, pad]]
    if img.ndim == 3:
        pads += [[0, 0]]
    return np.pad(img, pads, 'symmetric')


def one_batch(img, gt):
    img, gt = prepare_data(img, gt)
    img, gt = img.unsqueeze(0), gt.unsqueeze(0)
    return variable(img, volatile=True), variable(gt)


def calc_loss(model, criterion, file_names, return_list=False):
    model.eval()
    losses = []
    for i, fname in enumerate(file_names):
        image_path = pjoin(IMAGES_ROOT, fname)
        gt_path = pjoin(GT_ROOT, fname)
        img = ocv_loader(image_path)
        gt = ocv_loader(gt_path)[:, :, 0]
        img, gt = fix_shape(img), fix_shape(gt)
        img, gt = one_batch(img, gt)
        output = model(img)
        losses.append(criterion(output, gt).data[0])
    return np.mean(losses)


def train(logger,
          model,
          output_dirpath,
          init_optimizer,
          lr,
          num_iters,
          criterion,
          train_dataloader, 
          val_set):
    logger.msg(str(model))
    optimizer = init_optimizer(model.parameters(), lr=lr)
    train_hist = defaultdict(list)
    best_val_loss = float('inf')
    it = 0
    try:
        while it < num_iters:
            model.train()
            for inputs, targets, _ in train_dataloader:
                inputs, targets = variable(inputs), variable(targets)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss = loss.data[0]
                if it % 20 == 0:
                    logger.msg(f'Stage: train, iter: {it}, lr: {lr:.6f}, loss {train_loss:.6f}',
                              output_to_console=True)
                train_hist['train'].append((it, train_loss))
                
                if it and it % 5000 == 0:
                    model_dirpath = join(output_dirpath, 'model.pt')
                    torch.save(model.state_dict(), model_dirpath)
                    
                if it and it % 10000 == 0:
                    val_loss = calc_loss(model, criterion, val_set)
                    logger.msg(f'Stage: eval, iter: {it}, loss {val_loss:.6f}',
                              output_to_console=True)
                    train_hist['val'].append((it, val_loss))
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_model_dirpath = join(output_dirpath, 'model_best.pt')
                        shutil.copy(model_dirpath, best_model_dirpath)
                        
                if it and it % 20000 == 0:
                    lr /= 2
                    optimizer = init_optimizer(model.parameters(), lr=lr)
                    print(f'LR => {lr:E}')
                        
                it += 1
                if it == num_iters:
                    break
    except KeyboardInterrupt:
        torch.save(model.state_dict(), model_dirpath)
        print('kb interrupt')
    return train_hist

In [None]:
batch_size = 12
epoch_mul = 1

train_set, val_set = train_val_split()

ds_train = Dataset(IMAGE_SZ, train_set, epoch_mul=epoch_mul)
ds_valid = Dataset(IMAGE_SZ, val_set, epoch_mul=epoch_mul)

dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=8)
dl_valid = DataLoader(ds_valid, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
len(train_set), len(val_set)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, output, target):
        pred = F.sigmoid(output)
        smooth = 1e-6
        intersection = (pred * target).sum()
        return 1. - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

class BCEDiceLoss(nn.Module):
    def __init__(self, w_bce=.5, w_dice=.5):
        super().__init__()
        self.dice = DiceLoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.w_bce = w_bce
        self.w_dice = w_dice
        self.hist = {'bce': [], 'dice': []}

    def forward(self, input, target):
        bce = self.w_bce * self.bce(input, target)
        dice = self.w_dice * self.dice(input, target)
        self.hist['bce'].append(bce.data[0])
        self.hist['dice'].append(dice.data[0])
        return bce + dice

In [None]:
out_ch = 1

def get_model():
    model = unet_inceptionresnetv2(out_ch, up_block=UNetUp2)
    model.out_logits = True  # >>> MATCH CRETIRION <<<
    return nn.DataParallel(model)

init_optimizer = lambda parameters, lr: Adam(parameters, lr=lr)

criterion = BCEDiceLoss()

In [None]:
!nvidia-smi

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
experiment_name = '2'
experiment_dir = OUTPUT_DIR + experiment_name
!mkdir -p {experiment_dir}

In [None]:
logger = Logger(pjoin(experiment_dir, 'train_log.txt'))

model = get_model()
#model.load_state_dict(torch.load('./output/.../model.pt'))
# train_dict = {
#     'model': model,
#     'output_dirpath': experiment_dir,
#     'init_optimizer': init_optimizer,
#     'lr': 0.001,
#     'epochs': 200,
#     'criterion': criterion,
#     'train_dataloader': dl_train,
#     'val_dataloader': dl_valid,
#     'val_every': 1,
#     'patience': 10,
#     #'metrics': {'MSE': mse},
#     'logger': logger
# }
train_dict = {
    'model': model,
    'output_dirpath': experiment_dir,
    'init_optimizer': init_optimizer,
    'lr': 1e-4,
    'num_iters': 200000,
    'criterion': criterion,
    'train_dataloader': dl_train,
    'val_set': val_set,
    'logger': logger
}
train_hist = train(**train_dict)

In [None]:
def make_smooth(x):
    out = x[:]
    tau = 0.95
    for i in range(1, len(x) - 1):
        out[i] = out[i - 1] * tau + out[i] (1 - tau)

def plot_train_hist(hist):
    val_loss = np.array(hist['val'])[:,1]
    v_min, v_min_ep = val_loss.min(), val_loss.argmin()
    
    plt.figure(figsize=(12,6))
    trn_hist = np.array(hist['train'])
    plt.plot(trn_hist[10:,0], trn_hist[10:,1])
    plt.title('Train loss')
    plt.grid(True)
    plt.figure(figsize=(12,6))
    val_hist = np.array(hist['val'])
    plt.plot(val_hist[1:,0], val_hist[1:,1])
    plt.hlines([v_min], 0, val_hist[-1,0])
    plt.title(f'Validation loss (min {v_min:.6f}@{v_min_ep})')
    plt.grid(True)
    
plot_train_hist(train_hist)

In [None]:
model = get_model()
model_path = f'./output/{experiment_name}/model.pt'
model.load_state_dict(torch.load(model_path))
model.cuda()
model.eval();

In [None]:
def to_crops(img, crop_sz):
    assert img.shape[0] == img.shape[1]
    n = img.shape[0] // crop_sz
    pad = crop_sz * (1 + n) - img.shape[0]
    pads = [[0, pad], [0, pad]]
    if img.ndim == 3:
        pads += [[0, 0]]
    img_pad = np.pad(img, pads, 'symmetric')
    n_p = img_pad.shape[0] // crop_sz
    crops = []
    for i in range(n_p):
        for j in range(n_p):
            x = j * crop_sz
            y = i * crop_sz
            crops.append(img_pad[y:y+crop_sz,x:x+crop_sz,...])
    return crops


def pred2img(pred, ch=0):
#     return torch.sigmoid(pred).data[0].cpu().numpy()[ch]
    return torch.exp(pred).data[0].cpu().numpy()[ch]


def predshow(pred, ch=0):
    imshow(pred2img(pred, ch),cmap='gray',vmin=0, vmax=1)


def iou(pred, target):
    thr = 0.4
    pred = F.sigmoid(pred).data[0].cpu().numpy()
    target = target.data[0].cpu().numpy()
    pred = (pred > thr).astype(int)
    pred = pred[:ORIG_IMAGE_SZ,:ORIG_IMAGE_SZ]
    target = target[:ORIG_IMAGE_SZ,:ORIG_IMAGE_SZ]
    intersection = (pred * target).sum()
    return intersection / (pred.sum() + target.sum() - intersection + 1e-12)


def calc_iou(model, file_names, return_list=False):
    model.eval()
    ious = []
    for i, fname in enumerate(file_names):
        print(f'{i+1}/{len(file_names)}')
        image_path = pjoin(IMAGES_ROOT, fname)
        gt_path = pjoin(GT_ROOT, fname)
        img = ocv_loader(image_path)
        gt = ocv_loader(gt_path)[:, :, 0]
        imshow(gt)
        img, gt = fix_shape(img), fix_shape(gt)
        img, gt = one_batch(img, gt)
        output = model(img)
        predshow(output)
        ious.append(iou(output, gt))
        #img_crops = to_crops(img)
        #gt_crops = to_crops(gt)
        #img_ious = []
        #for img, gt in zip(img_crops, gt_crops):
        #    img, gt = one_batch(img, gt)
        #    output = model(img)
        #    img_ious.append(iou(output, gt))
        #ious.append(np.mean(img_ious))
    if return_list:
        return ious
    return np.mean(ious)
            
# iou_val = calc_iou(model, np.random.permutation(val_set)[:5])

# iou_val

calc_loss(model, criterion, val_set[:5])

In [None]:
output