In [2]:
import numpy as np
import pandas as pd
from hmmlearn import hmm

### First we need to intialise our categorical HMM from hmm learn 
### https://hmmlearn.readthedocs.io/en/latest/api.html

In [None]:
model = hmm.CategoricalHMM(n_components = 4, n)
hmm.CategoricalHMM(n_components = n_states, n_iter = hmm_iterations, tol = tol, params = 'ste', verbose = verbose)

In [None]:
def reshape(dataframe):

## Idea set them a task of creating a singular function that intialises, transforms, and trains the data

In [None]:
def hmm_train(dataframe, states, observables, var_column, file_name, trans_probs = None, emiss_probs = None, start_probs = None, iterations = 10, hmm_iterations = 100, tol = 50, t_column = 't', bin_time = 60, test_size = 10, verbose = False):
    """

    There must be no NaNs in the training data

    Resultant hidden markov models will be saved as a .pkl file if file_name is provided
    Final trained model probability matrices will be printed to terminal at the end of the run time

    Params:
    @states = list of sting(s), names of hidden states for the model to train to
    @observables = list of string(s), names of the observable states for the model to train to.
    The length must be the same number as the different categories in you movement column.
    @trans_probs = numpy array, transtion probability matrix with shape 'len(states) x len(states)', 0's restrict the model from training any tranisitons between those states
    @emiss_probs = numpy array, emission probability matrix with shape 'len(observables) x len(observables)', 0's same as above
    @start_probs = numpy array, starting probability matrix with shape 'len(states) x 0', 0's same as above
    @var_column = string, name for the column containing the variable of choice to train the model
    @iterations = int, only used if random is True, number of loops using a different randomised starting matrices, default is 10
    @hmm_iterations = int, argument to be passed to hmmlearn, number of iterations of parameter updating without reaching tol before it stops, default is 100
    @tol = int, convergence threshold, EM will stop if the gain in log-likelihood is below this value, default is 50
    @t_column = string, name for the column containing the time series data, default is 't'
    @bin_time = int, the time in seconds the data will be binned to before the training begins, default is 60 (i.e 1 min)
    @file_name = string, name of the .pkl file the resultant trained model will be saved to, if left as '' and random is False the model won't be saved, default is ''
    @verbose = (bool, optional), argument for hmmlearn, whether per-iteration convergence reports are printed to terminal

    returns a trained hmmlearn HMM Multinomial object
    """
    
    if file_name.endswith('.pkl') is False:
        raise TypeError('enter a file name and type (.pkl) for the hmm object to be saved under')

    n_states = len(states)
    n_obs = len(observables)

    hmm_df = self.copy(deep = True)

    def bin_to_list(data, t_var, mov_var, bin):
        """ 
        Bins the time to the given integer and creates a nested list of the movement column by id
        """
        stat = 'max'
        data = data.reset_index()
        t_delta = data[t_column].iloc[1] - data[t_column].iloc[0]
        if t_delta != bin:
            data[t_var] = data[t_var].map(lambda t: bin * floor(t / bin))
            bin_gb = data.groupby(['id', t_var]).agg(**{
                mov_var : (var_column, stat)
            })file_name
            bin_gb.reset_index(level = 1, inplace = True)
            gb = bin_gb.groupby('id')[mov_var].apply(np.array)
        else:
            gb = data.groupby('id')[mov_var].apply(np.array)
        return gb

    if var_column == 'beam_crosses':
        hmm_df['active'] = np.where(hmm_df[var_column] == 0, 0, 1)
        gb = bin_to_list(hmm_df, t_var = t_column, mov_var = var_column, bin = bin_time)

    elif var_column == 'moving':
        hmm_df[var_column] = np.where(hmm_df[var_column] == True, 1, 0)
        gb = bin_to_list(hmm_df, t_var = t_column, mov_var = var_column, bin = bin_time)

    else:
        gb = bin_to_list(hmm_df, t_var = t_column, mov_var = var_column, bin = bin_time)

    # split runs into test and train lists
    test_train_split = round(len(gb) * (test_size/100))
    rand_runs = np.random.permutation(gb)
    train = rand_runs[test_train_split:]
    test = rand_runs[:test_train_split]

    len_seq_train = [len(ar) for ar in train]
    len_seq_test = [len(ar) for ar in test]

    seq_train = np.concatenate(train, 0)
    seq_train = seq_train.reshape(-1, 1)
    seq_test = np.concatenate(test, 0)
    seq_test = seq_test.reshape(-1, 1)

    for i in range(iterations):
        print(f"Iteration {i+1} of {iterations}")
        
        init_params = ''
        # h = hmm.MultinomialHMM(n_components = n_states, n_iter = hmm_iterations, tol = tol, params = 'ste', verbose = verbose)
        h = hmm.CategoricalHMM(n_components = n_states, n_iter = hmm_iterations, tol = tol, params = 'ste', verbose = verbose)

        if start_probs is None:
            init_params += 's'
        else:
            s_prob = np.array([[np.random.random() if y == 'rand' else y for y in x] for x in start_probs], dtype = np.float64)
            s_prob = np.array([[y / sum(x) for y in x] for x in t_prob], dtype = np.float64)
            h.startprob_ = s_prob

        if trans_probs is None:
            init_params += 't'
        else:
            # replace 'rand' with a new random number being 0-1
            t_prob = np.array([[np.random.random() if y == 'rand' else y for y in x] for x in trans_probs], dtype = np.float64)
            t_prob = np.array([[y / sum(x) for y in x] for x in t_prob], dtype = np.float64)
            h.transmat_ = t_prob

        if emiss_probs is None:
            init_params += 'e'
        else:
            # replace 'rand' with a new random number being 0-1
            em_prob = np.array([[np.random.random() if y == 'rand' else y for y in x] for x in emiss_probs], dtype = np.float64)
            em_prob = np.array([[y / sum(x) for y in x] for x in em_prob], dtype = np.float64)
            h.emissionprob_ = em_prob

        h.init_params = init_params
        h.n_features = n_obs # number of emission states

        # call the fit function on the dataset input
        h.fit(seq_train, len_seq_train)

        # Boolean output of if the number of runs convererged on set of appropriate probabilites for s, t, an e
        print("True Convergence:" + str(h.monitor_.history[-1] - h.monitor_.history[-2] < h.monitor_.tol))
        print("Final log liklihood score:" + str(h.score(seq_train, len_seq_train)))

        if i == 0:
            try:
                h_old = pickle.load(open(file_name, "rb"))
                if h.score(seq_test, len_seq_test) > h_old.score(seq_test, len_seq_test):
                    print('New Matrix:')
                    df_t = pd.DataFrame(h.transmat_, index = states, columns = states)
                    print(tabulate(df_t, headers = 'keys', tablefmt = "github") + "\n")
                    
                    with open(file_name, "wb") as file: pickle.dump(h, file)
            except OSError as e:
                with open(file_name, "wb") as file: pickle.dump(h, file)

        else:
            h_old = pickle.load(open(file_name, "rb"))
            if h.score(seq_test, len_seq_test) > h_old.score(seq_test, len_seq_test):
                print('New Matrix:')
                df_t = pd.DataFrame(h.transmat_, index = states, columns = states)
                print(tabulate(df_t, headers = 'keys', tablefmt = "github") + "\n")
                with open(file_name, "wb") as file: pickle.dump(h, file)

        if i+1 == iterations:
            h = pickle.load(open(file_name, "rb"))
            #print tables of trained emission probabilties, not accessible as objects for the user
            self._hmm_table(start_prob = h.startprob_, trans_prob = h.transmat_, emission_prob = h.emissionprob_, state_names = states, observable_names = observables)
            return h


In [None]:
def _hmm_table(start_prob, trans_prob, emission_prob, state_names, observable_names):
    """ 
    Prints a formatted table of the probabilities from a hmmlearn MultinomialHMM object
    """
    df_s = pd.DataFrame(start_prob)
    df_s = df_s.T
    df_s.columns = state_names
    print("Starting probabilty table: ")
    print(tabulate(df_s, headers = 'keys', tablefmt = "github") + "\n")
    print("Transition probabilty table: ")
    df_t = pd.DataFrame(trans_prob, index = state_names, columns = state_names)
    print(tabulate(df_t, headers = 'keys', tablefmt = "github") + "\n")
    print("Emission probabilty table: ")
    df_e = pd.DataFrame(emission_prob, index = state_names, columns = observable_names)
    print(tabulate(df_e, headers = 'keys', tablefmt = "github") + "\n")

In [None]:
def hmm_display(hmm, states, observables):
    """
    Prints to screen the transion probabilities for the hidden state and observables for a given hmmlearn hmm object
    """
    self._hmm_table(start_prob = hmm.startprob_, trans_prob = hmm.transmat_, emission_prob = hmm.emissionprob_, state_names = states, observable_names = observables)

In [None]:
def plot_hmm_overtime(self, hmm, variable = 'moving', labels = None, colours = None, wrapped = False, bin = 60, func = 'max', avg_window = 30, day_length = 24, lights_off = 12, title = '', grids = False):
    """
    Creates a plot of all states overlayed with y-axis shows the liklihood of being in a sleep state and the x-axis showing time in hours.
    The plot is generated through the plotly package

    Params:
    @self = behavpy_HMM,
    @hmm = hmmlearn.hmm.MultinomialHMM, this should be a trained HMM Learn object with the correct hidden states and emission states for your dataset
    @variable = string, the column heading of the variable of interest. Default is "moving"
    @labels = list[string], the names of the different states present in the hidden markov model. If None the labels are assumed to be ['Deep sleep', 'Light sleep', 'Quiet awake', 'Full awake']
    @colours = list[string], the name of the colours you wish to represent the different states, must be the same length as labels. If None the colours are a default for 4 states (blue and red)
    It accepts a specific colour or an array of numbers that are acceptable to plotly
    @wrapped = bool, if True the plot will be limited to a 24 hour day average
    @bin = int, the time in seconds you want to bin the movement data to, default is 60 or 1 minute
    @func = string, when binning to the above what function should be applied to the grouped data. Default is "max" as is necessary for the "moving" variable
    @avg_window, int, the window in minutes you want the moving average to be applied to. Default is 30 mins
    @circadian_night, int, the hour when lights are off during the experiment. Default is ZT 12
    @save = bool/string, if not False then save as the location and file name of the save file

    returns None
    """
    assert isinstance(wrapped, bool)

    df = self.copy(deep = True)

    labels, colours = self._check_hmm_shape(hm = hmm, lab = labels, col = colours)

    states_list, time_list = self._hmm_decode(df, hmm, bin, variable, func)

    df = pd.DataFrame()
    for l, t in zip(states_list, time_list):
        tdf = hmm_pct_state(l, t, list(range(len(labels))), avg_window = int((avg_window * 60)/bin))
        df = pd.concat([df, tdf], ignore_index = True)

    if wrapped is True:
        df['t'] = df['t'].map(lambda t: t % (60*60*day_length))

    df['t'] = df['t'] / (60*60)
    t_min = int(12 * floor(df.t.min() / 12))
    t_max = int(12 * ceil(df.t.max() / 12))    
    t_range = [t_min, t_max]  

    fig = go.Figure()
    self._plot_ylayout(fig, yrange = [-0.025, 1.01], t0 = 0, dtick = 0.2, ylabel = 'Probability of being in state', title = title, grid = grids)
    self._plot_xlayout(fig, xrange = t_range, t0 = 0, dtick = day_length/4, xlabel = 'ZT (Hours)')

    for c, (col, n) in enumerate(zip(colours, labels)):

        column = f'state_{c}'

        gb_df = df.groupby('t').agg(**{
                    'mean' : (column, 'mean'), 
                    'SD' : (column, 'std'),
                    'count' : (column, 'count')
                })

        gb_df['SE'] = (1.96*gb_df['SD']) / np.sqrt(gb_df['count'])
        gb_df['y_max'] = gb_df['mean'] + gb_df['SE']
        gb_df['y_min'] = gb_df['mean'] - gb_df['SE']
        gb_df = gb_df.reset_index()

        upper, trace, lower, _ = self._plot_line(df = gb_df, x_col = 't', name = n, marker_col = col)
        fig.add_trace(upper)
        fig.add_trace(trace) 
        fig.add_trace(lower)

    # Light-Dark annotaion bars
    bar_shapes, min_bar = circadian_bars(t_min, t_max, max_y = 1, day_length = day_length, lights_off = lights_off)
    fig.update_layout(shapes=list(bar_shapes.values()))

    return fig

In [None]:
def plot_hmm_raw(self, hmm, variable = 'moving', colours = None, num_plots = 5, bin = 60, mago_df = None, func = 'max', show_movement = False, title = ''):
    """ plots the raw dedoded hmm model per fly (total = num_plots) 
        If hmm is a list of hmm objects, the number of plots will equal the length of that list. Use this to compare hmm models.
        """

    if colours is None:
        if isinstance(hmm, list):
            h = hmm[0]
        else:
            h = hmm
        states = h.transmat_.shape[0]
        if states == 4:
            colours = self._colours_four
        else:
            raise RuntimeError(f'Your trained HMM is not 4 states, please provide the {h.transmat_.shape[0]} colours for this hmm. See doc string for more info')

    colours_index = {c : col for c, col in enumerate(colours)}

    if mago_df is not None:
        assert isinstance(mago_df, behavpy) or isinstance(mago_df, behavpy_HMM), 'The mAGO dataframe is not a behavpy or behavpy_HMM class'

    if isinstance(hmm, list):
        num_plots = len(hmm)
        rand_flies = [np.random.permutation(list(set(self.meta.index)))[0]] * num_plots
        h_list = hmm
        if isinstance(bin, list):
            b_list = bin 
        else:
            b_list = [bin] * num_plots
    else:
        rand_flies = np.random.permutation(list(set(self.meta.index)))[:num_plots]
        h_list = [hmm] * num_plots
        b_list = [bin] * num_plots

    df_list = [self.xmv('id', id) for id in rand_flies]
    decoded = [self._hmm_decode(d, h, b, variable, func) for d, h, b in zip(df_list, h_list, b_list)]

    def analyse(data, df_variable):
        states = data[0][0]
        time = data[1][0]
        id = df_variable.index.tolist()
        var = df_variable[variable].tolist()
        previous_states = np.array(states[:-1], dtype = float)
        previous_states = np.insert(previous_states, 0, np.nan)
        previous_var = np.array(var[:-1], dtype = float)
        previous_var = np.insert(previous_var, 0, np.nan)
        all_df = pd.DataFrame(data = zip(id, states, time, var, previous_states, previous_var))
        all_df.columns = ['id','state', 't', 'var','previous_state', 'previous_var']
        all_df.dropna(inplace = True)
        all_df['colour'] = all_df['previous_state'].map(colours_index)
        all_df.set_index('id', inplace = True)
        return all_df

    analysed = [analyse(i, v) for i, v in zip(decoded, df_list)]

    fig = make_subplots(
    rows= num_plots, 
    cols=1,
    shared_xaxes=True, 
    shared_yaxes=True, 
    vertical_spacing=0.02,
    horizontal_spacing=0.02
    )

    for c, (df, b) in enumerate(zip(analysed, b_list)):
        id = df.first_valid_index()
        print(f'Plotting: {id}')
        if mago_df is not None:
            df2 = mago_df.xmv('id', id)
            df2 = df2[df2['has_interacted'] == 1]
            df2['t'] = df2['interaction_t'].map(lambda t:  b * floor(t / b))
            df2.reset_index(inplace = True)
            df = pd.merge(df, df2, how = 'outer', on = ['id', 't'])
            df['colour'] = np.where(df['has_responded'] == True, 'purple', df['colour'])
            df['colour'] = np.where(df['has_responded'] == False, 'lime', df['colour'])
            df['t'] = df['t'].map(lambda t: t / (60*60))
        
        else:
            df['t'] = df['t'].map(lambda t: t / (60*60))

        trace1 = go.Scatter(
            showlegend = False,
            y = df['previous_state'],
            x = df['t'],
            mode = 'markers+lines', 
            marker = dict(
                color = df['colour'],
                ),
            line = dict(
                color = 'black',
                width = 0.5
            )
            )
        fig.add_trace(trace1, row = c+1, col= 1)

        if show_movement == True:
            df['var'] = np.roll((df['var'] * 2) + 0.5, 1)
            trace2 = go.Scatter(
                showlegend = False,
                y = df['var'],
                x = df['t'],
                mode = 'lines', 
                marker = dict(
                    color = 'black',
                    ),
                line = dict(
                    color = 'black',
                    width = 0.75
                )
                )
            fig.add_trace(trace2, row = c+1, col= 1)

    y_range = [-0.2, states-0.8]
    self._plot_ylayout(fig, yrange = y_range, t0 = 0, dtick = False, ylabel = '', title = title)

    fig.update_yaxes(
        showgrid = False,
        linecolor = 'black',
        zeroline = False,
        ticks = 'outside',
        range = y_range, 
        tick0 = 0, 
        dtick = 1,
        tickwidth = 2,
        tickfont = dict(
            size = 18
        ),
        linewidth = 4
    )
    fig.update_xaxes(
        showgrid = False,
        color = 'black',
        linecolor = 'black',
        ticks = 'outside',
        tickwidth = 2,
        tickfont = dict(
            size = 18
        ),
        linewidth = 4
    )

    fig.update_yaxes(
        title = dict(
            text = 'Predicted State',
            font = dict(
                size = 18,
                color = 'black'
            )
        ),
        row = ceil(num_plots/2),
        col = 1
    )

    fig.update_xaxes(
        title = dict(
            text = 'ZT Hours',
            font = dict(
                size = 18,
                color = 'black'
            )
        ),
        row = num_plots,
        col = 1
    )

    return fig