In [22]:
#Importing Libraries
import os
import cv2
import numpy
import keras
import skimage.io
import skimage.color
import matplotlib.pyplot
import tensorflow
import tensorboard


In [None]:
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
print(f"Number of GPUs: {len(gpus)}")
for gpu in gpus:
    tensorflow.config.experimental.set_memory_growth(gpu, True)


In [24]:
#Getting Images

X = []
for imagename in os.listdir('Dataset/Train/'):
    X.append(tensorflow.keras.utils.img_to_array(tensorflow.keras.utils.load_img('Dataset/Train/'+imagename, target_size=(256, 256))))
X = numpy.array(X, dtype=float)





In [25]:
# Set up train and test data
split = int(0.95*len(X))
Xtrain = X[:split]
Xtrain = 1.0/255*Xtrain

In [26]:

#set up Test data
Xtest = X[split:]
Xtest = 1.0/255*Xtest

In [27]:
#CNN model

def create_model():


    model = keras.models.Sequential()

    #Input Layer
    model.add(keras.layers.Conv2D(64, (3, 3), input_shape=(256, 256, 1), activation='relu', padding='same'))

    #Hidden Layers
    model.add(keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))
    model.add(keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))
    model.add(keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))
    model.add(keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.UpSampling2D((2, 2)))
    model.add(keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.UpSampling2D((2, 2)))
    model.add(keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'))
    model.add(keras.layers.Conv2D(2, (3, 3), activation='tanh', padding='same'))
    model.add(keras.layers.UpSampling2D((2, 2)))



    #Compiling the CNN
    model.compile(optimizer='rmsprop', loss='mse', metrics = ['accuracy'])


    return model

In [28]:
# Function to save the last completed epoch
def save_epoch(epoch):
    with open('Dataset/output/last_epoch.txt', 'w') as f:
        f.write(str(epoch))


In [29]:
# Custom callback to save the epoch
class EpochSaver(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        save_epoch(epoch)


In [30]:






# Generate training data
def image_a_b_gen(batch_size, datagen, Xtrain):
    for batch in datagen.flow(Xtrain, batch_size=batch_size):
        lab_batch = skimage.color.rgb2lab(batch)
        X_batch = lab_batch[:, :, :, 0]
        Y_batch = lab_batch[:, :, :, 1:] / 128
        yield (X_batch.reshape(X_batch.shape + (1,)), Y_batch)


In [31]:

# Image transformer
datagen = tensorflow.keras.preprocessing.image.ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True)

In [32]:
# Load or create model
model_path = 'Dataset/output/model_checkpoint.h5'
if os.path.exists(model_path):
    model = tensorflow.keras.models.load_model(model_path)
else:
    model = create_model()


In [33]:
# Read the last completed epoch
last_epoch = 0
if os.path.exists('Dataset/output/last_epoch.txt'):
    with open('Dataset/output/last_epoch.txt', 'r') as f:
        last_epoch = int(f.read())


In [34]:
# Define callbacks
tensorboard = tensorflow.keras.callbacks.TensorBoard(log_dir="Dataset/output/beta_run")
checkpoint_callback =tensorflow.keras.callbacks.ModelCheckpoint(
    model_path,
    save_best_only=False,
    save_weights_only=False,
    monitor='loss',
    mode='min',
    verbose=1
)
epoch_saver = EpochSaver()


In [None]:
# Train model
batch_size = 10
trainedmodel = model.fit(
    image_a_b_gen(batch_size, datagen, Xtrain),  # Replace Xtrain with your training data
    callbacks=[tensorboard, checkpoint_callback, epoch_saver],
    epochs=500,
    steps_per_epoch=30,
    initial_epoch=last_epoch  # Start from the last completed epoch
)

In [None]:

# Summarize history for model accuracy
matplotlib.pyplot.plot(trainedmodel.history['accuracy'])
matplotlib.pyplot.title('model accuracy')
matplotlib.pyplot.ylabel('accuracy')
matplotlib.pyplot.xlabel('epoch')
matplotlib.pyplot.legend(['train', 'test'], loc='upper left')
matplotlib.pyplot.show()


In [None]:

# Summarize history for model loss
matplotlib.pyplot.plot(trainedmodel.history['loss'])
matplotlib.pyplot.title('model loss')
matplotlib.pyplot.ylabel('loss')
matplotlib.pyplot.xlabel('epoch')
matplotlib.pyplot.legend(['train', 'test'], loc='upper left')
matplotlib.pyplot.show()

In [38]:
# Save model

model_json = model.to_json()
with open("Dataset/Model/model.json", "w") as json_file:
    json_file.write(model_json)
model.save_weights("Dataset/Model/model.h5")

In [39]:
# load json and create model

from keras.models import model_from_json
json_file = open('Dataset/Model/model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
# load weights into new model
loaded_model.load_weights("Dataset/Model/model.h5")

In [None]:
loaded_model.summary()

In [None]:
# Test images
loaded_model.compile(optimizer='rmsprop', loss='mse', metrics = ['accuracy'])
Xtest = skimage.color.rgb2lab(1.0/255*X[split:])[:,:,:,0]
Xtest = Xtest.reshape(Xtest.shape+(1,))
Ytest = skimage.color.rgb2lab(1.0/255*X[split:])[:,:,:,1:]
Ytest = Ytest / 128
print(loaded_model.evaluate(Xtest, Ytest, batch_size=10))

In [None]:

# Directory containing images
directory = 'Dataset/Test/'

# Count number of images in the directory
num_images = len([filename for filename in os.listdir(directory) if filename.endswith('.jpg') or filename.endswith('.png')])

print("Number of images in folder : ", num_images)


# Ensure at least 1 image is found
if num_images == 0:
    raise ValueError("No images found in the directory.")

# Create subplot grid based on the number of images found
fig, ax = matplotlib.pyplot.subplots(num_images, 2, figsize=(16, num_images * 5))
row = 0
colorize = []

print('Output of the Model')

# Loop through images in the directory
for filename in os.listdir(directory):
    if filename.endswith('.jpg') or filename.endswith('.png'):
        img = cv2.imread(os.path.join(directory, filename))

        # Check if image is read correctly
        if img is None:
            print(f"Couldn't read image {filename}. Skipping.")
            continue

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img, (256, 256))
        colorize.append(img_resized)

        ax[row, 0].imshow(img_resized, interpolation='nearest')
        row += 1

colorize = numpy.array(colorize, dtype=float)
colorize = skimage.color.rgb2lab(1.0/255 * colorize)[:, :, :, 0]
colorize = colorize.reshape(colorize.shape + (1,))


# Test model
output = loaded_model.predict(colorize)
output *= 128

row = 0

# Output colorizations
for i in range(len(output)):
    cur = numpy.zeros((256, 256, 3))
    cur[:, :, 0] = colorize[i][:, :, 0]
    cur[:, :, 1:] = output[i]
    resImage = skimage.color.lab2rgb(cur)

    ax[row, 1].imshow(resImage, interpolation='nearest')
    row += 1

matplotlib.pyplot.show()