In [1]:
import tensorflow as tf
import keras
from keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D, SeparableConv2D, Conv2DTranspose
from keras.layers import Concatenate, Dense, Flatten, Reshape
from keras.models import Model
from PIL import Image
import numpy as np
import PIL.ImageOps
import random
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def wvnet(anc=16, image_dimensions = (256, 256)):
    im_dim = image_dimensions
    k=1
    encoding_length = 16

    inputs = Input(shape=(im_dim[0], im_dim[1], 1))
    conv3x3_1 = Conv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(inputs)
    batchnorm1 = BatchNormalization()(conv3x3_1)
    conv3x3_2 = Conv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(batchnorm1)
    batchnorm2 = BatchNormalization()(conv3x3_2)

    maxp1 = MaxPooling2D((2,2))(batchnorm2)
    sepconv1 = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(maxp1)
    sepconv2 = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(sepconv1)

    maxp2 = MaxPooling2D((2,2))(sepconv2)
    sepconv3 = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(maxp2)
    sepconv4 = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(sepconv3)

    maxp3 = MaxPooling2D((2,2))(sepconv4)
    sepconv5 = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(maxp3)
    sepconv6 = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(sepconv5)

    maxp4 = MaxPooling2D((2,2))(sepconv6)
    sepconv7 = SeparableConv2D(kernel_size=(3,3), filters=anc*16, padding='same', activation='relu')(maxp4)
    sepconv8 = SeparableConv2D(kernel_size=(3,3), filters=anc*16, padding='same', activation='relu')(sepconv7)

    #take down tensor for encoding.
    maxp_out1 = MaxPooling2D((2,2))(sepconv8)

    #this layer does the encoding
    encoder_convolution = Conv2D(kernel_size=(8, 8), filters=encoding_length, activation='relu', name='encoder_convolution')(maxp_out1)
    #this is our encoding:
    flattened_encoding = Flatten()(encoder_convolution)

    up1 = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*8)(sepconv8)
    concat1 = Concatenate()([sepconv6, up1])
    sepconv9 = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(concat1)
    sepconv10 = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(sepconv9)

    up2 = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*4)(sepconv10)
    concat2 = Concatenate()([sepconv4, up2])
    sepconv11 = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(concat2)
    sepconv12 = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(sepconv11)

    up3 = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*2)(sepconv12)
    concat3 = Concatenate()([sepconv2, up3])
    sepconv13 = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(concat3)
    sepconv14 = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(sepconv13)

    up4 = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*1)(sepconv14)
    concat4 = Concatenate()([batchnorm2, up4])
    sepconv15 = SeparableConv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(concat4)
    sepconv16 = SeparableConv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(sepconv15)

    pixel_map = Conv2D(kernel_size=(1,1), filters=k, activation='sigmoid', name='pixel_map')(sepconv16)

    #This part of the model generates the image reconstruction outputs.
    conv3x3_1b = Conv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(pixel_map)
    batchnorm1b = BatchNormalization()(conv3x3_1b)
    conv3x3_2b = Conv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(batchnorm1b)
    batchnorm2b = BatchNormalization()(conv3x3_2b)

    maxp1b = MaxPooling2D((2,2))(batchnorm2b)
    sepconv1b = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(maxp1b)
    sepconv2b = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(sepconv1b)

    maxp2b = MaxPooling2D((2,2))(sepconv2b)
    sepconv3b = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(maxp2b)
    sepconv4b = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(sepconv3b)

    maxp3b = MaxPooling2D((2,2))(sepconv4b)
    sepconv5b = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(maxp3b)
    sepconv6b = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(sepconv5b)

    maxp4b = MaxPooling2D((2,2))(sepconv6b)
    sepconv7b = SeparableConv2D(kernel_size=(3,3), filters=anc*16, padding='same', activation='relu')(maxp4b)
    sepconv8b = SeparableConv2D(kernel_size=(3,3), filters=anc*16, padding='same', activation='relu')(sepconv7b)

    #We recover the information taken at maxp_out1
    maxp_out2 = MaxPooling2D((2,2))(sepconv8b)

    #now reshape and transpose convolution to reconstruct input image
    reshaped = Reshape(target_shape=(1, 1, 16))(flattened_encoding)
    decoder_convolution1 = Conv2DTranspose(kernel_size=(8,8), strides=1, filters=anc*16)(reshaped)
    decoder_concat1 = Concatenate()([decoder_convolution1, maxp_out2])
    decoder_convolution1x1_1 = Conv2D(kernel_size=(1,1), strides=1, filters=anc*16, activation='relu')(decoder_concat1)
    decoder_convolution2 = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*8)(decoder_convolution1x1_1)
    decoder_concat2 = Concatenate()([decoder_convolution2, sepconv8b])
    decoder_convolution1x1_2 = Conv2D(kernel_size=(1,1), strides=1, filters=anc*16, activation='relu')(decoder_concat2)

    #and complete the w for reconstructed image
    up1b = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*8)(decoder_convolution1x1_2)
    concat1b = Concatenate()([sepconv6b, up1b])
    sepconv9b = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(concat1b)
    sepconv10b = SeparableConv2D(kernel_size=(3,3), filters=anc*8, padding='same', activation='relu')(sepconv9b)

    up2b = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*4)(sepconv10b)
    concat2b = Concatenate()([sepconv4b, up2b])
    sepconv11b = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(concat2b)
    sepconv12b = SeparableConv2D(kernel_size=(3,3), filters=anc*4, padding='same', activation='relu')(sepconv11b)

    up3b = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*2)(sepconv12b)
    concat3b = Concatenate()([sepconv2b, up3b])
    sepconv13b = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(concat3b)
    sepconv14b = SeparableConv2D(kernel_size=(3,3), filters=anc*2, padding='same', activation='relu')(sepconv13b)

    up4b = Conv2DTranspose(kernel_size=(2,2), strides=2, filters=anc*1)(sepconv14b)
    concat4b = Concatenate()([batchnorm2b, up4b])
    sepconv15b = SeparableConv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(concat4b)
    sepconv16b = SeparableConv2D(kernel_size=(3,3), filters=anc*1, padding='same', activation='relu')(sepconv15b)

    reconstructed_image = Conv2D(kernel_size=(1,1), filters=1, name='reconstruction')(sepconv16b)



    #And construct
    model = Model(inputs=inputs, outputs=[pixel_map, reconstructed_image])
    return model

model = wvnet()
model.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 256, 1)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 16) 160         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256, 256, 16) 64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 16) 2320        batch_normalization_1[0][0]      
__________________________________________________________________________________________________
batch_norm

In [3]:
path_to_target = './mask/{}_mask.bmp'
path_to_image = './raw/tile{}.pgm'

def load_target(tile, path_to_target=path_to_target):
    image = Image.open(path_to_target.format(tile))
    return image

def load_image(tile, path_to_image=path_to_image):
    image = Image.open(path_to_image.format(tile))
    return image

def augment(image, target, rotation=[0, 90, 180, 270]):
    rotation = random.choice(rotation)
    image = image.rotate(rotation)
    target = target.rotate(rotation)
    mirrorx = random.choice([True, False])
    if mirrorx:
        image = PIL.ImageOps.mirror(image)
        target = PIL.ImageOps.mirror(target)
    mirrory = random.choice([True, False])
    if mirrory:
        image = PIL.ImageOps.flip(image)
        target = PIL.ImageOps.flip(target)
    return image, target

def crop_pair(image, target, loc=(0, 0), dims=(256, 256)):
    x0 = loc[0]
    y0 = loc[1]
    x1 = loc[0]+dims[0]
    y1 = loc[1]+dims[1]
    image = image.crop((x0, y0, x1, y1))
    target = target.crop((x0, y0, x1, y1))
    return image, target

def select_coordinates(img_dims=(1700, 1700), crop_dims=(256, 256)):
    min_x = 0
    min_y = 0
    max_x = img_dims[0] - crop_dims[0]
    max_y = img_dims[1] - crop_dims[0]
    x = random.randint(min_x, max_x)
    y = random.randint(min_y, max_y)
    return x, y

def to_tensor(image, target, choose=True, coords=None, crop_dims=(256, 256)):
    if choose:
        coords = select_coordinates(img_dims=image.size, crop_dims=crop_dims)
    image, target = crop_pair(image, target, loc=coords, dims=crop_dims)
    image = np.array(image)/255
    image = np.expand_dims(image, axis=-1)
    target = np.array(target)/255
    target = np.expand_dims(target, axis=-1)
    return image, target
    

train_tiles = ['1_24', '2_24', '3_24', '3_25', '1_25']
val_tiles = ['2_25']

def make_batch(tiles=train_tiles, batch_size=10, crop_dims=(256, 256)):
    batch_in = []
    batch_out = []
    while len(batch_in) < batch_size:
        tile = random.choice(tiles)
        image = load_image(tile)
        target = load_target(tile)
        image, target = to_tensor(image, target)
        batch_in.append(image)
        batch_out.append(target)
    X = np.array(batch_in)
    Y = np.array(batch_out)
    return X, [Y, X]

def train_batch_gen(batch_size=10):
    while True:
        yield make_batch(batch_size=batch_size)
    
def val_batch_gen(tiles=val_tiles, batch_size=12):
    """use three validation steps."""
    batch_in = []
    batch_out = []
    count = 0
    while True:
        for tile in tiles:
            xs = [i*256 for i in range(6)]
            ys = [i*256 for i in range(6)]
            for x in xs:
                for y in ys:
                    count+=1
                    if len(batch_in) == batch_size:
                        X = np.array(batch_in)
                        Y = np.array(batch_out)
                        yield X, [Y, X]
                        batch_in = []
                        batch_out = []
                    image = load_image(tile)
                    target = load_target(tile)
                    image, target = to_tensor(image, target, choose=False, coords=(x, y))
                    batch_in.append(image)
                    batch_out.append(target)

In [4]:
model.compile(
    loss=['binary_crossentropy', 'mse'],
    loss_weights=[.999, .001],
    optimizer='adam',
    metrics=['acc']
)

In [5]:
tg = train_batch_gen(batch_size=20)
vg = val_batch_gen(batch_size=12)
history = model.fit_generator(
    tg,
    steps_per_epoch = 180,
    epochs = 25,
    validation_data = vg,
    validation_steps = 3
)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7efface8f588>