# Dependencies

In [None]:
import MURA
import cv2
import image_manipulation
from multiprocessing import Pool
from models import VAE, UPAE
import numpy as np
import glob
import matplotlib.pyplot as plt
import preprocessing
import json
from os import path, mkdir

import keras
from argparse import ArgumentParser
import tensorflow as tf

# Model to Run
This section lets the user edit on what model parameters to run, the model directory and parameters should exist prior to changing the file path

In [None]:
# All editable variables
model_file_path = "models/model_1"
# If preprocessing needs to run
run_preprocessing = False

In [None]:
# Get model parameters
try:
    params = json.load(open(model_file_path + '/parameters.json'))

    # Model parameters
    multiplier = params['multiplier']
    latent_size = params['latent_size']
    input_shape = params['input_shape']

    # Training parameters
    epochs = params['num_epochs']
    batch_size = params['batch_size']
    learning_rate = params['learning_rate']

    # Dataset Path
    image_paths = MURA.MURA_DATASET()
    dataset_file_path = params['dataset_path']
    all_image_paths = image_paths.get_combined_image_paths()
    all_image_paths = all_image_paths.to_numpy()[:,0]
except:
    raise Exception("No parameters.json file found in the model's directory.")

In [None]:
# Do preprocessing
if run_preprocessing:
    preprocess = preprocessing.preprocessing(input_path = all_image_paths, output_path = dataset_file_path)
    if __name__ == '__main__':
        preprocess.start()

In [None]:
# each array contains the training, validation, and testing in order
image_datasets = {'train': [],
                'valid': [],
                'test': []}
for dataset_name in image_datasets.keys():
    for image_path in glob.glob(f'{dataset_file_path}/{dataset_name}/*.png'):
        image_datasets[dataset_name].append(cv2.imread(image_path))
    image_datasets[dataset_name] = np.array(image_datasets[dataset_name])

# Model Training
This section creates and trains the model

In [None]:
if __name__ == "__main__":

    #for either VAE or UPAE
    parser = ArgumentParser()
    parser.add_argument('--u', dest='u', action='store_true') # use uncertainty
    opt, unknown = parser.parse_known_args()

    #preprocessing and augmentation
    # image_datasets = data_preparation()
    
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
    
    if opt.u is False:
        model = VAE(opt.u, input_shape, multiplier, latent_size)
    elif opt.u is True:
        model = UPAE(opt.u, input_shape, multiplier, latent_size)

    model.build(input_shape=(None, 64, 64, 3))

    model.compile(optimizer= optimizer
                  ,metrics=[tf.keras.metrics.Accuracy()])


    # plot_model(model, 'autoencoder_compress.png', show_shapes=True)
    #training on training set.
    history_train = model.fit(image_datasets['train'], 
                epochs=epochs, 
                batch_size=batch_size)
    

# Model Prediction

In [None]:
history_valid = model.predict(image_datasets['valid'], batch_size=batch_size)

# Plots Creation

In [None]:
train_loss = history_train.history['binary_crossentropy: ']
# valid_loss = history_valid.history['binary_crossentropy: ']

plt.plot(train_loss , label='train')
# plt.plot(valid_loss , label='test')
plt.title('Train vs Validation Loss')
plt.ylabel('Reconstruction Loss')
plt.xlabel('Epochs')
plt.legend()
plt.show()

# Saving of final reconstructed images 

In [None]:
# Create directory in models folder for reconstructed images
dataset_name = dataset_file_path.split('/')[-1]
reconstructed_images_path = model_file_path + "/" + dataset_name
if not path.exists(reconstructed_images_path):
    mkdir(reconstructed_images_path)

In [None]:
for x in range(len(history_valid)):
    fig, axs = plt.subplots(1,2, figsize=(8,4))
    axs[0].imshow(image_datasets['valid'][x])
    axs[0].set_title('Original Image')
    new_image = np.clip(np.floor(history_valid[x]), 0, 255).astype(np.uint8)
    axs[1].imshow(new_image)
    axs[1].set_title('Reconstructed Image')
    plt.savefig(f'{reconstructed_images_path}/Valid_Image_{x}.png')
    plt.close()

# Saving of Model Weights

In [None]:
model.save_weights(model_file_path + '/model_weights.h5')