In [1]:
import os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
print(len(os.listdir('/content/drive/My Drive/Colab Notebooks/colordata/')))

8166


In [4]:
import os
import random
import numpy as np
import cv2
import keras
import tensorflow as tf
from tqdm import tqdm
from keras import Input, Model
from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input, decode_predictions
from keras.layers import UpSampling2D, RepeatVector, Reshape, concatenate
from keras.layers.convolutional import Conv2D
from keras.callbacks import ModelCheckpoint
from keras.backend.tensorflow_backend import set_session
from keras.utils import generic_utils

resnet = InceptionResNetV2(weights=None, include_top=True)
def getImgList(basePath = '/content/drive/My Drive/Colab Notebooks/colordata/'):
    imgList = os.listdir(basePath)
    return imgList

def loadWeight(basePath = '/content/drive/My Drive/Colab Notebooks/model/'):
    resnet.load_weights(basePath + 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5')
    print("Resnet Weight loaded success")

def convImg(dataList, channels):
    imgList = []
    for img in dataList:
        img = cv2.resize(img, (256, 256))
        label_img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        if channels == 1:
            label_img = label_img[:, :, 0]
        imgList.append(label_img)
    
    return np.array(imgList).reshape(len(dataList), 256, 256, channels)

def resnetEmbedding(dataList):
    imgData = []
    for img in dataList:
        img = cv2.resize(img, (299, 299))
        greyImg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        bgrImg = cv2.cvtColor(greyImg, cv2.COLOR_GRAY2BGR)
        bgrImg = preprocess_input(bgrImg)
        imgData.append(bgrImg)
    imgData = np.array(imgData, dtype=float)
    embedding = resnet.predict(imgData)
    return embedding

def preReadData(dataList):  #fix the colab drive read issue
    allData = []
    for item in tqdm(dataList):
        basePath = '/content/drive/My Drive/Colab Notebooks/colordata/'
        img = cv2.imread(basePath + item)
        allData.append(img)
    return allData

def getData(allData, batch_size, train=True):
    if train:
        allData = allData[:8000]
    else:
        allData = allData[8000:8083]
    while True:
        for i in range(0, len(allData), batch_size):
            img = convImg(allData[i:i+batch_size], 3)
            embeddingImg = resnetEmbedding(allData[i:i+batch_size])
            x = img[:, :, :, 0]
            x = x.reshape(x.shape + (1,))
            y = img[:, :, :, 1:] / 128
            yield([x, embeddingImg], y)


Using TensorFlow backend.


In [5]:
def colorNet():
    embedInput = Input(shape=(1000,))
    encoderInput = Input(shape=(256, 256, 1,))

    x = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(encoderInput)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)

    fushionX = RepeatVector(32 * 32)(embedInput)
    fushionX = Reshape(([32, 32, 1000]))(fushionX)
    fushionX = concatenate([x, fushionX], axis=3)

    x = Conv2D(256, (1, 1), activation='relu', padding='same')(fushionX)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(2, (3, 3), activation='relu', padding='same')(x)
    result = UpSampling2D((2, 2))(x)

    model =  Model(inputs = [encoderInput, embedInput], outputs = result)
    model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

    return model

modelPath = '/content/drive/My Drive/Colab Notebooks/model/colorNet.hdf5'
model = colorNet()
#checkpoint = ModelCheckpoint(modelPath, monitor='loss', verbose=1, save_best_only=True)
if os.path.exists(modelPath):
    model.load_weights(modelPath)
    print("Check point loaded!")

Check point loaded!


In [None]:
loadWeight()
allData = preReadData(getImgList())
r_epochs = 0
epoch_length = 160
num_epochs = 200
dataGen = getData(allData,50)
bestLoss = np.Inf
iter_num = 0
losses = np.zeros((epoch_length, 2))
for epoch_num in range(num_epochs):
    progbar = generic_utils.Progbar(epoch_length)
    print('Epoch {}/{}'.format(r_epochs + 1, num_epochs))
    r_epochs += 1
    while True:
        X,Y = next(dataGen)
        modelLoss = model.train_on_batch(X,Y)
        losses[iter_num, 0] = modelLoss[0]
        losses[iter_num, 1] = modelLoss[1]
        
        iter_num += 1
        progbar.update(iter_num, [('loss', np.mean(losses[:iter_num, 0])), ('acc', np.mean(losses[:iter_num, 1]))])
        if iter_num == epoch_length:
            if modelLoss[0] < bestLoss:
                model.save_weights(modelPath)
            iter_num = 0
            break


In [13]:
def test():
    outputPath = '/content/drive/My Drive/Colab Notebooks/test/'
    testImg = getImgList()
    testImg = testImg[8000:8083]
    allData = []
    for item in testImg:
        basePath = '/content/drive/My Drive/Colab Notebooks/colordata/'
        img = cv2.imread(basePath + item)
        allData.append(img)
    greyImg = convImg(allData,1)
    imgEmbed = resnetEmbedding(allData)
    result = model.predict([greyImg,imgEmbed])
    for i in range(len(result)):
        combine = np.zeros((256, 256, 3))
        combine[:, :, 0] = greyImg[i][:, :, 0]
        combine[:, :, 1:] = result[i] * 128
        imgCopy = np.uint8(combine)
        cv2.imwrite(outputPath + 'test_img'+str(i)+'.jpg',cv2.cvtColor(imgCopy, cv2.COLOR_LAB2BGR))
test()