In [2]:
import numpy as np
import pandas as pd
import hmmlearn
import matplotlib.pyplot as plt

### Section - look at the average time spent in each state over time 

In [None]:
def plot_hmm_overtime(self, hmm, variable = 'moving', labels = None, colours = None, wrapped = False, bin = 60, func = 'max', avg_window = 30, day_length = 24, lights_off = 12, title = '', grids = False):
    """
    Creates a plot of all states overlayed with y-axis shows the liklihood of being in a sleep state and the x-axis showing time in hours.
    The plot is generated through the plotly package

    Params:
    @self = behavpy_HMM,
    @hmm = hmmlearn.hmm.MultinomialHMM, this should be a trained HMM Learn object with the correct hidden states and emission states for your dataset
    @variable = string, the column heading of the variable of interest. Default is "moving"
    @labels = list[string], the names of the different states present in the hidden markov model. If None the labels are assumed to be ['Deep sleep', 'Light sleep', 'Quiet awake', 'Full awake']
    @colours = list[string], the name of the colours you wish to represent the different states, must be the same length as labels. If None the colours are a default for 4 states (blue and red)
    It accepts a specific colour or an array of numbers that are acceptable to plotly
    @wrapped = bool, if True the plot will be limited to a 24 hour day average
    @bin = int, the time in seconds you want to bin the movement data to, default is 60 or 1 minute
    @func = string, when binning to the above what function should be applied to the grouped data. Default is "max" as is necessary for the "moving" variable
    @avg_window, int, the window in minutes you want the moving average to be applied to. Default is 30 mins
    @circadian_night, int, the hour when lights are off during the experiment. Default is ZT 12
    @save = bool/string, if not False then save as the location and file name of the save file

    returns None
    """
    assert isinstance(wrapped, bool)

    df = self.copy(deep = True)

    labels, colours = self._check_hmm_shape(hm = hmm, lab = labels, col = colours)

    states_list, time_list = self._hmm_decode(df, hmm, bin, variable, func)

    df = pd.DataFrame()
    for l, t in zip(states_list, time_list):
        tdf = hmm_pct_state(l, t, list(range(len(labels))), avg_window = int((avg_window * 60)/bin))
        df = pd.concat([df, tdf], ignore_index = True)

    if wrapped is True:
        df['t'] = df['t'].map(lambda t: t % (60*60*day_length))

    df['t'] = df['t'] / (60*60)
    t_min = int(12 * floor(df.t.min() / 12))
    t_max = int(12 * ceil(df.t.max() / 12))    
    t_range = [t_min, t_max]  

    fig = go.Figure()
    self._plot_ylayout(fig, yrange = [-0.025, 1.01], t0 = 0, dtick = 0.2, ylabel = 'Probability of being in state', title = title, grid = grids)
    self._plot_xlayout(fig, xrange = t_range, t0 = 0, dtick = day_length/4, xlabel = 'ZT (Hours)')

    for c, (col, n) in enumerate(zip(colours, labels)):

        column = f'state_{c}'

        gb_df = df.groupby('t').agg(**{
                    'mean' : (column, 'mean'), 
                    'SD' : (column, 'std'),
                    'count' : (column, 'count')
                })

        gb_df['SE'] = (1.96*gb_df['SD']) / np.sqrt(gb_df['count'])
        gb_df['y_max'] = gb_df['mean'] + gb_df['SE']
        gb_df['y_min'] = gb_df['mean'] - gb_df['SE']
        gb_df = gb_df.reset_index()

        upper, trace, lower, _ = self._plot_line(df = gb_df, x_col = 't', name = n, marker_col = col)
        fig.add_trace(upper)
        fig.add_trace(trace) 
        fig.add_trace(lower)

    # Light-Dark annotaion bars
    bar_shapes, min_bar = circadian_bars(t_min, t_max, max_y = 1, day_length = day_length, lights_off = lights_off)
    fig.update_layout(shapes=list(bar_shapes.values()))

    return fig

### Section - Quantify the above into box plots and run some quick statistical tests

### Section - Look at the raw data visualised

In [None]:
def plot_hmm_raw(self, hmm, variable = 'moving', colours = None, num_plots = 5, bin = 60, mago_df = None, func = 'max', show_movement = False, title = ''):
    """ plots the raw dedoded hmm model per fly (total = num_plots) 
        If hmm is a list of hmm objects, the number of plots will equal the length of that list. Use this to compare hmm models.
        """

    if colours is None:
        if isinstance(hmm, list):
            h = hmm[0]
        else:
            h = hmm
        states = h.transmat_.shape[0]
        if states == 4:
            colours = self._colours_four
        else:
            raise RuntimeError(f'Your trained HMM is not 4 states, please provide the {h.transmat_.shape[0]} colours for this hmm. See doc string for more info')

    colours_index = {c : col for c, col in enumerate(colours)}

    if mago_df is not None:
        assert isinstance(mago_df, behavpy) or isinstance(mago_df, behavpy_HMM), 'The mAGO dataframe is not a behavpy or behavpy_HMM class'

    if isinstance(hmm, list):
        num_plots = len(hmm)
        rand_flies = [np.random.permutation(list(set(self.meta.index)))[0]] * num_plots
        h_list = hmm
        if isinstance(bin, list):
            b_list = bin 
        else:
            b_list = [bin] * num_plots
    else:
        rand_flies = np.random.permutation(list(set(self.meta.index)))[:num_plots]
        h_list = [hmm] * num_plots
        b_list = [bin] * num_plots

    df_list = [self.xmv('id', id) for id in rand_flies]
    decoded = [self._hmm_decode(d, h, b, variable, func) for d, h, b in zip(df_list, h_list, b_list)]

    def analyse(data, df_variable):
        states = data[0][0]
        time = data[1][0]
        id = df_variable.index.tolist()
        var = df_variable[variable].tolist()
        previous_states = np.array(states[:-1], dtype = float)
        previous_states = np.insert(previous_states, 0, np.nan)
        previous_var = np.array(var[:-1], dtype = float)
        previous_var = np.insert(previous_var, 0, np.nan)
        all_df = pd.DataFrame(data = zip(id, states, time, var, previous_states, previous_var))
        all_df.columns = ['id','state', 't', 'var','previous_state', 'previous_var']
        all_df.dropna(inplace = True)
        all_df['colour'] = all_df['previous_state'].map(colours_index)
        all_df.set_index('id', inplace = True)
        return all_df

    analysed = [analyse(i, v) for i, v in zip(decoded, df_list)]

    fig = make_subplots(
    rows= num_plots, 
    cols=1,
    shared_xaxes=True, 
    shared_yaxes=True, 
    vertical_spacing=0.02,
    horizontal_spacing=0.02
    )

    for c, (df, b) in enumerate(zip(analysed, b_list)):
        id = df.first_valid_index()
        print(f'Plotting: {id}')
        if mago_df is not None:
            df2 = mago_df.xmv('id', id)
            df2 = df2[df2['has_interacted'] == 1]
            df2['t'] = df2['interaction_t'].map(lambda t:  b * floor(t / b))
            df2.reset_index(inplace = True)
            df = pd.merge(df, df2, how = 'outer', on = ['id', 't'])
            df['colour'] = np.where(df['has_responded'] == True, 'purple', df['colour'])
            df['colour'] = np.where(df['has_responded'] == False, 'lime', df['colour'])
            df['t'] = df['t'].map(lambda t: t / (60*60))
        
        else:
            df['t'] = df['t'].map(lambda t: t / (60*60))

        trace1 = go.Scatter(
            showlegend = False,
            y = df['previous_state'],
            x = df['t'],
            mode = 'markers+lines', 
            marker = dict(
                color = df['colour'],
                ),
            line = dict(
                color = 'black',
                width = 0.5
            )
            )
        fig.add_trace(trace1, row = c+1, col= 1)

        if show_movement == True:
            df['var'] = np.roll((df['var'] * 2) + 0.5, 1)
            trace2 = go.Scatter(
                showlegend = False,
                y = df['var'],
                x = df['t'],
                mode = 'lines', 
                marker = dict(
                    color = 'black',
                    ),
                line = dict(
                    color = 'black',
                    width = 0.75
                )
                )
            fig.add_trace(trace2, row = c+1, col= 1)

    y_range = [-0.2, states-0.8]
    self._plot_ylayout(fig, yrange = y_range, t0 = 0, dtick = False, ylabel = '', title = title)

    fig.update_yaxes(
        showgrid = False,
        linecolor = 'black',
        zeroline = False,
        ticks = 'outside',
        range = y_range, 
        tick0 = 0, 
        dtick = 1,
        tickwidth = 2,
        tickfont = dict(
            size = 18
        ),
        linewidth = 4
    )
    fig.update_xaxes(
        showgrid = False,
        color = 'black',
        linecolor = 'black',
        ticks = 'outside',
        tickwidth = 2,
        tickfont = dict(
            size = 18
        ),
        linewidth = 4
    )

    fig.update_yaxes(
        title = dict(
            text = 'Predicted State',
            font = dict(
                size = 18,
                color = 'black'
            )
        ),
        row = ceil(num_plots/2),
        col = 1
    )

    fig.update_xaxes(
        title = dict(
            text = 'ZT Hours',
            font = dict(
                size = 18,
                color = 'black'
            )
        ),
        row = num_plots,
        col = 1
    )

    return fig