In [None]:
from glasbey.glasbey import Glasbey
from matplotlib.colors import LinearSegmentedColormap

def get_knowledge_types_used(nodes, knowledge_types_dict):
    columns = next(iter(knowledge_types_dict.values())).keys()
    
    combs_used = []
    
    for n in nodes:
        try:
            non_nan_categories = []
            node_categories = knowledge_types_dict[n]
            
            for c in columns:
                if not (isinstance(node_categories[c], float) and np.isnan(node_categories[c])):
                    non_nan_categories.append(c)

            combs_used.append(sorted(non_nan_categories))
            
        except KeyError:
            combs_used.append([])

    return combs_used

def get_knowledge_types_used_single(node, knowledge_types_dict, columns=None):
    if not columns:
        columns = next(iter(knowledge_types_dict.values())).keys()
        
    try:
        non_nan_categories = []
        node_categories = knowledge_types_dict[node]
        
        for c in columns:
            if not (isinstance(node_categories[c], float) and np.isnan(node_categories[c])):
                for bad_char in ['[', ']', '\n', ',']:
                    c.replace(bad_char, '')
                
                c = c.lower()
                
                non_nan_categories.append(c)

        return sorted(non_nan_categories)

    except KeyError:
        return []

    
def get_knowledge_type_categories(types):
    cats = []
    
    for t in types:
        if 'procedural' in t:
            cats.append('procedural')
        elif 'conceptual' in t:
            cats.append('conceptual')
        elif 'factual' in t:
            cats.append('factual')
        else:
            print('{} is an invalid knowledge type!'.format(t))
            
    return cats

def get_distinct_cmap(n, lightness_range=(50, 100)):
    palette = os.path.join(COLOR_PALETTES, "{}_{}_colors.txt".format(n, lightness_range))
    palette_exists = False
    
    if os.path.exists(palette):
        generator = Glasbey(base_palette=palette, no_black=True)
        palette_exists = True
    else:
        generator = Glasbey(base_palette=DEFAULT_PALETTE, no_black=True, 
                            lightness_range=lightness_range)
    
    p = generator.generate_palette(size=n)
    
    if not palette_exists:
        generator.save_palette(p, palette)
    
    p_rgb = Glasbey.convert_palette_to_rgb(p)
    
    cmap_rgb = [list(map(lambda x: x/MAX_RGB, t)) for t in p_rgb]
    
    return LinearSegmentedColormap.from_list('distinct ' + str(n), cmap_rgb, N=n)

def set_tick_colors(ticks, discrete_labels, cax=None, cbar_label="Knowledge Type"):
    colors, combos, cmap = get_colors_from_dict([t.get_text() for t in ticks],
                                                discrete_labels, 
                                                lightness_range=(0, 15))
    if not colors:
        return
    
    for tick, color in zip(ticks, colors):
        tick.set_color(color)

    draw_discrete_colorbar_from_cmap(combos, cmap, cbar_label, cax=cax)

def get_colors_from_dict(nodes, color_reps_dict, lightness_range=(50, 100)):
    combs_used = get_knowledge_types_used(nodes, color_reps_dict)

    unique_combs = list(set(tuple(x) for x in combs_used))
    
    if len(unique_combs) == 1:
        return None, None, None
    
    colors = []
    
    cmap = get_distinct_cmap(len(unique_combs), lightness_range=lightness_range)

    for c in combs_used:
        colors.append(cmap(unique_combs.index(tuple(c))))
            
    return colors, unique_combs, cmap

def draw_discrete_colorbar_from_cmap(labels, cmap, cbar_label, cax=None):
    norm_nodes = mpl.colors.Normalize(vmin=-0.5, vmax=len(labels)-0.5)
    sm_nodes = plt.cm.ScalarMappable(cmap=cmap, norm=norm_nodes)
    sm_nodes.set_array([])

    ticks_nodes = range(len(labels))
    
    if cax:
        cbar_nodes = plt.colorbar(sm_nodes, ticks=ticks_nodes, cax=cax)
    else:
        cbar_nodes = plt.colorbar(sm_nodes, ticks=ticks_nodes)

    cbar_nodes.set_ticklabels(list_as_pretty(labels))

    cbar_nodes.ax.set_ylabel(cbar_label)
    

def list_as_pretty(l):
    pretty = []
        
    for i in range(len(l)):
        s = ""
        
        if not l[i]:
            s = "None"

        else:
            for j in range(len(l[i])):
                add_to_string = l[i][j].strip()

                if j < len(l[i]) - 1:
                    add_to_string += ', '

                s += add_to_string

        pretty.append(s)
        
    return pretty

def ticks_and_texts(vmax, n_ticks=5):
    tick_values = np.linspace(0, vmax, n_ticks)
    tick_texts = [(str(i) + "+" if i >= vmax else str(i)) for i in tick_values]
    return (tick_values, tick_texts)

def set_red_text_for_workshops(ticklabels):
    # Highlight workshops
    for t in ticklabels:
        date = get_date(t.get_text())
    
        if date in WORKSHOPS:
            t.set_color("red")
            

def plot_broken_y_bar(df, *, lims, xlabel, ylabel, ylabel_loc, figsize, breakline_len, width=0.8, align='center'):
    #Create two subplots to simulate a vertical axis break
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)

    plt.subplots_adjust(hspace=0.1)

    ax2.set_ylim(lims[0])
    ax1.set_ylim(lims[1])

    for ax in [ax1, ax2]:
        df.plot(kind='bar', ax=ax, width=width, align=align, ec='white')

        # Remove the background
        ax.patch.set_visible(False)

        # Set left spines to be visible
        ax.spines['left'].set_color('black')
        
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        # Removes the legend
        ax.get_legend().remove()
        
    ax1.spines['bottom'].set_visible(False)
    ax1.xaxis.set_tick_params(length=0, labelsize=0, which='both')

    # Only want bottom spine of the second graph to be visible
    ax2.spines['bottom'].set_color('black')
    
    #height_pixels = figsize[1]*plt.gcf().dpi
    #plt.text(0, height_pixels/2, ylabel, rotation=90, transform=None)
    
    # Adding labels for subplots is a pain, so here's a hacky way
    
    t = ax2.text(ylabel_loc[0], ylabel_loc[1], ylabel, rotation=90)

    title = xlabel + " vs " + ylabel
    
    ax1.set_title(title)
    ax2.set_xlabel(xlabel)

    # Magic from https://matplotlib.org/examples/pylab_examples/broken_axis.html
    # This makes the break lines
    kwargs = dict(transform=ax1.transAxes, color='k', clip_on=False)
    ax1.plot((-breakline_len, +breakline_len), (-breakline_len, +breakline_len), **kwargs)
    kwargs.update(transform=ax2.transAxes)  
    ax2.plot((-breakline_len, +breakline_len), (1 - breakline_len, 1 + breakline_len), **kwargs)
    
    save_or_display(title)

def get_transition_counts(data, keep_reloads=False, session_threshold=TWO_HOURS):
    occurences = get_durations(data, session_threshold=session_threshold).groupby(['from', 'to'])

    counts = occurences.size().to_frame(name='count').reset_index()
    
    return result_of_keep_reloads(keep_reloads, counts)

def get_transitions_to_and_from(counts):
    dict_to = {}
    
    for resource, rows in counts.groupby('to'):
        dict_to[resource] = rows
    
    for resource, rows in counts.groupby('from'):
        try:
            yield resource, pd.concat([rows, dict_to[resource]])
        except KeyError:
            yield resource, rows