## Visualize the predictions

This Jupyter notebook is used to visualize the predictions on test data; in particular, it is used to generate the histogram of predictions on each position.

First, let's import packages

In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt

In [2]:
def make_histogram(input_file, save_folder, layer_num):
    df = pd.read_csv(input_file)
    y_true_list = df['Y_True'].unique()
    for cur_y_true in y_true_list:
        cur_df = df[df['Y_True']==cur_y_true]
        cur_y_true_np = cur_df['Y_True'].values
        cur_y_pred_np = cur_df['Y_Prediction'].values
        #----- draw histogram of Y_True and Y_Prediction
        fig_title = 'The histogram of prediction on {} (NN-{})'.format(np.round(cur_y_true,3), layer_num)
        fig_name = os.path.join(save_folder, 'NN_{}_{}.png'.format(layer_num, np.round(cur_y_true,3)))
        plt.figure()
        plt.hist(cur_y_true_np, bins=1, label='Ground truth', alpha=0.5)
        plt.hist(cur_y_pred_np, bins=10, label='Predictions', alpha=0.5)
        plt.title(fig_title)
        plt.xlabel('Position')
        plt.ylabel('Sample number')
        plt.legend()
        plt.savefig(fig_name, dpi=300)
        plt.close()

Below, we start to run our method

In [3]:
if __name__ == '__main__':
    
    #------------ set up parameters
    layer_num = [2, 5, 10]
    hidden_dim = [32]
    # the directory of test results
    # "DNN_Models/DNN_Regression/deepnn_results/test_results"
    root_dir = os.path.join('../','src','DNN_Regression','deepnn_results','test_results')
    
    save_folder = 'hist_results'
    if not os.path.exists('hist_results'):
        os.makedirs('hist_results')
    else:
        pass
    
    for cur_layer in layer_num:
        for cur_hd in hidden_dim:
            # predictions of training set
            tst_pred_file = os.path.join(root_dir, 'hos_test_prediction_L{}_H{}.csv'.format(cur_layer, cur_hd))
            make_histogram(tst_pred_file, save_folder, cur_layer)
    print('')
    print('>>>Congrats! Figures have been saved!')


>>>Congrats! Figures have been saved!
