In [None]:
import os
from tensorflow import keras

# set warnings
import warnings
warnings.simplefilter(action='ignore', category=Warning)

# import modules and components
from FEXT.commons.utils.dataloader.generators import build_tensor_dataset
from FEXT.commons.utils.dataloader.serializer import DataSerializer, ModelSerializer
from FEXT.commons.utils.models.inferencer import FeatureExtractor
from FEXT.commons.utils.validation import ModelValidation

# 1. Load data and model

In [None]:
# selected and load the pretrained model, then print the summary 
modelserializer = ModelSerializer()         
model, parameters = modelserializer.load_pretrained_model()
model_folder = modelserializer.loaded_model_folder
model.summary(expand_nested=True)

# isolate the encoder from the autoencoder model, and use it for inference     
encoder_input = model.get_layer('input_1')  
encoder_output = model.get_layer('fe_xt_encoder')  
encoder_model = keras.Model(inputs=encoder_input.input, 
                            outputs=encoder_output.output)

### 1.1 Create generator and datasets

In [None]:
# extract paths
dataserializer = DataSerializer()
processed_img_path = dataserializer.load_preprocessed_data(model_folder) 

# initialize the TensorDataSet class with the generator instances
# create the tf.datasets using the previously initialized generators    
train_dataset = build_tensor_dataset(processed_img_path['train'])
validation_dataset = build_tensor_dataset(processed_img_path['validation'])
test_dataset = build_tensor_dataset(processed_img_path['test'])  

# 2. Model performance evaluation

### 2.1 Evaluation of loss and metrics

In [None]:
validator = ModelValidation(model)

# create subfolder for evaluation data
eRESULTS_PATH = os.path.join(model_folder, 'evaluation') 
os.mkdir(eRESULTS_PATH) if not os.path.exists(eRESULTS_PATH) else None

# evluate the model on both the train and test dataset
train_eval = model.evaluate(train_dataset, batch_size=25, verbose=1)
validation_eval = model.evaluate(validation_dataset, batch_size=25, verbose=1)

print('\nTrain dataset:')
print(f'Loss: {train_eval[0]}')    
print(f'Metric: {train_eval[1]}')  
print('\nTest dataset:')
print(f'Loss: {validation_eval[0]}')    
print(f'Metric: {validation_eval[1]}') 

### 2.2 Reconstruction evaluation

Compare reconstructed images to original pictures to qualitatively evaluate the performance of the FeXT autoencoder model

In [None]:
# perform visual validation for the train dataset (initialize a validation tf.dataset
# with batch size of 10 images)
print('Visual reconstruction validation: train dataset\n')
plot_name = 'visual_evaluation_train'
train_batch = train_dataset.unbatch().batch(10).take(1)
for images, labels in train_batch:
    recostructed_images = model.predict(images, verbose=0)
    validator.visualize_reconstructed_images(images, recostructed_images, plot_name, eRESULTS_PATH)

# perform visual validation for the test dataset (initialize a validation tf.dataset
# with batch size of 10 images)
print('Visual reconstruction validation: test dataset\n')
plot_name = 'visual_evaluation_val'
validation_batch = validation_dataset.unbatch().batch(10).take(1)
for images, labels in validation_batch:
    recostructed_images = model.predict(images, verbose=0) 
    validator.visualize_reconstructed_images(images, recostructed_images, plot_name, eRESULTS_PATH)

Visualize the original image and the reconstructed images, together with the corresponding features vector. The raw vector has shape 4x4x512 and is reshaped to be 64x512

In [None]:
single_image_batch = train_dataset.unbatch().batch(1).take(1)

# isolate the encoder from the autoencoder model
encoder_input = model.get_layer('input_1')  
encoder_output = model.get_layer('fe_xt_encoder')  
encoder_model = keras.Model(inputs=encoder_input.input, outputs=encoder_output.output)

# extract features vector
recostructed_image = model.predict(single_image_batch, verbose=0) 
extracted_features = encoder_model.predict(single_image_batch, verbose=0)
reshaped_features = extracted_features.reshape(64, 512)

plot_name = 'visual_features_vector'
for original_image, label in single_image_batch:
    validator.visualize_features_vector(original_image, reshaped_features, recostructed_image,
                                        plot_name , eRESULTS_PATH)

In [None]:
# extract features from images using the encoder output    
extractor = FeatureExtractor(model)
train_features = extractor.extract_from_encoder(processed_img_path['train'], parameters)
val_features = extractor.extract_from_encoder(processed_img_path['validation'], model)