In [None]:
import matplotlib.pyplot as plt
import numpy as np
import joblib as jb

from tensorflow import keras


from keras.models import Model, load_model
from keras.layers import Dense, Input, Conv2D, MaxPool2D, UpSampling2D
from keras import backend as k

%matplotlib inline

In [None]:
# load preprocessed dragonfly images
trainpath = r'D:\Linnaeus_models\dragon\train\dragon_train.npy'
testpath = r'D:\Linnaeus_models\dragon\test\dragon_test.npy'
dragons = np.concatenate((np.load(trainpath), np.load(testpath)), axis=0)
print(f'Number of images: {len(dragons)}')

##  Auto Encoding With CNN
   * Test image denoising abilities on the images

In [None]:
# separate into training set and validation set and reshape to fit the NN
x_train = dragons[0:17000] 
x_val = dragons[17000:]

In [None]:
x_train = np.expand_dims(x_train, axis=-1)
x_val = np.expand_dims(x_val, axis=-1)

In [None]:
def create_unsupervised_model(weights_path=None, weights_name=None, shape=(256, 256, 1)):
    input_dim = Input(shape=shape)

    # encoded representation of the input
    encoded = Conv2D(48, (3, 3), activation='relu', padding='same')(input_dim)
    encoded = MaxPool2D((2, 2), padding='same')(encoded)

    encoded = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
    encoded = MaxPool2D((2, 2), padding='same')(encoded)

    encoded = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
    encoded = MaxPool2D((2, 2), padding='same')(encoded)

    encoded = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
    encoded = MaxPool2D((2, 2), padding='same')(encoded)

    # reconstruction of the input
    decoded = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
    decoded = UpSampling2D((2, 2))(decoded)

    decoded = Conv2D(16, (3, 3), activation='relu', padding='same')(decoded)
    decoded = UpSampling2D((2, 2))(decoded)

    decoded = Conv2D(32, (3, 3), activation='relu', padding='same')(decoded)
    decoded = UpSampling2D((2, 2))(decoded)

    decoded = Conv2D(48, (3, 3), activation='relu', padding='same')(decoded)
    decoded = UpSampling2D((2, 2))(decoded)

    decoded = Conv2D(1, (3, 3), padding='same')(decoded)

    # define model input and output
    model = Model(input_dim, decoded)
    
    #compile
    model.compile(optimizer='Adam', loss='mse')
    
    if weights_path and weights_name:
        model.load_weights(f'{weights_path}{weights_name}')
        
    return model

In [None]:
model = create_unsupervised_model()
# model.summary()

In [None]:
train_history = model.fit(x_train, x_train, epochs=10, batch_size=100, 
                          validation_data=(x_val, x_val))
# save model
model_path =  r'D:\Linnaeus_models\dragon_reconstruction_v3.pkl'
jb.dump(loaded_model, model_path)

### Reload unsupervised model to continue training

In [None]:
# load the model
loaded_model = jb.load(r'D:\Linnaeus_models\dragon_reconstruction_v3.pkl')

In [None]:
# continue training the unsupervised model
train_history = loaded_model.fit(x_train, x_train, epochs=16, batch_size=100, 
                                 validation_data=(x_val, x_val))

# save model and weights
loaded_model.save_weights(r'D:\Linnaeus_models\dragon_reconstruction_v4_weights')
model_path =  r'D:\Linnaeus_models\dragon_reconstruction_v4.pkl'
jb.dump(loaded_model, model_path)

In [None]:
# preds = loaded_model.predict(x_val)
jb.dump(preds, './unsupervised_val_images.pkl')

In [None]:
preds = jb.load('./unsupervised_val_images.pkl')

In [None]:
reverted_x = np.squeeze(x_val, axis=-1)

In [None]:
plt.imshow(reverted_x[2].reshape(256, 256), cmap='gray')

In [None]:
plt.imshow(x_val[10].reshape(256, 256), cmap='gray')

In [None]:
plt.imshow(preds[10].reshape(256, 256), cmap='gray')

In [None]:
# save model
model_path =  r'D:\Linnaeus_models\dragon_reconstruction_v3.pkl'
jb.dump(loaded_model, model_path)

In [None]:
# past models

In [None]:
model_name =  r'D:\Linnaeus_models\dragon_reconstruction_v1.pkl'
new_model = jb.load(model_name)

In [None]:
import keras
import pydot
from keras.utils import plot_model
keras.utils.vis_utils.pydot = pydot
plot_model(new_model, to_file='model.png')

In [None]:
from keras.utils.vis_utils import model_to_dot

In [None]:
from IPython.display import SVG
SVG(model_to_dot(new_model).create(prog='dot', format='svg'))