In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import json
import pickle
from collections import defaultdict
from PIL import Image
%matplotlib inline

In [None]:
cv2.INTER_NEAREST

In [2]:
np.random.seed(21)

In [3]:
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Flatten, Dense, Activation, AveragePooling2D, Cropping2D, Reshape, BatchNormalization
from keras.optimizers import Adam, SGD
from keras import backend as K
from keras.losses import mean_squared_error, binary_crossentropy
from keras.preprocessing.image import Iterator
from keras.utils.np_utils import to_categorical
from keras.layers import Reshape, BatchNormalization
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau
from keras.engine.topology import Layer

import tensorflow as tf
sess = tf.Session()
K.set_session(sess)

Using TensorFlow backend.


In [4]:
from __future__ import absolute_import
from __future__ import print_function

from pkg_resources import parse_version
from keras.callbacks import Callback

In [34]:
def unet_down_block(x, n_filters, block_id, with_maxpool=True, activation="elu", crop=False):
    padding = 'valid' if crop else 'same'
    y = Conv2D(n_filters, (3, 3), activation=activation, 
               padding=padding, name="conv{}_1".format(block_id))(x)
    y = Conv2D(n_filters, (3, 3), activation=activation,
               padding=padding, name="conv{}_2".format(block_id))(y)
    if not with_maxpool:
        return y
    
    pool = MaxPooling2D(pool_size=(2, 2), name="max_pool{}".format(block_id))(y)
    return y, pool    

In [39]:
def unet_up_block(x, y, n_filters, block_id, activation="elu", crop=False):
    padding = 'valid' if crop else 'same'
    up_x = UpSampling2D(size=(2, 2), name="upsample{}".format(block_id))(x)
    
    # Compute crop needed to have the same shape for up_x and y
    if crop:
        _, hx, wx, _ = up_x.shape
        _, hy, wy, _ = y.shape
        cropy = int(hy - hx)//2
        cropx = int(wy - wx)//2
        crop_y = Cropping2D(cropping=((cropy, cropy), (cropx, cropx)),
                            name="crop{}".format(block_id))(y)
        print("Crop: ", cropy, cropx)
    else:
        crop_y = y
    up = concatenate([up_x, crop_y], axis=-1,
                     name="concat{}".format(block_id))
    print(up.shape)
    up = Conv2D(n_filters, (3, 3), 
                activation=activation,
                padding=padding,
                name="conv{}_1".format(block_id))(up)
    print(up.shape)
    up = Conv2D(n_filters, (3, 3),
                activation=activation,
                padding=padding,
                name="conv{}_2".format(block_id))(up)
    print(up.shape)
    return up

In [40]:
def get_unet(im_height, im_width, n_channels=3,
             n_filters=[64, 128, 256, 512, 1024]):
    inputs = Input((im_height, im_width, n_channels))
    
    conv1, pool1 = unet_down_block(inputs, n_filters[0], 1)
    conv2, pool2 = unet_down_block(pool1,  n_filters[1], 2)
    conv3, pool3 = unet_down_block(pool2,  n_filters[2], 3)
    conv4, pool4 = unet_down_block(pool3,  n_filters[3], 4)
    conv5 = unet_down_block(pool4, n_filters[4], 5, with_maxpool=False)
    
    conv6 = unet_up_block(conv5, conv4, n_filters[3], 6)
    conv7 = unet_up_block(conv6, conv3, n_filters[2], 7)
    conv8 = unet_up_block(conv7, conv2, n_filters[1], 8)
    conv9 = unet_up_block(conv8, conv1, n_filters[0], 9)
    
    segmentation = Conv2D(1, (1, 1), activation='sigmoid', name="segmentation")(conv9)

    model = Model(inputs=[inputs], outputs=[segmentation], name="unet")

    return model

In [67]:
unet = get_unet( 160, 240, 3, n_filters=[16, 32, 64, 128, 256])

(?, 20, 30, 384)
(?, 20, 30, 128)
(?, 20, 30, 128)
(?, 40, 60, 192)
(?, 40, 60, 64)
(?, 40, 60, 64)
(?, 80, 120, 96)
(?, 80, 120, 32)
(?, 80, 120, 32)
(?, 160, 240, 48)
(?, 160, 240, 16)
(?, 160, 240, 16)


In [73]:
test_batch = np.ones((2, 160, 240, 3))

In [74]:
unet.predict(test_batch).shape

(2, 160, 240, 1)

In [75]:
class NonValidPatch(Exception):
    pass

In [81]:
class PatchIterator(Iterator):
    def __init__(self, root_dir, image_ids,
                 n_samples_per_image=160,
                 target_size=(160, 240),
                 crop_size=(160, 240),
                 batch_size=4, shuffle=True, seed=42,
                 debug_dir=None):
        self.n_consecutives_samples_per_image = 1
        self.image_ids = image_ids
        self.root_dir = root_dir
        self.debug_dir = debug_dir
        self.n_samples_per_image = n_samples_per_image
        self.target_size = target_size
        self.crop_size = crop_size
        self.n_indices = len(self.image_ids) * self.n_samples_per_image
        if self.debug_dir:
            os.makedirs(self.debug_dir, exist_ok=True)
        super(PatchIterator, self).__init__(self.n_indices,
                                            batch_size//self.n_consecutives_samples_per_image,
                                            shuffle, seed)

    def normalize_x(self, x):
        x[..., 0] -= 127
        x[..., 1] -= 127
        x[..., 2] -= 127
        return x
    
    def random_transform(self, x, y):
        return x, y
    
    def sample(self, img, mask):
        h, w, _ = img.shape
        bx = np.zeros((self.n_consecutives_samples_per_image,
                        self.target_size[0], 
                        self.target_size[1],
                        3),
                      dtype=np.uint8)
        by = np.zeros((self.n_consecutives_samples_per_image,
                        self.crop_size[0], 
                        self.crop_size[1]),
                      dtype=np.uint8)
        for i in range(self.n_consecutives_samples_per_image):
            if w - self.target_size[1] - 1 > 0 and h - self.target_size[0] - 1 > 0:
                sx = np.random.randint(0, w - self.target_size[1] - 1)
                sy = np.random.randint(0, h - self.target_size[0] - 1)
            else:
                sx = 0
                sy = 0
            img_patch = img[sy: sy + self.target_size[0],
                            sx: sx + self.target_size[1], ...]
            mask_patch = mask[sy: sy + self.target_size[0],
                              sx: sx + self.target_size[1]]
            img_patch, mask_patch = self.random_transform(img_patch, mask_patch)
            crop_sx = (self.target_size[1] - self.crop_size[1])//2
            crop_sy = (self.target_size[0] - self.crop_size[0])//2
            mask_patch_crop = mask_patch[crop_sy: crop_sy + self.crop_size[0],
                                         crop_sx: crop_sx + self.crop_size[1]]
            bx[i, ...] = img_patch
            by[i, ...] = mask_patch_crop
        return bx, by
    
    def next(self):
        """For python 2.x.
        # Returns
            The next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array, current_index, current_batch_size = next(self.index_generator)
                
        batch_x = np.zeros((current_batch_size * self.n_consecutives_samples_per_image,
                            self.target_size[0],
                            self.target_size[1],
                            3),
                           dtype=np.uint8)
        batch_y = np.zeros((current_batch_size * self.n_consecutives_samples_per_image,
                            self.crop_size[0],
                            self.crop_size[1],
                            1),
                           dtype=np.uint8)
        
        # For each index, we load the data and sample randomly n_consecutives_samples_per_image patches
        for i, j in enumerate(index_array):
            index = j // self.n_samples_per_image
            image_id = self.image_ids[index]
            img = np.array(Image.open(os.path.join(self.root_dir, "train_240x160", image_id + ".png")))
            mask = np.array(Image.open(os.path.join(self.root_dir, "train_masks_240x160", image_id + "_mask.png")))
            #weights = np.array(Image.open(os.path.join(self.root_dir, "train_weights", image_id + ".png")))
                
            x, y = self.sample(img, mask) #  ,weights
            batch_x[i*self.n_consecutives_samples_per_image:(i+1)*self.n_consecutives_samples_per_image, ...] = x
            batch_y[i*self.n_consecutives_samples_per_image:(i+1)*self.n_consecutives_samples_per_image, :, :, 0] = y
        
        if self.debug_dir:
            for i in range(batch_x.shape[0]):
                cv2.imwrite(os.path.join(self.debug_dir, "{:02d}_img.png".format(i)), batch_x[i, ...])
                cv2.imwrite(os.path.join(self.debug_dir, "{:02d}_mask.png".format(i)), batch_y[i, ...] * 255)
        return self.normalize_x(batch_x), batch_y

In [82]:
with open("../data/train.json", "r") as jfile:
    train_ids = json.load(jfile)

with open("../data/val.json", "r") as jfile:
    val_ids = json.load(jfile)

In [88]:
trainPatchesGenerator = PatchIterator("/home/lowik/carvana/data", train_ids)
valPatchesGenerator = PatchIterator("/home/lowik/carvana/data", val_ids)

In [89]:
for bx, by in trainPatchesGenerator:
    break

In [90]:
by.shape

(4, 160, 240, 1)

In [101]:
smooth = 1e-6

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [102]:
sgd = SGD(lr=1e-4, momentum=0.9, decay=1e-6, nesterov=True)
unet.compile(optimizer=sgd, loss=dice_coef_loss, metrics=["accuracy"])

In [103]:
h = unet.fit_generator(trainPatchesGenerator, 20, epochs=2,
                       verbose=1,
                       validation_data=valPatchesGenerator, validation_steps=10,
                       class_weight=[0.2, 0.8],
                       max_q_size=1, workers=1, pickle_safe=False,
                       initial_epoch=0)

Epoch 1/2
Epoch 2/2
