In [None]:
# This notebook generates barplot with evaluation metrics for all groups specified in groups_eval variable. 


In [None]:
basic_metrics = {('wtkappa', 'trim'): [0.7],
                 ('corr', 'trim'): [0.7],
                 ('DSM', 'trim_round'): [0.1, -0.1],
                 ('DSM', 'trim'): [0.1, -0.1],
                 ('R2', 'trim'): [],
                 ('RMSE', 'trim'): []}

colprefix = 'scale' if use_scaled_predictions else 'raw'
metrics = dict([('{}.{}_{}'.format(k[0], colprefix, k[1]), v) for k,v in basic_metrics.items()])
num_metrics = len(metrics)

for group in groups_eval:
    display(Markdown('### Evaluation by {}'.format(group)))
    
    eval_group_file = join(output_dir, '{}_eval_by_{}.{}'.format(experiment_id, group, file_format))
    df_eval_group_all = DataReader.read_from_file(eval_group_file, index_col=0)
    
    df_eval_group_all.index.name = group
    df_eval_group_all.reset_index(inplace=True)
    
    # If we have threshold per group, apply it now. Keep "All data" in any case. 
    if group in min_n_per_group:
        display(Markdown("The report only shows the results for groups with "
                         "at least {} responses in the evaluation set.".format(min_n_per_group[group])))
        
        df_eval_group = df_eval_group_all[(df_eval_group_all['N'] >= min_n_per_group[group]) |
                                         (df_eval_group_all[group] == 'All data')].copy()
    else:
        df_eval_group = df_eval_group_all.copy()


    # Define the order of the bars: put 'All data' first and 'No info' last.
    group_levels = list(df_eval_group[group])
    group_levels = [level for level in group_levels if level != 'All data']
    
    # We only want to show the report if we have anything other than All data
    if len(group_levels) > 0:
        
        if 'No info' in group_levels:
            bar_names = ['All data'] + [level for level in group_levels if level != 'No info'] + ['No info']
        else:
            bar_names = ['All data'] + group_levels

        fig = plt.figure()
        (figure_width, 
         figure_height, 
         num_rows, 
         num_columns, 
         wrapped_bar_names) = compute_subgroup_plot_params(bar_names, num_metrics)

        fig.set_size_inches(figure_width, figure_height)
        with sns.axes_style('white'), sns.plotting_context('notebook', font_scale=1.2):
            for i, metric in enumerate(sorted(metrics.keys())):
                df_plot = df_eval_group[[group, metric]]
                ax = fig.add_subplot(num_rows, num_columns, i + 1)
                for lineval in metrics[metric]:
                    ax.axhline(y=float(lineval), linestyle='--', linewidth=0.5, color='black')
                sns.barplot(x=df_plot[group], y=df_plot[metric], color='grey', ax=ax, order=bar_names)
                ax.set_xticklabels(wrapped_bar_names, rotation=90) 
                ax.set_xlabel('')
                ax.set_ylabel('')

                # set the y-limits of the plots appropriately
                if metric.startswith('corr') or metric.startswith('wtkappa'):
                    if df_plot[metric].min() < 0:
                        y_limits = (-1.0, 1.0)
                        ax.axhline(y=0.0, linestyle='--', linewidth=0.5, color='black')
                    else:
                        y_limits = (0.0, 1.0)
                    ax.set_ylim(y_limits)
                elif metric.startswith('R2'):
                    min_value = df_plot[metric].min()
                    if min_value < 0:
                        y_limits = (min_value - 0.1, 1.0)
                        ax.axhline(y=0.0, linestyle='--', linewidth=0.5, color='black')
                    else:
                        y_limits = (0.0, 1.0)
                    ax.set_ylim(y_limits)
                elif metric.startswith('RMSE'):
                    max_value = df_plot[metric].max()
                    y_limits = (0.0, max(max_value + 0.1, 1.0))
                    ax.set_ylim(y_limits)
                elif metric.startswith('DSM'):
                    min_value = df_plot[metric].min()
                    if min_value < 0:
                        ax.axhline(y=0.0, linestyle='--', linewidth=0.5, color='black')

                # set the title
                ax.set_title('{} by {}'.format(metric, group))

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            plt.tight_layout(h_pad=1.0)

        imgfile = join(figure_dir, '{}_eval_by_{}.svg'.format(experiment_id, group))
        plt.savefig(imgfile)

        if use_thumbnails:
            show_thumbnail(imgfile, next(id_generator))
        else:
            plt.show()
            
    else:
        display(Markdown("None of the groups in {} had {} or more responses.".format(group,
                                                                                    min_n_per_group[group])))