In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from tqdm import tqdm_notebook as tqdm
from copy import copy, deepcopy
from tensorflow.keras import datasets, layers, models, Sequential
from skimage import color

from utils import *
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint, TensorBoard
BATCH_SIZE = 16

In [17]:
load_data = tfds.load("cifar10")
train, test = load_data["train"], load_data["test"]

rgb_images = [np.array(data["image"])
                for data in train]

lab_images = [np.array(cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB))
                for rgb_img in rgb_images]

In [26]:
def data_generator(images):
    for image in images:
        try:
            lab_image = np.array(cv2.cvtColor(image, cv2.COLOR_RGB2LAB))
            yield lab_image[:,:,0], lab_image[:,:,1], lab_image[:,:,2]
        except cv2.error:
            print("/!\\ CV2 ERROR /!\\")

In [64]:
N_BINS = 313

from tensorflow.keras.layers import *
class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()

        ############################
        #########  Conv 1  #########
        ############################

        # (batch_size, 32, 32, 1) --> (batch_size, 16, 16, 8)
        self.conv_1_1 = Conv2D(filters=8, kernel_size=3,
                               padding='same',
                               activation='relu',
                               input_shape=(32, 32, 1))
        self.conv_1_2 = Conv2D(filters=8, kernel_size=3,
                               strides=(2, 2),
                               padding='same',
                               activation='relu')
        self.bn_1 = BatchNormalization()

        ############################
        #########  Conv 2  #########
        ############################

        # (batch_size, 16, 16, 8) --> (batch_size, 8, 8, 16)
        self.conv_2_1 = Conv2D(filters=16, kernel_size=3,
                               padding='same',
                               activation='relu')
        self.conv_2_2 = Conv2D(filters=16, kernel_size=3,
                               strides=(2, 2),
                               padding='same',
                               activation='relu')
        self.bn_2 = BatchNormalization()

        ############################
        #########  Conv 3  #########
        ############################

        # (batch_size, 8, 8, 16)  --> (batch_size, 4, 4, 32)
        self.conv_3_1 = layers.Conv2D(filters=32, kernel_size=3,
                                      padding='same',
                                      activation='relu')
        self.conv_3_2 = layers.Conv2D(filters=32, kernel_size=3,
                                      padding='same',
                                      activation='relu')
        self.conv_3_3 = layers.Conv2D(filters=32, kernel_size=3,
                                      strides=(2, 2),
                                      padding='same',
                                      activation='relu')
        self.bn_3 = BatchNormalization()

        ############################
        #########  Conv 4  #########
        ############################

        # (batch_size, 4, 4, 32) --> (batch_size, 4, 4, 64)
        self.conv_4_1 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu')
        self.conv_4_2 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu')
        self.conv_4_3 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu')
        self.bn_4 = BatchNormalization()

        ############################
        #########  Conv 5  #########
        ############################

        # (batch_size, 4, 4, 64) --> (batch_size, 4, 4, 64)
        self.conv_5_1 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu',
                               dilation_rate=2)
        self.conv_5_2 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu',
                               dilation_rate=2)
        self.conv_5_3 = Conv2D(filters=64, kernel_size=3,
                               strides=(1, 1),
                               padding='same',
                               activation='relu',
                               dilation_rate=2)
        self.bn_5 = BatchNormalization()

        ############################
        #########  Conv 6  #########
        ############################

        # (batch_size, 4, 4, 64) --> (batch_size, 4, 4, 64)
        self.conv_6_1 = Conv2D(filters=64, kernel_size=3,
                               padding='same',
                               activation='relu',
                               dilation_rate=2)
        self.conv_6_2 = Conv2D(filters=64, kernel_size=3,
                               padding='same',
                               activation='relu',
                               dilation_rate=2)
        self.conv_6_3 = Conv2D(filters=64, kernel_size=3,
                                  padding='same',
                                  activation='relu',
                                  dilation_rate=2)
        self.bn_6 = BatchNormalization()

        ############################
        #########  Conv 7  #########
        ############################

        # (batch_size, 4, 4, 64) --> (batch_size, 4, 4, 64)
        self.conv_7_1 = Conv2D(filters=64, kernel_size=3,
                               padding='same',
                               activation='relu',
                               dilation_rate=1)
        self.conv_7_2 = Conv2D(filters=64, kernel_size=3,
                               padding='same',
                               activation='relu',
                               dilation_rate=1)
        self.conv_7_3 = Conv2D(filters=64, kernel_size=3,
                               padding='same',
                               activation='relu',
                               dilation_rate=1)
        self.bn_7 = BatchNormalization()

        ############################
        #########  Deconv  #########
        ############################

        # (batch_size, 4, 4, 64) --> (batch_size, 32, 32, 32)
        self.deconv_1_1 = Conv2DTranspose(filters=32, kernel_size=4,
                                          strides=(2, 2),
                                          padding='same',
                                          activation='relu',
                                          dilation_rate=1)
        self.deconv_1_2 = Conv2DTranspose(filters=32, kernel_size=3,
                                          strides=(2, 2),
                                          padding='same',
                                          activation='relu',
                                          dilation_rate=1)
        self.deconv_1_3 = Conv2DTranspose(filters=32, kernel_size=3,
                                          strides=(2, 2),
                                          padding='same',
                                          activation='relu',
                                          dilation_rate=1)

        ############################
        ####  Unary prediction  ####
        ############################

        # (batch_size, 32, 32, 32) --> (batch_size, 32, 32, 1)
        self.conv_a = Conv2D(filters=1,
                             kernel_size=1,
                             strides=(1, 1),
                             dilation_rate=1)
        self.conv_b = Conv2D(filters=1,
                             kernel_size=1,
                             strides=(1, 1),
                             dilation_rate=1)
        
        self.seq_layers = [self.conv_1_1, self.conv_1_2, self.bn_1,
                           self.conv_2_1, self.conv_2_2, self.bn_2,
                           self.conv_3_1, self.conv_3_2, self.conv_3_3, self.bn_3,
                           self.conv_4_1, self.conv_4_2, self.conv_4_3, self.bn_4,
                           self.conv_5_1, self.conv_5_2, self.conv_5_3, self.bn_5,
                           self.conv_6_1, self.conv_6_2, self.conv_6_3, self.bn_6,
                           self.conv_7_1, self.conv_7_2, self.conv_7_3, self.bn_7,
                           self.deconv_1_1, self.deconv_1_2, self.deconv_1_3]

    def call(self, inputs):
        x = inputs
        for layer in self.seq_layers:
            x = layer(x)
        probs_a = self.conv_a(x)
        probs_b = self.conv_b(x)
        return probs_a, probs_b
        
model = Model()

In [27]:
def train_generator():
    while True:
        gen = data_generator(train_images)
        for features, labels_a, labels_b in gen:
            inputs = []
            targets_a = []
            targets_b = []
            for i in range(BATCH_SIZE):
                #print(np.expand_dims(labels, -1)[:,:,0].shape)
                inputs.append(np.expand_dims(features, -1))
                targets_a.append(np.expand_dims(labels_a, -1))
                targets_b.append(np.expand_dims(labels_b, -1))
            yield np.array(inputs), np.array(targets_a), np.array(targets_b)

In [59]:
train_gen = train_generator()

In [61]:
#img_example=next(train_gen)[0][0,:,:,0]
#plt.imshow(img_example)

In [62]:
import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

In [63]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',)

last_epoch = 0

model.fit_generator(train_gen, steps_per_epoch=120000/BATCH_SIZE, verbose=1,
                    callbacks=[MyCustomCallback()], epochs=5, initial_epoch=last_epoch)

model.save_weights("weights2.h5")

Epoch 1/5
Training: batch 0 begins at 13:21:26.691898


ValueError: Found a sample_weight array with shape (16, 32, 32, 1). In order to use timestep-wise sample weights, you should specify sample_weight_mode="temporal" in compile(). If you just mean to use sample-wise weights, make sure your sample_weight array is 1D.