# Dependencies

In [None]:
import MURA
import cv2
import image_manipulation
from multiprocessing import Pool
from models import *
import numpy as np
import glob
import matplotlib.pyplot as plt
import preprocessing
import json
from os import path, mkdir

import keras
from argparse import ArgumentParser
import tensorflow as tf

# Model to Run
This section lets the user edit on what model parameters to run, the model directory and parameters should exist prior to changing the file path

In [None]:
# All editable variables
model_file_path = "models/model_3"
# If preprocessing needs to run
run_preprocessing = False

In [None]:
# Get model parameters
try:
    params = json.load(open(model_file_path + '/parameters.json'))

    # Model to use
    model_is_VAE = params['is_VAE']
    # Model parameters
    multiplier = params['multiplier']
    latent_size = params['latent_size']
    input_shape = params['input_shape']

    # Training parameters
    epochs = params['num_epochs']
    batch_size = params['batch_size']
    learning_rate = params['learning_rate']

    # Dataset Path
    image_paths = MURA.MURA_DATASET()
    dataset_file_path = params['dataset_path']
    all_image_paths = image_paths.get_combined_image_paths()
    all_image_paths = all_image_paths.to_numpy()[:,0]
except:
    raise Exception("No parameters.json file found in the model's directory.")

In [None]:
# Do preprocessing
if run_preprocessing:
    preprocess = preprocessing.preprocessing(input_path = all_image_paths, output_path = dataset_file_path)
    if __name__ == '__main__':
        preprocess.start()

In [None]:
# each array contains the training, validation, and testing in order
image_datasets = {'train': [],
                'valid': [],
                'test': []}
for dataset_name in image_datasets.keys():
    for image_path in glob.glob(f'{dataset_file_path}/{dataset_name}/*.png'):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image_datasets[dataset_name].append(image)
    image_datasets[dataset_name] = np.array(image_datasets[dataset_name])

# Model Training
This section creates and trains the model

In [None]:
if __name__ == "__main__":

    if model_is_VAE:
        model = VAE(False, input_shape, multiplier, latent_size)
    else:
        model = UPAE(True, input_shape, multiplier, latent_size)

    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    model.build(input_shape=(None,) + tuple(input_shape))

    model.compile(optimizer= optimizer, loss='mse'
                  ,metrics=[tf.keras.metrics.Accuracy()])
    
    model.summary()

    # Where images of each epoch will be saved
    # save_directory = 'Images/images_epochs' #edited in models where automatically make folder if non existent
    save_directory = model_file_path + '/callback_images'
    save_callback = SaveImageCallback(image_datasets['train'], save_directory=save_directory)

    # plot_model(model, 'autoencoder_compress.png', show_shapes=True)
    #training on training set.
    history_train = model.fit(image_datasets['train'], 
                epochs=epochs, 
                batch_size=batch_size,
                validation_split=0.15,
                callbacks=[save_callback])
    

# Model Prediction

In [None]:
history_valid = model.predict(image_datasets['test'], batch_size=batch_size)

In [None]:
len(history_valid)

# Testing of the Model with the Test Set
This section tests the model with the current test set
TODO: 
- Get the label of each image in the test set
- Test the images
- Create Linear Regression for the abnormality score to get the threshold for determining abnormal or normal images

# Saving of final reconstructed images 

TODO: Create a more efficient saving of output images

In [None]:
# Create directory in models folder for reconstructed images
dataset_name = dataset_file_path.split('/')[-1]
reconstructed_images_path = model_file_path + "/" + dataset_name
if not path.exists(reconstructed_images_path):
    mkdir(reconstructed_images_path)

In [None]:
for x in range(len(history_valid[0])):
    fig, axs = plt.subplots(1,2, figsize=(8,4))
    axs[0].imshow(image_datasets['test'][x])
    axs[0].set_title('Original Image')
    new_image = np.floor(history_valid[0][x]).astype(np.uint8)
    axs[1].imshow(new_image)
    axs[1].set_title('Reconstructed Image')
    plt.savefig(f'{reconstructed_images_path}/Valid_Image_{x}.png')
    plt.close()
    

# Saving of Model Weights

In [None]:
print(model_file_path)

In [None]:
model.save_weights(model_file_path + '/model_weights', save_format='tf')

In [None]:
history_valid[0].shape

In [None]:
new_image = np.concatenate(history_valid[0], axis=1)
plt.imshow(new_image)

In [None]:
history_valid[0]