# Evaluate ML framework for neural network, explanations using LRP

In [None]:
import csv
import pickle
import time
import warnings
from collections import defaultdict, Counter
from datetime import timedelta
from pathlib import Path
import copy
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline
from imblearn.under_sampling import RandomUnderSampler
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from tabulate import tabulate
from tqdm import tqdm
import mlxai4cat.utils.LRP_tools as LRP
from mlxai4cat.models.neural_network import NeuralNetwork, ModifiedNeuralNetwork
from mlxai4cat.utils.nn_training import train_epoch, val_epoch
from mlxai4cat.utils.data import prepare_dataset, stratified_sampling, resampling, get_test_data_loader, get_xval_data_loaders
from mlxai4cat.utils.visualization import get_formatted_results, plot_feature_importance, plot_feature_importance_distribution, custom_palette
from mlxai4cat.models.generative import generate_catalysts_from_relevance_scores, catalyst_string_to_numpy, numpy_to_catalyst_string
from mlxai4cat.utils.LRP_tools import LRPAnalyzer

warnings.filterwarnings('ignore')
torch.manual_seed(0)

In [None]:
%load_ext autoreload
%autoreload 2

### Storing information

In [None]:
storing_path = Path('../results')
figure_path = Path('../figures')
SAVE = True

## Load data

In [None]:
_, X, y, X_pos, y_pos, X_neg, y_neg, feature_names = prepare_dataset('../data/ocm_cat_data.csv')

## Cross validation setup

In [None]:
num_layers = [2, 3, 4]
num_neurons_per_layers = {
    1: [[36, 16, 2], [36, 32, 2], [36, 64, 2], [36, 128, 2]], 
    2:[[36, 16, 16, 2], [36, 32, 32, 2], [36, 64, 64, 2], [36, 128, 128, 2]],
    3: [[36, 16, 16, 16, 2], [36, 32, 32, 32, 2], [36, 64, 64, 64, 2], [36, 128, 128, 128, 2]]
}
dropout_rates = [0, 0.1]
#dropout_rates = [0, 0.05, 0.1, 0.15, 0.2]
#lr = [1e-2, 1e-3, 1e-4, 1e-5]
#wd = [0, 1e-4, 1e-5]
lr = [1e-3]
wd = [0.5e-2]
all_combs = []

for i in range(len(num_layers)):
    for v in list(num_neurons_per_layers.values())[i]:
        for p in dropout_rates:
            for l in lr:
                for w in wd:
                    combs = []
                    combs.append(num_layers[i])
                    combs.append(v)
                    combs.append(p)
                    combs.append(l)
                    combs.append(w)
                    all_combs.append(combs)

print(f"All combinations ({len(all_combs)}): \n{all_combs}")

#all_combs = all_combs[0:50]  # TODO: remove it was just for testing

## Neural network with resampling

### Explanations based on positive class

In [None]:
with_resampling = True

In [None]:
# This explanation is based on the positive class [0,1]
criterion = nn.BCELoss()
acc = 0
k = 5 # cross-validation folds
n = 100
n_iter = 20
patience = 10  
verbose = False
# instantiate variables for ease of use later
artificial_neuron=True

### Training and nested cross-validation 

In [None]:
acc_mlp_g = []
precision_mlp_g = []
recall_mlp_g = []
f1_mlp_g = []


selected_models_max_g = defaultdict(list)  # For storing the selected model architecture and its F1 score for each split
selected_models_g = defaultdict(list)  # For storing the all model architecture and its F1 score for each split
selected_models_counts_g = Counter()  # For storing counts and F1 scores of selected models across splits
selected_models = [] # store best models as objects
split_test_data = {}  # Store confusion information and relevance scores for each split

for rs in range(n):
    print(f"Split {rs}")
    start_time = time.time()
    ### Get the data loaders for the current split

    X_train, y_train, test_loader = get_test_data_loader(X_pos, X_neg, y_pos, y_neg, rs)

    ### Iterate over all combinations of hyperparameters for given split and select the best-performing model
    max_f1_model = None  # Model with max validation F1 score
    max_f1_val = 0  # Max validation F1 score

    convergences = []
    idx_max_f1_model = 0
    for c, comb in enumerate(all_combs[:5]):
        if c % 50 == 0:
            print(f"> Combination {c}/{len(all_combs)}")

        num_layers, num_neurons_per_layer, dropout_rate, lr, wd = comb

        model = NeuralNetwork(num_layers, num_neurons_per_layer, dropout_rate, artificial_neuron=act_rel)
        if torch.cuda.is_available() and use_gpu:
            model = model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), weight_decay=wd, lr=lr)
        criterion = nn.BCELoss()

        xval_f1_scores = []

        for k_i in range(k):
            train_loader, val_loader = get_xval_data_loaders(k, k_i, X_train, y_train,
                                                             with_resampling=with_resampling,
                                                             verbose=verbose)
            early_stopping_counter = 0
            best_val_f1 = 0
            f1_val = 0
            iteration = 0

            # iterate n_iter epochs
            for iteration in tqdm(range(n_iter), desc=f"Training for max {n_iter} epochs", leave=False):

                model.train()
                train_epoch(train_loader, model, criterion, optimizer)

                model.eval()
                val_pred, val_gt = val_epoch(val_loader, model)
        
                f1_val = f1_score(np.array(val_gt), np.array(val_pred))
        
                if f1_val > best_val_f1:
                    best_val_f1 = f1_val
                    early_stopping_counter = 0
                else:
                    early_stopping_counter += 1
        
                if early_stopping_counter >= patience:
                    break
            xval_f1_scores.append(best_val_f1)
        
        best_val_f1 = np.mean(xval_f1_scores)

        # Store epoch that the model converged
        convergences.append(iteration)
        # Store the selected model architecture and its F1 score for this split
        model_architecture = str(model)  # Convert the model architecture to a string
        selected_models_g[model_architecture].append({'f1_score': f1_val, 'hyperparams': comb})

        # Store the model for given combination if it has a higher F1 score
        if best_val_f1 >= max_f1_val:
            max_f1_model = copy.deepcopy(model)
            max_f1_val = best_val_f1
            max_f1_comb = comb
            idx_max_f1_model = c

    ### Store model architecture that achieved the highest F1 score over all combinations for 1 split
    model_architecture_max = str(max_f1_model)  # Convert the model architecture to a string
    selected_models_max_g[model_architecture_max].append({"f1_score": max_f1_val, "hyperparams": max_f1_comb})
    selected_models.append((max_f1_val, max_f1_model))
    # Increase counter for the selected model architecture
    selected_models_counts_g[model_architecture_max] += 1
    max_f1_model.cpu()

    ### Evaluate the best-performing model on the test set.
    max_f1_model.eval()
    modified_model = ModifiedNeuralNetwork(max_f1_model)

    pred, gt, probs, rels, confusion_scores, confusion_idxs = modified_model.inference_with_relevance(test_loader, reweight_explanation=True,
                                                                                               relevance_on_positive_class=True)
    gt = np.array(gt)
    pred = np.array(pred)
    rels = np.array(rels)
    probs = np.array(probs)

    split_test_data[rs] = {
        'pred': pred,
        'gt': gt,
        'rels': rels,
        'probs' : probs,
        # rels is a dict, containing the keys R_on_pred and R_on_pos_cls, each of which is a list of numpy arrays
        'confusion_scores': confusion_scores,
        # confusion_scores is a dict, containing the keys true_pos_scores, true_neg_scores, false_pos_scores, and false_neg_scores, each of which is a list of floats
        'confusion_idxs': confusion_idxs,
        # confusion_idxs is a dict, containing the keys true_pos_idx, true_neg_idx, false_pos_idx, and false_neg_idx, each of which is a list of integers
        'convergence_times': convergences,
        # store number of epochs until convergence for each combination of hyperparameters
        'convergence_time_max_f1_model': convergences[idx_max_f1_model]
        # store number of epochs until convergence for the model with the highest F1 score
    }

    acc_mlp_g.append(accuracy_score(gt, pred))
    precision_mlp_g.append(precision_score(gt, pred, zero_division=1))
    recall_mlp_g.append(recall_score(gt, pred))
    f1_mlp_g.append(f1_score(gt, pred))

    print(f"> Total computation time for split {rs}: {timedelta(seconds=(time.time() - start_time))}\n")

### Store all results

In [None]:
if SAVE:
    with open(storing_path / f'all_results_resampling_{with_resampling}.pkl', 'wb') as f:
        pickle.dump(split_test_data, f)
        
    with open(storing_path / f'max_f1_models_resampling_{with_resampling}.pkl', 'wb') as f:
        pickle.dump(selected_models_max_g, f)

### Display and save the metrics

In [None]:
df_metrics = get_formatted_results(acc_mlp_g, f1_mlp_g, precision_mlp_g, recall_mlp_g, 'Neural Networks', verbose=True)
if SAVE:
    file_path = f'mlp_metrics_results_csv'
    ## SAVING ANALYSIS RESULTS
    df_metrics.to_csv(os.path.join(storing_path, 'NN_metrics_results.csv'), index=False)


### Signed and absolute average feature importances

In [None]:
analyzer = LRPAnalyzer(np.array([split_test_data[i]['rels'] for i in split_test_data.keys()]).reshape(-1, len(feature_names)), feature_names)
analyzer.calculate_mean_lrp_scores()
analyzer.calculate_mean_abs_lrp_scores()
analyzer.plot_lrp_scores(os.path.join(figure_path, 'sorted_mean_lrp_NN_GI.png'))
analyzer.plot_abs_lrp_scores(os.path.join(figure_path, 'sorted_mean_abs_lrp_NN_GI.png'))
analyzer.save_scores_to_csv(os.path.join(storing_path, 'sorted_mean_lrp_NN.csv'), os.path.join(storing_path, 'sorted_mean_abs_lrp_NN.csv'))

In [None]:
# collect and reshape relevances for plotting
plt_rels = np.concatenate([split_test_data[i]['rels'] for i in split_test_data.keys()], 1)

plot_feature_importance_distribution(np.abs(plt_rels).mean(0), feature_names, 'NN', color='gray', savedir=figure_path)

## Neural network without resampling

### Explanations based on positive class

In [None]:
with_resampling = False

In [None]:
# This explanation is based on the positive class [0,1]
criterion = nn.BCELoss()
acc = 0
k = 5 # cross-validation folds
n = 3
n_iter = 3 
patience = 10  
verbose = False
# instantiate variables for ease of use later
artificial_neuron=True

### Training and nested cross-validation 

In [None]:
acc_mlp_g_nr = []
precision_mlp_g_nr = []
recall_mlp_g_nr = []
f1_mlp_g_nr = []


selected_models_max_g_nr = defaultdict(list)  # For storing the selected model architecture and its F1 score for each split
selected_models_g_nr = defaultdict(list)  # For storing the all model architecture and its F1 score for each split
selected_models_counts_g_nr = Counter()  # For storing counts and F1 scores of selected models across splits
selected_models_nr = [] # store best models as objects
split_test_data_nr = {}  # Store confusion information and relevance scores for each split

for rs in range(n):
    print(f"Split {rs}")
    start_time = time.time()
    ### Get the data loaders for the current split

    X_train, y_train, test_loader = get_test_data_loader(X_pos, X_neg, y_pos, y_neg, rs)

    ### Iterate over all combinations of hyperparameters for given split and select the best-performing model
    max_f1_model_nr = None  # Model with max validation F1 score
    max_f1_val = 0  # Max validation F1 score

    convergences = []
    idx_max_f1_model = 0
    for c, comb in enumerate(all_combs[:5]):
        if c % 50 == 0:
            print(f"> Combination {c}/{len(all_combs)}")

        num_layers, num_neurons_per_layer, dropout_rate, lr, wd = comb

        model = NeuralNetwork(num_layers, num_neurons_per_layer, dropout_rate, artificial_neuron=act_rel)
        if torch.cuda.is_available() and use_gpu:
            model = model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), weight_decay=wd, lr=lr)
        criterion = nn.BCELoss()

        xval_f1_scores = []

        for k_i in range(k):
            train_loader, val_loader = get_xval_data_loaders(k, k_i, X_train, y_train,
                                                             with_resampling=with_resampling,
                                                             verbose=verbose)
            early_stopping_counter = 0
            best_val_f1 = 0
            f1_val = 0
            iteration = 0

            # iterate n_iter epochs
            for iteration in tqdm(range(n_iter), desc=f"Training for max {n_iter} epochs", leave=False):

                model.train()
                train_epoch(train_loader, model, criterion, optimizer)

                model.eval()
                val_pred, val_gt = val_epoch(val_loader, model)
        
                f1_val = f1_score(np.array(val_gt), np.array(val_pred))
        
                if f1_val > best_val_f1:
                    best_val_f1 = f1_val
                    early_stopping_counter = 0
                else:
                    early_stopping_counter += 1
        
                if early_stopping_counter >= patience:
                    break
            xval_f1_scores.append(best_val_f1)

        
        best_val_f1 = np.mean(xval_f1_scores)

        # Store epoch that the model converged
        convergences.append(iteration)
        # Store the selected model architecture and its F1 score for this split
        model_architecture = str(model)  # Convert the model architecture to a string
        selected_models_g_nr[model_architecture].append({'f1_score': f1_val, 'hyperparams': comb})

        # Store the model for given combination if it has a higher F1 score
        if best_val_f1 >= max_f1_val:
            max_f1_model_nr = copy.deepcopy(model)
            max_f1_val = best_val_f1
            max_f1_comb = comb
            idx_max_f1_model = c

    ### Store model architecture that achieved the highest F1 score over all combinations for 1 split
    model_architecture_max = str(max_f1_model_nr)  # Convert the model architecture to a string
    selected_models_max_g_nr[model_architecture_max].append({"f1_score": max_f1_val, "hyperparams": max_f1_comb})
    selected_models_nr.append((max_f1_val, max_f1_model_nr))
    # Increase counter for the selected model architecture
    selected_models_counts_g_nr[model_architecture_max] += 1
    max_f1_model_nr.cpu()

    ### Evaluate the best-performing model on the test set.
    max_f1_model_nr.eval()
    modified_model = ModifiedNeuralNetwork(max_f1_model_nr)

    pred, gt, probs, rels, confusion_scores, confusion_idxs = modified_model.inference_with_relevance(test_loader, reweight_explanation=True,
                                                                                               relevance_on_positive_class=True)
    gt = np.array(gt)
    pred = np.array(pred)
    rels = np.array(rels)
    probs = np.array(probs)

    split_test_data_nr[rs] = {
        'pred': pred,
        'gt': gt,
        'rels': rels,
        'probs' : probs,
        # rels is a dict, containing the keys R_on_pred and R_on_pos_cls, each of which is a list of numpy arrays
        'confusion_scores': confusion_scores,
        # confusion_scores is a dict, containing the keys true_pos_scores, true_neg_scores, false_pos_scores, and false_neg_scores, each of which is a list of floats
        'confusion_idxs': confusion_idxs,
        # confusion_idxs is a dict, containing the keys true_pos_idx, true_neg_idx, false_pos_idx, and false_neg_idx, each of which is a list of integers
        'convergence_times': convergences,
        # store number of epochs until convergence for each combination of hyperparameters
        'convergence_time_max_f1_model': convergences[idx_max_f1_model]
        # store number of epochs until convergence for the model with the highest F1 score
    }

    acc_mlp_g_nr.append(accuracy_score(gt, pred))
    precision_mlp_g_nr.append(precision_score(gt, pred, zero_division=1))
    recall_mlp_g_nr.append(recall_score(gt, pred))
    f1_mlp_g_nr.append(f1_score(gt, pred))

    print(f"> Total computation time for split {rs}: {timedelta(seconds=(time.time() - start_time))}\n")

### Store all results

In [None]:
if SAVE:
    with open(storing_path / f'all_results_resampling_{with_resampling}.pkl', 'wb') as f:
        pickle.dump(split_test_data_nr, f)
        
    with open(storing_path / f'max_f1_models_resampling_{with_resampling}.pkl', 'wb') as f:
        pickle.dump(selected_models_max_g_nr, f)

### Display and save the metrics

In [None]:
df_metrics = get_formatted_results(acc_mlp_g_nr, f1_mlp_g_nr, precision_mlp_g_nr, recall_mlp_g_nr, 'Neural Networks', verbose=True)
if SAVE:
    file_path = f'mlp_metrics_NO_resampling_results_csv'
    
    ## SAVING ANALYSIS RESULTS
    df_metrics.to_csv(os.path.join(storing_path, 'NN_metrics_NO_Resampling_results.csv'), index=False)


### Signed and absolute average feature importances

In [None]:
analyzer = LRPAnalyzer(np.array([split_test_data_nr[i]['rels'] for i in split_test_data_nr.keys()]).reshape(-1, len(feature_names)), feature_names)
analyzer.calculate_mean_lrp_scores()
analyzer.calculate_mean_abs_lrp_scores()
analyzer.plot_lrp_scores(os.path.join(figure_path, 'sorted_mean_lrp_NN_NO_Resampling_GI.png'))
analyzer.plot_abs_lrp_scores(os.path.join(figure_path, 'sorted_mean_abs_lrp_NN_NO_Resampling_GI.png'))
analyzer.save_scores_to_csv(os.path.join(storing_path, 'sorted_mean_lrp_NN_NO_Resampling.csv'), os.path.join(storing_path, 'sorted_mean_abs_lrp_NN_NO_Resampling.csv'))

## Create single sample visualizations

### Select a random split to visualize

In [None]:
# Sample a random split from all results, i.e., random key of split_test_data dictionary
np.random.seed(1)
rs = np.random.choice(list(split_test_data.keys()))
single_sample_split_rs = rs

print(f"Randomly selected split: {rs}\n")

confusion_scores = split_test_data[rs]['confusion_scores']
confusion_idxs = split_test_data[rs]['confusion_idxs']
rels = split_test_data[rs]['rels']
print('confiusion scores shape', confusion_scores.keys())
print('rels shape', rels.shape)
#print('Keys of different subdictionaries:')
#print(confusion_scores.keys())
#print(confusion_idxs.keys())
#print(rels.keys())

### Find outlier samples, used later for plotting single sample explanations

In [None]:
def get_highest_and_lowest_scores(scores, idxs, top_k=1, bottom_k=1):
    """For each confusion category, get the highest and lowest scoring samples and their indexes."""
    # Convert score lists to numpy arrays for easier indexing
    scores = np.array(scores)
    idxs = np.array(idxs)

    # Sorting indices
    sort_idx = np.argsort(scores)

    # High and low scores
    high_scores = scores[sort_idx[-top_k:]]
    high_idxs = idxs[sort_idx[-top_k:]]

    low_scores = scores[sort_idx[:bottom_k]]
    low_idxs = idxs[sort_idx[:bottom_k]]

    return high_scores, high_idxs, low_scores, low_idxs

index_lists = []
score_lists = []
prob_lists = []
categories = []
name_mapping = {
    'true_pos_scores': 'True Positives',
    'false_pos_scores': 'False Positives',
    'true_neg_scores': 'True Negatives',
    'false_neg_scores': 'False Negatives'
}

# Desired order
desired_order = [
    'High True Positives', 'Low True Positives',
    'High False Positives', 'Low False Positives',
    'High False Negatives', 'Low False Negatives',
    'High True Negatives', 'Low True Negatives'
]

temp_categories = []
temp_index_lists = []
temp_score_lists = []
temp_prob_lists = []

for ((scores_key, scores), (idx_key, idxs)) in zip(confusion_scores.items(), confusion_idxs.items()):
    high_scores, high_idxs, low_scores, low_idxs = get_highest_and_lowest_scores(scores, idxs, top_k=1, bottom_k=1)

    temp_categories.append(f"High {name_mapping[scores_key]}")
    temp_index_lists.append(high_idxs)
    temp_score_lists.append(high_scores)

    temp_categories.append(f"Low {name_mapping[scores_key]}")
    temp_index_lists.append(low_idxs)
    temp_score_lists.append(low_scores)

# Sorting the results to match the desired order
for category in desired_order:
    index = temp_categories.index(category)
    categories.append(temp_categories[index])
    index_lists.append(temp_index_lists[index])
    score_lists.append(temp_score_lists[index])

### Visualize the relevance scores for the highest and lowest scoring samples per category

In [None]:
# figsize=(8.27, 11.69) is A4 paper size -> lets make it a bit smaller

fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(8, 10.95))
axes = axes.flatten()

# Iterate over all categories in the confusion matrix
for k, (index_list, score_list, category) in enumerate(zip(index_lists, score_lists, categories)):
    idx = index_list[0]
    score = score_list[0]
    # rels_single_sample = rels_list[idx]  # Is a list with feature relevance scores for given sample
    rels_single_sample = rels[idx].squeeze()

    df_single_sample = pd.DataFrame({
        'Feature': feature_names,
        'Importance Score': rels_single_sample
    }).sort_values(by='Importance Score', ascending=True)

    # Find the index of the first zero value
    zero_index = df_single_sample['Importance Score'].eq(0).idxmax()

    # skip zero relevance elements to save on space for visualization
    df_single_sample = df_single_sample[
        ~((df_single_sample['Importance Score'] == 0) & (df_single_sample.index != zero_index))]
    df_single_sample.loc[zero_index, 'Feature'] = '[...]'

    palette = custom_palette(df_single_sample['Importance Score'])

    # plot relevances for sample
    sns.barplot(x='Importance Score', y='Feature', data=df_single_sample, palette=palette, ax=axes[k])
    axes[k].set_title(f"{category} - Sample {idx} (Score: {score:.2f})", fontsize=12)
    axes[k].set_xlabel('Importance Score')
    axes[k].set_ylabel('Features')
    # fig.suptitle(
    #     f"Feature Importance Scores for Highest and Lowest Scoring Samples\n(Relevances computed w.r.t. {'prediction' if 'pred' in rel_key else 'the positive class'})",
    #     fontsize=13)
plt.tight_layout(rect=[0, 0, 1, 0.97])
# save figure
if SAVE:
    plt.savefig(
        figure_path / f"LRP_on_pos_class_samples_per_cat.png",
        dpi=300, facecolor=(1, 1, 1, 0),
        bbox_inches='tight')
plt.show(fig)

## Generating catalyst candidates using LRP importances for neural network

In [None]:
#load feature importances
df_feature_importance_nn = pd.read_csv(os.path.join(storing_path, 'sorted_mean_lrp_NN.csv'))

In [None]:
el_factor = 20
supp_factor = 2
print(df_feature_importance_nn)

candidates = generate_catalysts_from_relevance_scores(df_feature_importance_nn['Importance Score'].to_numpy(),
                                                      df_feature_importance_nn['Feature'].to_list(),
                                                      num_candidates=1000,
                                                      elem_importance_factor=el_factor,
                                                      supp_importance_factor=supp_factor,
                                                     )

### Convert candidates to numpy features, remove duplicates

In [None]:
feats = catalyst_string_to_numpy(candidates,
                                    df_feature_importance_rebalanced['Feature'].to_list(),
                                    remove_duplicates=True)

# classify with a single model to see the fraction of samples that would be classified as high-yield
logits = max_f1_model(torch.from_numpy(feats).float())
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)  # Get the predicted class
#print('prediction', y_pred)
print('fraction of samples predicted as true', float(torch.sum(y_pred) / y_pred.shape[0]))

### Remove diplicates that appear in OCM dataset

In [None]:
X = np.concatenate([X_pos,X_neg], axis=0)

feats_view = np.ascontiguousarray(feats).view([('', feats.dtype)] * feats.shape[1])
X_view = np.ascontiguousarray(X).view([('', X.dtype)] * X.shape[1])

# Perform set difference on the rows
feats_diff = np.setdiff1d(feats_view, X_view)

# Convert back to the original array format
feats_new = feats_diff.view(feats.dtype).reshape(-1, feats.shape[1])

print("Unique generated candidates that do not appear in training set", feats_new.shape[0])

### Select the top N neural network models from the set of best models of each split in order to evaluate catalyst candidates

In [None]:
N_models = 3
selected_model_scores = np.array([selected_models[i][0] for i in range(len(selected_models))])
best_scores_idx = np.argsort(-selected_model_scores)[:N_models]
best_scores = [selected_models[i][0] for i in best_scores_idx]
best_models = [selected_models[i][1] for i in best_scores_idx]
print("Sorted scores of best models from each split", best_scores)

### Select the top 20 candidates based on the average scores from the NN models

In [None]:
all_logits = torch.stack([max_model(torch.from_numpy(feats_new).float()) for max_model in best_models], dim=0)

logits = all_logits.mean(dim=0)

pred_probab = nn.Softmax(dim=1)(logits)
all_pred_probab = nn.Softmax(dim=-1)(all_logits)

top_probab_idx = np.argsort(-(pred_probab[:, 1]).numpy(force=True))
top_prob_feats = feats_new[top_probab_idx[:20]]

cand_new = numpy_to_catalyst_string(top_prob_feats, feature_names)
print('List of top 20 generated promising catalyst candidates', cand_new)