In [1]:
import numpy as np
import cv2

In [2]:
def load_init_model():
    """Load the Caffe model and according data, return the finished model"""
    model = cv2.dnn.readNetFromCaffe(
        "model/colorization_deploy_v2.prototxt",
        "model/colorization_release_v2.caffemodel"
    )
    pts = np.load("model/pts_in_hull.npy")
    # add the cluster centers as 1x1 convolutions to the model
    class8 = model.getLayerId("class8_ab")
    conv8 = model.getLayerId("conv8_313_rh")
    pts = pts.transpose().reshape(2, 313, 1, 1)
    model.getLayer(class8).blobs = [pts.astype("float32")]
    model.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype="float32")]
    return model

In [3]:
def prepare_and_predict_image(img, model):
    """ Convert image to Lab, scale to model dims, pass through model.
    Combine predicted ab with original L, convert to RGB and scale back.
    """
    image = cv2.imread(img)
    scaled = image.astype("float32") / 255.0
    lab = cv2.cvtColor(scaled, cv2.COLOR_BGR2LAB)
    resized = cv2.resize(lab, (224, 224)) # resize to model dimensions
    L = cv2.split(resized)[0] # extract L from LAB
    L -= 50 # mean centering
    model.setInput(cv2.dnn.blobFromImage(L))
    ab = model.forward()[0, :, :, :].transpose((1, 2, 0))
    ab = cv2.resize(ab, (image.shape[1], image.shape[0])) # resize prediction back to img props
    # get original L layer, add a and b to it, convert to rgb
    L = cv2.split(lab)[0]
    colorized = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
    colorized = cv2.cvtColor(colorized, cv2.COLOR_LAB2BGR)
    colorized = np.clip(colorized, 0, 1)
    # convert from float range (0,1) to uint (0, 255)
    colorized = (255 * colorized).astype("uint8")
    cv2.imshow("Original", image)
    cv2.imshow("Colorized", colorized)
    cv2.waitKey(0)

In [4]:
model = load_init_model()

In [5]:
testimg = "data/testdata/Validate/7Vizcm.jpg"
testimg = "data/testdata/Validate/landscape.jpeg"
testimg = "data/testdata/Validate/9KfZez.jpg"
testimg = "data/testdata/Validate/1QejlL.jpg"

testimg = "data/testdata/Train/11Se02.jpg"
testimg = "data/testdata/Train/1PFDZe.jpg"
testimg = "data/testdata/Validate/landscape2.jpg"

prepare_and_predict_image(testimg, model)