In [None]:
import time

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np

import os, sys

from preprocessing import *
import image_preprocessing

from skimage import color

import keras

from skimage.color import rgb2hsv

from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, LeakyReLU, Dropout

In [None]:
ROOT_DIR = 'training/'
IMAGE_DIR = ROOT_DIR + 'images/'
GT_DIR = ROOT_DIR + 'groundtruth/'

PATCH_SIZE = 10

WINDOW_SIZE = 71

In [None]:
files = os.listdir(IMAGE_DIR)

imgs = np.stack([load_image(IMAGE_DIR + file) for file in files]) # images (400, 400, 3)
gt_imgs = np.stack([load_image(GT_DIR + file) for file in files]) # images (400, 400)

patched_imgs = np.stack([patch_image(img, PATCH_SIZE) for img in imgs]) # images (400, 400)
patched_gts = np.stack([patch_groundtruth(gt, PATCH_SIZE) for gt in gt_imgs])

In [None]:
PATCHED_SIZE = imgs.shape[1] // PATCH_SIZE
WINDOWS_PER_IMAGE = PATCHED_SIZE ** 2

In [None]:
N = 1 # Number of image to be used in training

leakyness = 0.1

In [None]:
windows_per_image = [image_to_neighborhoods(im, WINDOW_SIZE, True) for im in patched_imgs[:N]]
windows = np.vstack(windows_per_image)

window_labels = np.ravel(patched_gts[:N])
assert window_labels.shape[0] == windows.shape[0]

In [None]:
window_labels.shape

In [None]:
window_labels = keras.utils.np_utils.to_categorical(window_labels)

In [None]:
window_cnn = keras.models.Sequential([
    
    Conv2D(32, (5, 5), strides=(1, 1), input_shape=windows.shape[1:]),
    LeakyReLU(leakyness),
    
    MaxPooling2D(2),
    Dropout(0.25),

    Conv2D(64, (3, 3), strides=(1, 1)),
    LeakyReLU(leakyness),
    
    MaxPooling2D(2),
    Dropout(0.25),
    
    Conv2D(128, (3, 3), strides=(1, 1)),
    LeakyReLU(leakyness),
    
    MaxPooling2D(2),
    Dropout(0.25),
    
    Conv2D(256, (3, 3), strides=(1, 1)),
    LeakyReLU(leakyness),

    MaxPooling2D(2),
    Dropout(0.25),
    
    Dense(128),
    LeakyReLU(leakyness),
    
    Flatten(),
    Dense(2, activation='sigmoid'),
])

In [None]:
#window_cnn.summary()

In [None]:
window_cnn.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam(lr=0.001), metrics=['accuracy'])

In [None]:
nw = windows.shape[0]
def batch_generator():
    i = 0
    while True:
        print('Called', i)
        i+=1
        indices = np.random.choice(range(nw), 200, replace=False)
        yield (windows[indices], window_labels[indices])

In [None]:
def smart_generator(n_windows, n_epochs):
    ws_per_epoch = n_windows // n_epochs
    indices = np.random.choice(range(n_windows), n_epochs*ws_per_epoch, replace=False)
    indices_per_epoch = [indices[i:i+ws_per_epoch] for i in range(0, n_windows, ws_per_epoch)]
    for batch_idx in indices_per_epoch:
        print(len(batch_idx))
        ws = windows[batch_idx]
        labs = window_labels[batch_idx]
        yield (ws, labs)

In [None]:
epochs = 10

In [None]:
start = time.perf_counter()

window_cnn.fit(windows, window_labels, epochs=epochs)#, batch_size=1600)
#window_cnn.fit_generator(smart_generator(windows.shape[0], epochs), steps_per_epoch=windows.shape[0] // epochs, epochs=10)

time.perf_counter() - start

In [None]:
#window_cnn.save('backup')

# Test on a never seen image

In [None]:
i = 0

In [None]:
plt.imshow(imgs[i])

In [None]:
window_t = image_to_neighborhoods(patched_imgs[i], WINDOW_SIZE, True)

In [None]:
preds = window_cnn.predict(window_t)

In [None]:
plt.plot(np.sort(np.ravel(preds)))

In [None]:
preds = (preds[:,1] > preds[:,0]) * 1

In [None]:
f = preds
f = np.reshape(f, (PATCHED_SIZE, PATCHED_SIZE))

In [None]:
plt.imshow(patched_gts[i], cmap='Greys_r')

In [None]:
plt.imshow(f, cmap='Greys_r')

In [None]:
plt.imshow(np.hstack([patched_gts[i], f]), cmap='Greys_r')

# Test data set

In [None]:
TEST_DIR = 'test_set_images/'
test_files = os.listdir(TEST_DIR)

test_imgs = [load_image(TEST_DIR + file + '/' + file + '.png') for file in test_files]

In [None]:
patched_tests = np.stack([patch_image(im, PATCH_SIZE) for im in test_imgs])

In [None]:
test_windows = np.vstack([image_to_neighborhoods(im, WINDOW_SIZE, True) for im in patched_imgs])

In [None]:
test_windows.shape

In [None]:
predictions = window_cnn.predict(test_windows)