# Explanations

Explain trained models ('model_names') on 100 samples per class from a set of datasets ('dataset_paths'). 

- Trained model weights loaded from 'models_save_dir'. 

- Evaluation results in 'results_dir' are used to select 100 positive samples of every class.

- Explanations are stored in subfolders under 'results_dir'.

In [2]:
import os
from tqdm import tqdm
import tensorflow as tf
import shutil

from ekman_expressions.nets import getNetByName
from ekman_expressions.explanations import save_positives_per_class, apply_lime

# Allow VRAM use growth, instead of using it all
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
# Models
model_names = ['SilNet', 'WeiNet', 'AlexNet', 'SongNet', 'InceptionV3',
               'VGG19', 'VGG16', 'ResNet50', 'ResNet101V2', 'Xception',
               'MobileNetV3Large', 'EfficientNetV2B0']

# Number of k-cross validations and folder where they are located
# Alternatively set the paths to the target training and test manually
K = 5
dataset_paths_root = '../datasets/'
dataset_paths_train = []
dataset_paths_test = []
for i in range(K):
    dataset_paths_train.append(dataset_paths_root + 'CV' + str(i+1))
    dataset_paths_test.append(dataset_paths_root + 'CV' + str(i+1) + '_test')

# Folder where the trained models were saved
models_save_dir = '../models'

# Folder where the evaluation results are saved, and where to save the explanations
results_dir = '../results'

# Labels of the classes
label_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise']

# Explanation method
expl_method = 'lime'

In [None]:
# Init progress bar
progress = tqdm(total=len(model_names)*len(dataset_paths_train)*len(label_names)*100)

for model_name in model_names:
    
    print(model_name)
    
    for train_path, test_path in zip(dataset_paths_train, dataset_paths_test):
        print(test_path)

        # Results foder
        results_net = os.path.join(results_dir, model_name + '_'
                                   + os.path.basename(train_path)
                                   + '_results')
        if not os.path.exists(results_net):
            os.mkdir(results_net)

        # Load model
        model_save_path = os.path.join(models_save_dir, model_name + '_' + os.path.basename(train_path))
        if os.path.exists(model_save_path + '_model.h5'):
            model_save_path = model_save_path + '_model.h5'
            _, img_size = getNetByName(model_name)
            model = tf.keras.models.load_model(model_save_path)
        elif os.path.exists(model_save_path + '_weights.h5'):
            model_save_path = model_save_path + '_weights.h5'
            model, img_size = getNetByName(model_name)
            model.load_weights(model_save_path)

        # Test file
        test_file = os.path.join(results_net, 'test.json')

        # Copy 100 positives per class
        imgs100 = os.path.join(results_net, 'imgs_100')
        if os.path.exists(imgs100):
            shutil.rmtree(imgs100)
        os.mkdir(imgs100)
        save_positives_per_class(test_path, test_file, label_names, imgs100, n_positives=100)

        # Explain each positive with LIME
        for class_dir in os.listdir(imgs100):
            print(class_dir)

            # Path of folder to explain
            path_dir = os.path.join(imgs100, class_dir, 'positives')

            # Path where to store masks
            exp_img_path = os.path.join(imgs100, class_dir, expl_method+'_explanations')
            if not os.path.exists(exp_img_path):
                os.mkdir(exp_img_path)

            # Path where to store LIME explanations
            exp_mask_path = os.path.join(imgs100, class_dir, expl_method+'_masks')
            if not os.path.exists(exp_mask_path):
                os.mkdir(exp_mask_path)

            # Path where to store segmentations
            if expl_method == 'lime':
                exp_seg_path = os.path.join(imgs100, class_dir, 'lime_seg')
                if not os.path.exists(exp_seg_path):
                    os.mkdir(exp_seg_path)

            # Iterate over each image
            for img_name in os.listdir(path_dir):
                # Run LIME explanation
                if expl_method == 'lime':
                    apply_lime(os.path.join(path_dir, img_name), img_size, model, exp_img_path, exp_mask_path, exp_seg_path,
                             hide_color=0, num_samples=1000, label=int(class_dir),
                             th=None, top_k=None, min_accum=None, improve_background=False, pos_only=True, neg_only=False,
                             hist_stretch=True, invert=True)
                    
                # Update progress
                progress.update(1)

# Close progress
progress.close()