# RGB Autoencoder 
The principle here is to use the same type of autoencoder as for a black and white image but using it 3 times separately on each RGB component of a color image.

## Installaton et Imports

In [None]:
! pip install tensorflow

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model

## Import and pre-processing of data

In [None]:
import glob
from PIL import Image

n= 472 # number of images in '/Clean_apple/'
filelist = glob.glob('/Clean_apple/*.jpg') # image names in the directory

apples_R = []
apples_G = []
apples_B = []

# split the images in 3 RGB components
for i in range(n):
  img = Image.open(filelist[i])
  rgb = img.split()
  apples_R.append(np.asarray(rgb[0].resize((256,256))))
  apples_G.append(np.asarray(rgb[1].resize((256,256))))
  apples_B.append(np.asarray(rgb[2].resize((256,256))))

apples_R = np.asarray(apples_R)
apples_G = np.asarray(apples_G)
apples_B = np.asarray(apples_B)

## Creation of training and test data sets

Red

In [None]:
x_train_R, x_test_R = apples_R[:400],apples_R[400:]

x_train_R = x_train_R.astype('float32') / 255.
x_test_R = x_test_R.astype('float32') / 255.

print (x_train_R.shape)
print (x_test_R.shape)

Green

In [None]:
x_train_G, x_test_G = apples_G[:400],apples_G[400:]

x_train_G = x_train_G.astype('float32') / 255.
x_test_G = x_test_G.astype('float32') / 255.

print (x_train_G.shape)
print (x_test_G.shape)

Blue

In [None]:
x_train_B, x_test_B = apples_B[:400],apples_B[400:]

x_train_B = x_train_B.astype('float32') / 255.
x_test_B = x_test_B.astype('float32') / 255.

print (x_train_B.shape)
print (x_test_B.shape)

## Models
Note: We only need one starting model which is duplicated and trained separately for each color

In [None]:
latent_dim = 2048

class Autoencoder(Model):
  def __init__(self, latent_dim):
    super(Autoencoder, self).__init__()
    self.latent_dim = latent_dim   
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(latent_dim, activation='relu'),
    ])
    self.decoder = tf.keras.Sequential([
      layers.Dense(65536, activation='sigmoid'),
      layers.Reshape((256, 256))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded


The Autoencoders :

In [None]:
autoencoder_R = Autoencoder(latent_dim) 
autoencoder_G = Autoencoder(latent_dim) 
autoencoder_B = Autoencoder(latent_dim) 

## Checkpoints saving

In [None]:
# R
checkpoint_dir_R = '/Autoencoder/R' 
checkpoint_prefix_R = os.path.join(checkpoint_dir_R, "ckpt")
checkpoint_R = tf.train.Checkpoint(autoencoder=autoencoder_R)

In [None]:
# G
checkpoint_dir_G = '/Autoencoder/G' 
checkpoint_prefix_G = os.path.join(checkpoint_dir_G, "ckpt")
checkpoint_G = tf.train.Checkpoint(autoencoder=autoencoder_G)

In [None]:
# B
checkpoint_dir_B = '/Autoencoder/B' 
checkpoint_prefix_B = os.path.join(checkpoint_dir_B, "ckpt")
checkpoint_B = tf.train.Checkpoint(autoencoder=autoencoder_B)

## Training

In [None]:
autoencoder_R.compile(optimizer='adam', loss=losses.MeanSquaredError())
autoencoder_G.compile(optimizer='adam', loss=losses.MeanSquaredError())
autoencoder_B.compile(optimizer='adam', loss=losses.MeanSquaredError())

In [None]:
# R
autoencoder_R.fit(x_train_R, x_train_R,
                epochs=200,
                shuffle=True,
                validation_data=(x_test_R, x_test_R))
checkpoint_R.save(file_prefix = checkpoint_prefix_R)

In [None]:
# G
autoencoder_G.fit(x_train_G, x_train_G,
                epochs=200,
                shuffle=True,
                validation_data=(x_test_G, x_test_G))
checkpoint_G.save(file_prefix = checkpoint_prefix_G)

In [None]:
# B
autoencoder_B.fit(x_train_B, x_train_B,
                epochs=200,
                shuffle=True,
                validation_data=(x_test_B, x_test_B))
checkpoint_B.save(file_prefix = checkpoint_prefix_B)

## Results

In [None]:
encoded_imgs_R = autoencoder_R.encoder(x_test_R).numpy()
decoded_imgs_R = autoencoder_R.decoder(encoded_imgs_R).numpy()

encoded_imgs_G = autoencoder_G.encoder(x_test_G).numpy()
decoded_imgs_G = autoencoder_G.decoder(encoded_imgs_G).numpy()

encoded_imgs_B = autoencoder_B.encoder(x_test_B).numpy()
decoded_imgs_B = autoencoder_B.decoder(encoded_imgs_B).numpy()

In [None]:
# number of images to display
n = 10

Red autoencoder's results :

In [None]:
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test_R[i], "Reds_r")
  plt.title("original_R")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs_R[i], "Reds_r")
  plt.title("reconstructed_R")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

Green autoencoder's results :

In [None]:
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test_G[i], "Greens")
  plt.title("original_G")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs_G[i], "Greens")
  plt.title("reconstructed_G")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

Blue autoencoder's results :

In [None]:
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test_B[i], "Blues")
  plt.title("original_B")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs_B[i], "Blues")
  plt.title("reconstructed_B")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

RGB results :

In [None]:
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)  
  plt.imshow(np.dstack([x_test_R[i], x_test_G[i],x_test_B[i]]))
  plt.title("original")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(np.dstack([decoded_imgs_R[i], decoded_imgs_G[i],decoded_imgs_B[i]]))
  plt.title("reconstructed")
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()