# training

## imports

In [0]:
import tensorflow as tf
import keras
from keras.layers import Conv2D, Conv2DTranspose, MaxPooling2D, BatchNormalization, UpSampling2D
from keras.layers import Activation, Dropout, Flatten, Dense, MaxPool2D
from keras.models import Sequential, load_model
from keras.preprocessing.image import ImageDataGenerator, load_img, array_to_img, img_to_array, image
from keras.callbacks import ModelCheckpoint
# from keras.callbacks import *
from keras.utils import plot_model
from keras import backend as kb

import numpy as np
import glob
import cv2
import pickle

## configs

In [0]:
checkpoint_path = ".../variables/checkpoints"
chunks_path = ".../variables/dataset/

## data preparation

In [0]:
# load batches paths (X, y)
X_paths = []
y_paths = []

for i, filename in enumerate(glob.glob(chunks_path+'X*.pickle')):
  X_paths.append(filename)

for i, filename in enumerate(glob.glob(chunks_path+'y*.pickle')):
  y_paths.append(filename)

train_batches = len(X_paths)

In [0]:
# connect all batches
X = []
y = []

for i in range(0, len(X_paths)):
  print("---- "+str(i)+"/"+str(len(X_paths))+" ----")

  with open(X_paths[i], 'rb') as f:
    X_l = pickle.load(f)
  with open(y_paths[i], 'rb') as f:
    y_l = pickle.load(f)

  X_l = X_l.astype('float16')
  y_l = y_l.astype('float16')

  if i == 0:
    X = X_l
    y = y_l
  else:
    X = np.concatenate((X,X_l))
    y = np.concatenate((y,y_l))

  del X_l
  del y_l

print(X.shape)
print(y.shape)

## model

In [0]:
# define custom losses - euclidean, cross_entropy loss
def euclidean_loss(y_true, y_pred):
  return kb.sqrt( kb.sum( kb.square(y_true - y_pred ), axis=-1 ))

def cross_entorpy_loss(y_true, y_pred):
  return -1.0 * kb.sum(y_true*kb.log(y_pred), axis=-1)

In [0]:
# define model, and construct the network

model = Sequential()

# conv 1
model.add(Conv2D(64,(3,3),padding="same",input_shape=(SIZE,SIZE,1)))
model.add(Conv2D(64,(3,3),padding="same", strides=2))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 2
model.add(Conv2D(128, (3,3), padding="same"))
model.add(Conv2D(128,(3,3),padding="same", strides=2))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 3
model.add(Conv2D(256, (3,3), padding="same"))
model.add(Conv2D(256, (3,3), padding="same"))
model.add(Conv2D(256,(3,3),padding="same", strides=2))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 4
model.add(Conv2D(512, (3,3), padding="same"))
model.add(Conv2D(512, (3,3), padding="same"))
model.add(Conv2D(512, (3,3), padding="same"))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 5 
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(Activation('relu'))
model.add(BatchNormalization())

#conv 6
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(keras.layers.ZeroPadding2D(padding=(2, 2)))
model.add(Conv2D(512, (3,3), dilation_rate=2, padding="valid"))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 7
model.add(Conv2D(512, (3,3), padding="same", dilation_rate=1))
model.add(Conv2D(512, (3,3), padding="same", dilation_rate=1))
model.add(Conv2D(512, (3,3), padding="same", dilation_rate=1))
model.add(Activation('relu'))
model.add(BatchNormalization())

# conv 8
# deconv
model.add(Conv2DTranspose(256, (4,4), padding="same", dilation_rate=1, strides=2))
model.add(Conv2D(256, (4,4), padding="same", dilation_rate=1))
model.add(Conv2D(256, (4,4), padding="same", dilation_rate=1))

# softmax
model.add(Conv2D(313, (1,1), padding="same", dilation_rate=1))
model.add(BatchNormalization())
model.add(Activation('softmax'))

# decoding
model.add(Conv2D(2, (1,1), padding="same", dilation_rate=1))

model.add(UpSampling2D((4,4)))


model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
# model.compile(optimizer='adam', loss=euclidean_loss, metrics=['accuracy'])

model.summary()

## train

In [0]:
# store checkpoints, so when training stops for some reason, we have backup files, and train the model
checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_accuracy', verbose=1, save_best_only=False, mode='max')
callbacks_list = [checkpoint]

model.fit(X,y,epochs=100, validation_split=0.1, batch_size=50, callbacks=callbacks_list, verbose=1)

In [0]:
# save model in .model format
model.save(".../models/image_colorization.model")

In [0]:
# in case of training stoping use the following cells
filepath = ".../variables/checkpoints/last_checkpoint.hdf5"

model_cont = tf.keras.models.load_model(filepath)

In [0]:
# continue training, with the desired epoches left
model_cont.fit(X,y,epochs=20, batch_size=50,  verbose=1)