In [None]:
if len(groups_desc) > 0:
    markdown_str = ["## Differential feature functioning"]
    markdown_str.append("This section shows differential feature functioning (DFF) plots "
                        "for all features and subgroups. The features are shown after applying "
                        "transformations (if applicable) and truncation of outliers.")
    display(Markdown('\n'.join(markdown_str)))

In [None]:
# check if we already created the merged file in another notebook

try:
    df_train_preproc_merged
except NameError:
    df_train_preproc_merged = pd.merge(df_train_preproc, df_train_metadata, on = 'spkitemid')

for group in groups_desc:
    display(Markdown("### DFF by {}".format(group)))
    
    if group in min_n_per_group:
        display(Markdown("The report only shows the results for groups with "
                         "at least {} responses in the training set.".format(min_n_per_group[group])))
        
        category_counts = df_train_preproc_merged[group].value_counts()
        selected_categories = category_counts[category_counts >= min_n_per_group[group]].index
        
        df_train_preproc_selected = df_train_preproc_merged[df_train_preproc_merged[group].isin(selected_categories)].copy()
    else:
        df_train_preproc_selected = df_train_preproc_merged.copy()
    
    
    if len(df_train_preproc_selected) > 0:
        
        # we need to reduce col_wrap and increase width if the feature names are too long
        if longest_feature_name > 10:
            col_wrap = 2
            # adjust height to allow for wrapping really long names. We allow 0.25 in per line
            height = 2+(math.ceil(longest_feature_name/30)*0.25)
            aspect = 5/height
            # show legend near the second plot in the grid
            plot_with_legend = 1
        else:
            height=3
            col_wrap = 3
            aspect = 1
            # show the legend near the third plot in the grid
            plot_with_legend = 2
        
        selected_columns = ['spkitemid', 'sc1'] + features_used + [group]
        df_melted = pd.melt(df_train_preproc_selected[selected_columns], id_vars=['spkitemid', 'sc1', group], var_name='feature')
        group_values = sorted(df_melted[group].unique())
        colors = sns.color_palette("Greys", len(group_values))
        with sns.axes_style('whitegrid'), sns.plotting_context('notebook', font_scale=1.2):
            p = sns.catplot(x='sc1', y='value', hue=group, hue_order = group_values,
                            col='feature', col_wrap=col_wrap, height=height, aspect=aspect,
                            scale=0.6,
                            palette=colors,
                            sharey=False, sharex=False, legend=False, kind="point",
                            data=df_melted)

            for i, axis in enumerate(p.axes):
                axis.set_xlabel('score')
                if i == plot_with_legend:
                    legend = axis.legend(group_values, title=group, 
                                         frameon=True, fancybox=True, 
                                         ncol=1, fontsize=10,
                                         loc='upper right', bbox_to_anchor=(1.75, 1))
                    for j in range(len(group_values)):
                        legend.legendHandles[j].set_color(colors[j])
                    plt.setp(legend.get_title(), fontsize='x-small')
            
            for ax, cname in zip(p.axes, p.col_names):
                ax.set_title('\n'.join(wrap(str(cname), 30)))


            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                plt.tight_layout(h_pad=1.0)

            imgfile = join(figure_dir, '{}_dff_{}.svg'.format(experiment_id, group))
            plt.savefig(imgfile)
            if use_thumbnails:
                show_thumbnail(imgfile, next(id_generator))
            else:
                plt.show()
    else:
        display(Markdown("None of the groups in {} had {} or more responses.".format(group,
                                                                                    min_n_per_group[group])))