In [1]:
import numpy as np

In [2]:
np.random.seed(21)

In [11]:
from keras.preprocessing.image import Iterator
from keras import backend as K
import matplotlib.pyplot as plt
import cv2
import os
import json
import pickle
%matplotlib inline

In [64]:
class CountsIterator(Iterator):
    def __init__(self, root_dir, image_ids,
                 n_samples_per_image=160,
                 target_size=(1024, 1024),
                 batch_size=32, shuffle=True, seed=42, debug_dir=None):
        
        self.n_sealion_types = 5
        self.image_ids = image_ids
        self.root_dir = root_dir
        self.debug_dir = debug_dir
        self.n_samples_per_block = 4
        self.n_samples_per_image = n_samples_per_image
        self.target_size = target_size
        self.n_indices = len(self.image_ids) * self.n_samples_per_image
                 
        super(CountsIterator, self).__init__(self.n_indices, batch_size//self.n_samples_per_block, shuffle, seed)
    
    def normalize_input(self, x_bgr):
        x_bgr[..., 0] -= 103.939
        x_bgr[..., 1] -= 116.779
        x_bgr[..., 2] -= 123.68
        return x_bgr
    
    def denormalize_input(self, x_normed):
        x[..., 0] += 103.939
        x[..., 1] += 116.779
        x[..., 2] += 123.68

    def random_transform(self, im):
        flips = np.random.randint(0, 2, (3,))
        if flips[0]:
            x = np.rot90(x)
            y = np.rot90(y)
        if flips[1]:
            x = np.flipud(x)
            y = np.flipud(y)
        if flips[2]:
            x = np.fliplr(x)
            y = np.fliplr(y)
        return x, y
    
    def sample(self, im, dots):
        h, w, c = im.shape
        batch_x = np.zeros((self.n_samples_per_block, self.target_size[0], self.target_size[1], 3), dtype=np.float32)
        batch_y = np.zeros((self.n_samples_per_block, 5), dtype=np.float32)
        xs = np.random.randint(0, w - self.target_size[1], size=(self.n_samples_per_block,))
        ys = np.random.randint(0, h - self.target_size[0], size=(self.n_samples_per_block,))
        for i in range(self.n_samples_per_block):
            counts = self.get_counts(xs[i], ys[i], dots)
            batch_x[i, ...] = im[ys[i]:ys[i]+self.target_size[0], xs[i]:xs[i]+self.target_size[1],...]
            batch_y[i, ...] = np.asarray(counts, dtype=np.float32)
        return batch_x, batch_y
    
    def get_counts(self, xstart, ystart, dots):
        x1 = xstart
        y1 = ystart
        x2 = xstart + self.target_size[1]
        y2 = ystart + self.target_size[0]
        counts = [0, 0, 0, 0, 0]
        for i, ds in enumerate(dots):
            for (x, y) in ds:
                if x1 <= x < x2 and y1 <= y < y2:
                    counts[i] += 1
        return counts
        
    def next(self):
        """For python 2.x.
        # Returns
            The next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array, current_index, current_batch_size = next(self.index_generator)
                
        batch_x = np.zeros((current_batch_size * self.n_samples_per_block,
                            self.target_size[0],
                            self.target_size[1],
                            3),
                           dtype=K.floatx())
        batch_y = np.zeros((current_batch_size * self.n_samples_per_block, 5),
                           dtype=np.int32)
        
        # For each index, we load the data and sample randomly n_successive_samples patches
        for i, j in enumerate(index_array):
            index = j // self.n_samples_per_image
            image_id = self.image_ids[index]
            with open(os.path.join(self.root_dir, "TrainDots", str(image_id) + ".pkl"), "rb") as pfile:
                dots = pickle.load(pfile)
            im = cv2.imread(os.path.join(self.root_dir, "Train", str(image_id) + ".jpg"))
                
            x, y = self.sample(im, dots)
            batch_x[i*self.n_samples_per_block:(i+1)*self.n_samples_per_block, ...] = x
            batch_y[i*self.n_samples_per_block:(i+1)*self.n_samples_per_block, ...] = y 

        if self.debug_dir:
            for i in range(batch_x.shape[0]):
                cv2.imwrite(os.path.join(self.debug_dir, "patch_{}.jpg".format(i)), batch_x[i])
        
        return self.normalize_input(batch_x), batch_y

In [65]:
with open("../data/sealion/train.json", "r") as jfile:
    train_ids = json.load(jfile)
train_ids = [int(iid) for iid in train_ids]

with open("../data/sealion/val.json", "r") as jfile:
    val_ids = json.load(jfile)
val_ids = [int(iid) for iid in val_ids]

In [66]:
trainPupsGenerator = CountsIterator("/home/lowik/sealion/data/sealion/", train_ids)

valPupsGenerator = CountsIterator("/home/lowik/sealion/data/sealion/", val_ids)

In [67]:
for batch_x, batch_y in trainPupsGenerator:
    break

In [61]:
batch_y.shape

(32, 5)

In [62]:
batch_x.shape

(32, 1024, 1024, 3)

In [63]:
batch_y.ravel()

array([ 0,  0,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  1,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        2,  0,  5,  0,  0,  0,  0,  0,  0,  2,  1, 42, 52,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  0,
        1,  0,  0,  7,  0, 52,  8, 43,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  1,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0], dtype=int32)

In [359]:
pups_net.predict(batch_x).shape

(32, 1, 1, 1)

In [360]:
for layer in base_model.layers:
    layer.trainable = False

In [361]:
from keras.losses import binary_crossentropy

In [362]:
sgd = SGD(lr=0.01, momentum=0.9, decay=0.0005, nesterov=True)
pups_net.compile(optimizer=sgd, loss=binary_crossentropy, metrics=['accuracy'])

In [363]:
h = pups_net.fit_generator(trainPupsGenerator, 5, epochs=3, verbose=1, callbacks=None, validation_data=valPupsGenerator, validation_steps=5, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

Epoch 1/3
Epoch 2/3
Epoch 3/3
