### 1. Import required libraries

In [None]:
import os

from utils import preprocess_data, color_image

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, UpSampling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import mlflow
import mlflow.tensorflow

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

### 2. Configure MLflow instance

In [None]:
# MLflow Configuration
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment('autoencoders-image-coloring')

mlflow.start_run()
mlflow.set_tag('model', 'basic_cnn')

run = mlflow.active_run()
print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))

# Autologging the Model data
mlflow.tensorflow.autolog(every_n_iter=5)

### 3. Loading the dataset(color images)

In [None]:
# Configurations
path = "<path_to_color_images>"
IMG_WIDTH, IMG_HEIGHT = 256, 256

#Normalize images - divide by 255
train_datagen = ImageDataGenerator(rescale=1./255)

#Resize images, if needed
train = train_datagen.flow_from_directory(path, 
                                          target_size=(IMG_WIDTH, IMG_HEIGHT), 
                                          batch_size=340, 
                                          class_mode=None)

### 4. Data Preprocessing

In [None]:
# Convert from RGB to Lab
X, Y = preprocess_data(train)

print(X.shape)
print(Y.shape)

### 5. Convolutional Neural Network model

In [None]:
model = Sequential()

#Encoder
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2, input_shape=(IMG_WIDTH, IMG_HEIGHT, 1)))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Conv2D(128, (3,3), activation='relu', padding='same', strides=2))
model.add(Conv2D(256, (3,3), activation='relu', padding='same'))
model.add(Conv2D(256, (3,3), activation='relu', padding='same', strides=2))
model.add(Conv2D(512, (3,3), activation='relu', padding='same'))
model.add(Conv2D(512, (3,3), activation='relu', padding='same'))
model.add(Conv2D(256, (3,3), activation='relu', padding='same'))

# Decoder
model.add(Conv2D(128, (3,3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3,3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
model.add(Conv2D(16, (3,3), activation='relu', padding='same'))
model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))
model.add(UpSampling2D((2, 2)))

model.summary()

### 6. Model's configuration

In [None]:
# Optimizer
opt = tf.keras.optimizers.RMSprop()

# Loss Function
loss = 'mse'

#  Compiling the model
model.compile(optimizer=opt, loss=loss , metrics=['accuracy'])

### 7. Training the model

In [None]:
EPOCHS = 30
BATCH_SIZE = 16

history = model.fit(X, Y, validation_split=0.2, epochs=EPOCHS, batch_size=BATCH_SIZE)

### 8. Model's training metrics

In [None]:
ax = plt.subplot(label= True)
plt.plot(history.history["loss"], label = "loss")
plt.plot(history.history["val_loss"], label = "val_loss")
plt.legend(["loss", "val_loss"])

### 9. Saving the model

In [None]:
# Saving model into .h5 file
model.save('./models/<model_name.h5>')

# Ending the MLflow instance
mlflow.end_run()
run = mlflow.get_run(run.info.run_id)
print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))

### 10. Coloring an image

In [None]:
# Load the model
model = tf.keras.models.load_model('./models/<saved_model>')

# Coloring the image
color_image(model, '<path_of_image_to_be_colored>')