# Autoencoder to encode apples
*gray version*

## Installation and Imports 

In [None]:
! pip install tensorflow

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model

## Data import

In [None]:
import glob
from PIL import Image

filelist = glob.glob('/Clean_apple/*.jpg') # image names in the directory
apples = np.array([np.array(Image.open(fname).convert('L').resize((256,256))) for fname in filelist])

In [None]:
n=400 # number of training images

x_train, x_test = apples[:n],apples[n:]

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print (x_train.shape)
print (x_test.shape)

## Model 

In [None]:
latent_dim = 2048

class Autoencoder(Model):
  def __init__(self, latent_dim):
    super(Autoencoder, self).__init__()
    self.latent_dim = latent_dim   
    self.encoder = tf.keras.Sequential([
      layers.Flatten(),
      layers.Dense(latent_dim, activation='relu'),
    ])
    self.decoder = tf.keras.Sequential([
      layers.Dense(65536, activation='sigmoid'),
      layers.Reshape((256, 256))
    ])

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded
  
autoencoder = Autoencoder(latent_dim) 

In [None]:
# checkpoints saving
checkpoint_dir = '/Autoencoder' 
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(autoencoder=autoencoder)

In [None]:
autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())

## Training

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=200,
                shuffle=True,
                validation_data=(x_test, x_test))
checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
encoded_imgs = autoencoder.encoder(x_test).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

## Display results

In [None]:
n = 10 # number of images to show 
plt.figure(figsize=(20, 4))
for i in range(n):
  # Display original images
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(x_test[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # Display reconstructed  images
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()
