In [None]:
import constants as c

def plot_heatmaps(dfs, mod_figs, func, only_modded=False, title_addition=""):
    """Helper function to plot heatmaps with individualised figure sizes

    :param dfs: The list of dataframes to plot
    :param mod_figs: A dictionary of the users to modify plots for and their figure sizes
    :param func: The heatmap plot function to call
    :param only_modded: Only plot modified figures
    :param title_addition: Additional phrase to add at the end of the title of the graphs.

    """
    for i in range(len(dfs)):
        user = dfs[i].user.iloc[0]
        
        if only_modded and user not in mod_figs.keys(): continue

        fsize = mod_figs[user] if user in mod_figs.keys() else (6, 6)

        func([dfs[i]], figsize=fsize, font_scale=1, title_addition=title_addition)
        
    
def get_counts(data, display_names=None):
    aggregated = pd.concat(data, sort=False)
    data_display_names = list(aggregated.display_name.unique())
    
    if display_names:
        for display_name in data_display_names:
            if display_name not in display_names:
                display_names.append(display_name)
    else:
        display_names = data_display_names
        sort_resources(display_names)
    
    display_name_pairs = [aggregated.display_name.iloc[i:i+2] 
                  for i in range(len(aggregated) - 1)
                  if aggregated.user.iloc[i] == 
                     aggregated.user.iloc[i+1]]
    display_name_index_pairs = [(display_names.index(n1), display_names.index(n2))
                        for n1, n2 in display_name_pairs if n1 != n2]

    func = lambda d, i, j: 1
    
    return get_2d_vals(data, display_names, func, display_name_index_pairs)


def get_counts_all_resources(data):
    return get_counts(data, resource_order)

def sort_resources(resources):
    # float('inf') just means to append to the end of the list
    resources.sort(key=lambda x: resource_order.index(x) 
                   if x in resource_order else float('inf'))
    
def get_2d_vals(data, display_names, func, loop_pairs):
    vals = np.zeros((len(display_names), len(display_names)))

    for i, j in loop_pairs:
        vals[i][j] += func(data, i, j)
        
    return display_names, vals

# TODO update get_title to something more sensible
def get_title(title_format, title_name, title_addition, df_user, unit=""):
    title_user = title_name if title_name else df_user
    addition = " " + title_addition if title_addition else ""
    
    return title_format.format(title_user, unit, addition)

def plot_2d_values_heatmap(data, *, func, xlabel, ylabel, figsize=(10, 10),
                           font_scale=1, title_name="", 
                           title_addition="", unit="Traversal", 
                           transpose=False, quantile=True,
                           fig_size_inches=None, dpi=None,
                           transpose_labels=False, no_title=False,
                           tick_label_colors=None, 
                           width_ratios=None, vmax=None, 
                           n_ticks=c.N_TICKS, int_tick_labels=False):
    """Plot the values of a 2 dimensional array based on the given function

    :param data: The data to plot.
    :param title_name: Custom user for the title of the graph. 
                       Leave as empty for single user dataframes. 
                       Useful for when passing multiple dataframes.
    :param title_addition: Additional phrase to add at the end of the title.
    :param transpose: Whether to transpose the graph.
    """                
    
    sns.set(font_scale=font_scale)
    fig = plt.figure(figsize=figsize)
    
    if fig_size_inches:
        fig.set_size_inches(fig_size_inches)
    
    names, vals = func(data)
    l = len(names)
    
    df = pd.DataFrame(data=vals, index=names, columns=names)

    if transpose: df = df.transpose()
    
    cbar_label = 'Quantile of {}s'.format(unit) if quantile else '{} Occurences'.format(unit)
    
    if tick_label_colors:
        grid = plt.GridSpec(1, 2, width_ratios=width_ratios, hspace=0.0)
        
        main_ax = fig.add_subplot(grid[0, 0])
        tick_label_cbar_ax = fig.add_subplot(grid[0, 1])
        vmin = 0
    else:
        main_ax = None
        vmin = None
    
    tick_labels = np.linspace(0, l-(l/n_ticks), n_ticks, dtype=int) if int_tick_labels else True
        
    ax = sns.heatmap(df, xticklabels=tick_labels, yticklabels=tick_labels, 
                     cmap=CMAP, cbar_kws={'label': cbar_label}, 
                     vmin=vmin, vmax=vmax, ax=main_ax)
    
    if not main_ax: main_ax = ax
        
    if isinstance(tick_labels, np.ndarray):
        l = len(names)
        ticks = np.linspace(0, l-(l/n_ticks), n_ticks, dtype=int)
        main_ax.xaxis.set_ticks(ticks)
        main_ax.yaxis.set_ticks(ticks)
    
    if not no_title:
        title = get_title("{} {} {}", title_name, title_addition, 
                      data[0].user.iloc[0], unit=unit)
    else:
        title = ""

    if tick_label_colors:
        # For some reason I can't call get_ticklabels...
        xticks = main_ax.get_xticklabels(which='both')
        ticks = main_ax.get_yticklabels(which='both')
        ticks.extend(xticks)
        
        set_tick_colors(ticks, tick_label_colors, cax=tick_label_cbar_ax)
    
    if transpose_labels: 
        main_ax.xaxis.tick_top()
        main_ax.xaxis.set_label_position('top') 
        main_ax.tick_params(axis='x', labelrotation=90)
        main_ax.xaxis.set_tick_params(which="both", length=0)
        
        main_ax.set(ylabel=xlabel, xlabel=ylabel)
    
    else:
        main_ax.set(title=title, ylabel=ylabel, xlabel=xlabel)
    
    cbar = main_ax.collections[0].colorbar
    
    max_vals = int(np.amax(vals))
    m = vmax if vmax else max_vals
    
    if quantile:
        # Set tick labels to percentages
        if len(vals) == 1:
            cbar.set_ticks([m])
            cbar.set_ticklabels(['100%'])
        else:
            len_quantiles = len(np.unique(vals.flatten()))

            if len_quantiles < n_ticks:
                n_ticks = len_quantiles
            
            tick_values = np.linspace(0, m, n_ticks, dtype=int)
            
            if vmax:
                quantile_max = int((m / max_vals) * 100)
                quantiles = np.linspace(0, quantile_max, n_ticks, dtype=int)
                ticklabels = ['{}%'.format(i) + ('+' if i >= quantile_max else '') for i in quantiles]
            else:
                ticklabels = ['{}%'.format(i) for i in np.linspace(0, 100, n_ticks)]
    
            cbar.set_ticks(tick_values)
            cbar.set_ticklabels(ticklabels)
    else:
        ticks = ticks_and_texts(m)
        
        if not vmax: ticks[1][-1] = ticks[1][-1][:-1] # Remove plus sign
        
        cbar.set_ticks(ticks[0])
        cbar.set_ticklabels(ticks[1])
    
    if not title: title = "No title"
    save_plot("Final Traversals", dpi=dpi)
    