In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
import tensorflow as tf
import keras
import keras.models as km
import keras.layers as kl
import keras.utils as ku
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import TensorBoard
import skimage.color as skc
from skimage.transform import resize
from skimage.io import imsave, imshow
from keras.applications.vgg19 import VGG19

In [None]:
vgg_model = VGG19()

In [None]:
encoder_model = km.Sequential(
    vgg_model.layers[:-5]
)

In [None]:
encoder_model.summary()

In [None]:
# freeze VGG16 layers to keep feature extractors the same
for layer in encoder_model.layers:
    layer.trainable = False

In [None]:
TRAIN_PATH = '/app/data/imagenet_data/train/'

In [None]:
train_datagen = ImageDataGenerator(rescale=1./255)

In [None]:
train = train_datagen.flow_from_directory(
    TRAIN_PATH,
    target_size=(224, 224),
    batch_size=128,
    class_mode=None
)

In [None]:
batches = train.n // train.batch_size
batches

In [None]:
for i in tqdm(range(batches)):
    b = train[i]
    if i == 428:
        print(b)

In [None]:
def create_XY(data):
    X = []
    Y = []
    for img in data:
        try:
            lab = skc.rgb2lab(img)
            X.append(lab[:, :, 0])
            Y.append(lab[:, :, 1:] / 128)
        except:
            print('error')
    X = np.array(X)
    Y = np.array(Y)
    X = X.reshape(X.shape + (1,))
    return X, Y

In [None]:
def out_vgg(X, vgg_model):
    vgg_features = []
    for i, sample in enumerate(X):
        sample = skc.gray2rgb(sample)
        sample = sample.reshape((1, 224, 224, 3))
        prediction = vgg_model.predict(sample, verbose=0)
        prediction = prediction.reshape((14, 14, 512))
        vgg_features.append(prediction)
    vgg_features = np.array(vgg_features)
    return vgg_features

In [None]:
def run_encoder_vgg(data, vgg_model):
    X, Y = create_XY(data)
    vgg_features = out_vgg(X, vgg_model)
    return vgg_features, Y

In [None]:
tensorboard_callback = TensorBoard(log_dir='/app/vgg19_tensorboard_logs', histogram_freq=0, write_graph=True, write_images=True)

In [None]:
decoder_model = km.Sequential()
decoder_model.add(kl.Conv2D(256, (3, 3), activation='relu', padding='same', input_shape=(14, 14, 512)))
decoder_model.add(kl.Conv2D(128, (3, 3), activation='relu', padding='same'))
# decoder_model.add(kl.UpSampling2D((2, 2)))
decoder_model.add(kl.Conv2D(64, (3, 3), activation='relu', padding='same'))
decoder_model.add(kl.UpSampling2D((2, 2)))
decoder_model.add(kl.Conv2D(32, (3, 3), activation='relu', padding='same'))
decoder_model.add(kl.UpSampling2D((2, 2)))
decoder_model.add(kl.Conv2D(16, (3, 3), activation='relu', padding='same'))
decoder_model.add(kl.UpSampling2D((2, 2)))
decoder_model.add(kl.Conv2D(2, (3, 3), activation='tanh', padding='same'))
decoder_model.add(kl.UpSampling2D((2, 2)))

decoder_model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

In [None]:
train.n / train.batch_size

In [None]:
hists = []
start = 0
end = 30
for i in tqdm(range(start, end)):
    vgg_features, Y = run_encoder_vgg(train[i], encoder_model)
    hist = decoder_model.fit(vgg_features, Y, validation_split=0.1, epochs=30, batch_size=32, verbose=0, callbacks=[tensorboard_callback])
    hists.append(hist)

In [None]:
concatted_histories = []
for hist in hists:
    concatted_histories.append(pd.DataFrame(hist.history))

df_hist = pd.concat(concatted_histories)
df_hist.to_csv('/app/output/transfer_learning_histories.csv')

In [None]:
TEST_PATH = '/app/data/imagenet_data/test/color/'

In [None]:
files = os.listdir(TEST_PATH)[:100]

In [None]:
def predict_grayscal2rgb(file_paths):
    rgb_images = []
    for file in tqdm(file_paths):
        test = ku.img_to_array(ku.load_img(os.path.join(TEST_PATH, file)))
        test = resize(test, (224, 224), anti_aliasing=True)
        test *= 1.0 / 255
        lab = skc.rgb2lab(test)
        l = lab[:, :, 0]
        L = skc.gray2rgb(l)
        L = L.reshape((1, 224, 224, 3))
        vggpred = encoder_model.predict(L, verbose=0)
        ab = decoder_model.predict(vggpred, verbose=0)
        ab = ab * 128

        cur = np.zeros((224, 224, 3))
        cur[:, :, 0] = l
        cur[:, :, 1:] = ab

        rgb_img = skc.lab2rgb(cur)
        rgb_img = ( rgb_img * 256 ).astype(np.uint8)
        rgb_images.append(rgb_img)

    return rgb_images

In [None]:
rgb_images = predict_grayscal2rgb(files)

In [None]:
# display images in grid
fig, ax = plt.subplots(10, 10, figsize=(10, 10))
for i, rgb_img in enumerate(rgb_images):
    ax[i//10, i%10].imshow(rgb_img)
    ax[i//10, i%10].axis('off')