# import

In [1]:

import json
import os
import pandas as pd
import numpy as np
import ast 

def assign_to_group(mouse,expert_mice,hlesion_mice,learning_mice,var_dict):
    # if session in one of the groups (and define which)   
    if mouse in list(expert_mice) + list(hlesion_mice) + list(learning_mice):
        if mouse in expert_mice:
            var_dict['expert'] += [1]
            var_dict['hlesion'] += [0]
            var_dict['learning'] += [0]               
        elif mouse in hlesion_mice:                
            var_dict['expert'] += [0]
            var_dict['hlesion'] += [1]
            var_dict['learning'] += [0]   
        elif mouse in learning_mice:                
            var_dict['expert'] += [0]
            var_dict['hlesion'] += [0]
            var_dict['learning'] += [1]   
    return var_dict

def get_time_span(dat_path,pp_file,mouse):
    with open(dat_path + pp_file + r'\trainingData\\' + 'params_' + mouse + '.json', 'r') as file:
        params = json.load(file)
    time_spans = params['time_span']
    return time_spans

def find_useable_mouse_paths(sleep_ppseq_path,useable_mirs,expert_mice,hlesion_mice,learning_mice,var_dict,sleep_start):
    current_mouse_path = []
    for run_index,pp_file in enumerate(os.listdir(sleep_ppseq_path)):
        if not 'sleep_time_points' in pp_file:
            # current mouse
            mouse = '_'.join(pp_file.split('_')[0:3])    

            if mouse in useable_mirs:
                    
                    # asign to experimental group in var_dict
                    var_dict = assign_to_group(mouse,expert_mice,hlesion_mice,learning_mice,var_dict)

                    # load in sleep start time and time span
                    var_dict['current_sleep_start'] += sleep_start[mouse]
                    var_dict['time_spans'] += get_time_span(sleep_ppseq_path,pp_file,mouse)

                    # set path to processed files 
                    current_mouse_path += [sleep_ppseq_path + pp_file + '\\analysis_output\\']
                    var_dict['mirs'] += [mouse]
    return current_mouse_path,var_dict


def make_filter_masks(data,sequential_filter,nrem_filter,rem_filter,sleep_filters_on,background_only):
    ## filter this data
    if sequential_filter == True: 
        sequential_condition = data.ordering_classification == 'sequential'
    else:
        sequential_condition = np.array([True]*len(data.ordering_classification))

    if sleep_filters_on == True:
        if nrem_filter == True: 
            nrem_condition = data.nrem_events == 1
        else:
            nrem_condition = np.array([False]*len(data.nrem_events))

        if rem_filter == True: 
            rem_condition = data.rem_events == 1
        else:
            rem_condition = np.array([False]*len(data.rem_events))

        if background_only == True:
            rem_condition = data.rem_events == 0
            nrem_condition = data.nrem_events == 0

    else:
        nrem_condition = np.array([True]*len(data))
        rem_condition = np.array([True]*len(data))
        
    # filter is set up so that any true will carry forward 
    filter_mask = sequential_condition * (nrem_condition + rem_condition)
        
    return filter_mask

def determine_chunk_mins(chunk_time,sleep_filters_on,nrem_filter,rem_filter,background_only,path):
    # if sleep_filters_on is false, use all chunk time
    if sleep_filters_on == False:
        mins = np.diff(chunk_time)[0]
    else:
        # load in state times
        rem_state_times = np.load(path + 'rem_state_times.npy')
        nrem_state_times = np.load(path + 'nrem_state_times.npy')
        if len(rem_state_times) > 0:
            tot_rem = sum(np.diff(rem_state_times))[0]
        else:
            tot_rem = 0
        if len(nrem_state_times) > 0:
            tot_nrem = sum(np.diff(nrem_state_times))[0]
        else:
            tot_nrem = 0

        # if background then use all non rem and non nrem times
        if background_only:
            mins = np.diff(chunk_time)[0] - (tot_rem+tot_nrem)
        else:
            # if both, use both 
            if nrem_filter == True and rem_filter == True:
                mins = tot_rem+tot_nrem
            elif nrem_filter == True and rem_filter == False:
                mins = tot_nrem
            elif nrem_filter == False and rem_filter == True:
                mins = tot_rem
    # convert to mins            
    mins = mins/60
    
    return mins

def cluster_events(start_times, end_times, threshold):
    clusters = []
    for i in range(len(start_times)):
        event_added = False
        for cluster in clusters:
            for index in cluster:
                if (start_times[i] <= end_times[index] + threshold and end_times[i] >= start_times[index] - threshold):
                    cluster.append(i)
                    event_added = True
                    break
            if event_added:
                break
        if not event_added:
            clusters.append([i])
    return clusters

def relative_dict(input_dict):
    total_sum = sum(input_dict.values())
    relative_dict = {key: value / total_sum for key, value in input_dict.items()}
    return relative_dict

def refind_cluster_events(filtered_chunk_data,event_proximity_filter):
    
    ### ignore the origonal clusterg rosp and remake them: 
    start_times = filtered_chunk_data.first_spike_time.values
    end_times = filtered_chunk_data.last_spike_time.values

    clustered_events = cluster_events(start_times, end_times,event_proximity_filter)

    cluster_group = np.zeros(len(filtered_chunk_data))
    for index,cluster in enumerate(clustered_events):
        for item in cluster:
            cluster_group[item] = int(index)
    filtered_chunk_data['coactive_cluster_group'] = cluster_group
    
    return filtered_chunk_data

def create_multicluster_dataframe(filtered_chunk_data):
    meaned_order = []
    fs_order = []
    event_times = []
    count = 0
    for i,group in enumerate(filtered_chunk_data.coactive_cluster_group.unique()):
        group_mask = filtered_chunk_data.coactive_cluster_group == group
        current_cluster = filtered_chunk_data[group_mask].copy()
        if len(current_cluster) > 1:
            means = []
            event_types = []
            fs_orders = []
            for index,events in enumerate(current_cluster.cluster_spike_times):
                event_types += [current_cluster.cluster_seq_type.values[index]]
                # calculate event order based on spike time weighted mean
                means += [np.mean(ast.literal_eval(events))]
                # calculate order based on first spike time:
                fs_orders += [current_cluster.first_spike_time.values[index]]

            # order by mean time:    
            meaned_order += [list(np.array(event_types)[np.argsort(means)])]
            # order by first spike:
            fs_order += [list(np.array(event_types)[np.argsort(fs_orders)])]

            event_times += [fs_orders]

            current_cluster['new_cluster_group'] =  [count]*len(current_cluster)
            current_cluster['cluster_order_first_spike_defined'] =  list(np.argsort(np.argsort(fs_orders)))
            current_cluster['cluster_order_mean_weighted_spikes_defined'] =  list(np.argsort(np.argsort(means)))

            if count == 0:
                multi_cluster_df = current_cluster.copy()
            else:
                # Concatenate the DataFrames vertically (row-wise)
                multi_cluster_df = pd.concat([multi_cluster_df, current_cluster], axis=0)
                # Reset the index if needed
                multi_cluster_df = multi_cluster_df.reset_index(drop=True)

            count += 1
    return multi_cluster_df,meaned_order,fs_order

def logic_machine_for_pair_catagorisation(pair,dominant,other):
    # if first one in dominant check for ordering:
    if pair[0] in dominant and pair[-1] in dominant:
        if pair_in_sequence(pair,dominant):
            return('ordered')
        elif pair_in_sequence(pair,dominant[::-1]):
            return('reverse')
        elif pair[-1] == pair[0]:
            return('repeat')
        elif pair[-1] in dominant:
            return('misordered') 
    # if its not these  options then check if it could be in the extra task seqs
    elif pair[0] in  (dominant + other) and pair[-1] in  (dominant + other):
        for item in other:
            if pair[0] in  (dominant + [item]):
                if pair_in_sequence(pair,(dominant + [item])):
                    return('ordered')
                elif pair_in_sequence(pair,(dominant + [item])[::-1]):
                    return('reverse')
                elif pair[-1] == pair[0]:
                    return('repeat')
                elif pair[-1] in (dominant + [item]):
                    return('misordered')  
        # if not this then check if both are in the extra seqs (and are not a repeat):
        if pair[0] in other and pair[-1] in other:
            if not pair[-1] == pair[0]: 
                return('ordered')
    else:
        # if item 1 is in but item 2 isnt then task to other 
        if pair[0] in  (dominant + other):
            if not pair[-1] in  (dominant + other):
                return('task to other')
        # if item 2 is in but item 1 isnt then other to task 
        elif not pair[0] in  (dominant + other):
            if pair[-1] in  (dominant + other):
                return('other to task')
            else:
                return('other')
    return print('ERROR!')

def pair_in_sequence(pair, sequence):
    for i in range(len(sequence) - 1):
        if sequence[i] == pair[0] and sequence[i + 1] == pair[1]:
            return True
        # because its ciruclar:
        elif sequence[-1] == pair[0] and sequence[0] == pair[1]:
            return True
    return False

def calculate_ordering_amounts(meaned_order,dominant,other_):
    ordered = 0
    misordered = 0
    other = 0
    for cluster in meaned_order:
        for ind,item in enumerate(cluster):
            if not ind == len(cluster)-1:
                pair = [item,cluster[ind+1]]
                outcome = logic_machine_for_pair_catagorisation(pair,dominant,other_)
                if outcome in ['ordered', 'repeat', 'reverse']:
                    ordered += 1
                elif outcome == 'misordered':
                    misordered += 1
                else:
                    other +=1
    return ordered,misordered,other

def all_motifs_proportion_coactive(multi_cluster_df):
    motif_motif_coative_events = []
    for seq_type in range(1,7):
        motif_cluster_groups = multi_cluster_df[multi_cluster_df['cluster_seq_type'] == seq_type].new_cluster_group
        if not len(motif_cluster_groups) == 0:
            coative_motif_events = len(motif_cluster_groups)
        else:
            coative_motif_events = 0
        motif_motif_coative_events += [coative_motif_events]
    return motif_motif_coative_events

def motif_by_motif_ordering(meaned_order,real_order,dominant,other_):

    all_motifs_fs_task_related_ordered = []
    all_motifs_fs_task_related_misordered = []
    all_motifs_fs_task_related_other = []

    for motif_type in range(1,7):
        ordered = 0
        misordered = 0
        other = 0
        
        if motif_type in real_order:
            for cluster in meaned_order:
                for ind,item in enumerate(cluster):
                    if not ind == len(cluster)-1:
                        pair = [item,cluster[ind+1]]
                        if motif_type in pair:
                            outcome = logic_machine_for_pair_catagorisation(pair,dominant,other_)
                            if outcome in ['ordered', 'repeat', 'reverse']:
                                ordered += 1
                            elif outcome == 'misordered':
                                misordered += 1
                            else:
                                other +=1  
                                
            all_motifs_fs_task_related_ordered += [ordered]
            all_motifs_fs_task_related_misordered += [misordered]
            all_motifs_fs_task_related_other += [other]
        else:
            all_motifs_fs_task_related_ordered += ['nan']
            all_motifs_fs_task_related_misordered += ['nan']
            all_motifs_fs_task_related_other += ['nan']

    return all_motifs_fs_task_related_ordered,all_motifs_fs_task_related_misordered,all_motifs_fs_task_related_other

def coactive_rate(filtered_chunk_data):
    # how mnay coacitve in chunk: 
    current_coactive_freqs_chunk = {}
    for cluster in filtered_chunk_data.coactive_cluster_group.unique():
        num = list(filtered_chunk_data.coactive_cluster_group.values).count(cluster)
        if num in current_coactive_freqs_chunk:
            current_coactive_freqs_chunk[num] += 1
        else:
            current_coactive_freqs_chunk[num] = 1
            
    # total single events that are cocaitve with at least one other
    cocative_total = 0
    overall_total = 0
    for item in list(current_coactive_freqs_chunk):
        if item > 1:
            cocative_total += current_coactive_freqs_chunk[item]
        overall_total += current_coactive_freqs_chunk[item] * item

    # coactive_lengths (only coactive, ignore single events)
    coactive_len_per_chunk =[]
    for item in current_coactive_freqs_chunk:
        if item > 1:
            coactive_len_per_chunk += current_coactive_freqs_chunk[item] * [item]


    return cocative_total,coactive_len_per_chunk,overall_total

def empty_chunk_vars():
    ## set chunk vars 
    chunk_vars = {"chunk_reactivations" : [],
                "chunk_mins" : [],
                "chunk_motif_type_reactivations" : [],
                "mean_spikes_per_event" : [],
                "motif_by_motif_mean_spikes_per_event" : [],
                "mean_units_per_event" : [],
                "motif_by_motif_mean_units_per_event" : [],
                "chunk_event_lengths" : [],
                "motif_event_lenghts" : [],
                "total_single_events_coacitvely_paired" : [],
                "coactive_lenghts" : [],
                "overall_total_coactive_or_single_cluster_events" : [],
                "meaned_order_task_related_ordered":[],
                "meaned_order_task_related_misordered":[],
                "meaned_order_task_related_other":[],
                "fs_order_task_related_ordered":[],
                "fs_order_task_related_misordered":[],
                "fs_order_task_related_other":[],
                "normalised_task_related_total":[],
                "normalised_non_task_related_total":[],
                "all_motifs_total_coactive":[],
                "meaned_ordering_all_motifs_task_related_ordered":[],
                "meaned_ordering_all_motifs_task_related_misordered":[],
                "meaned_ordering_all_motifs_task_related_other":[],
                "fs_ordering_all_motifs_task_related_ordered":[],
                "fs_ordering_all_motifs_task_related_misordered":[],
                "fs_ordering_all_motifs_task_related_other":[],
                
                "normalised_task_related_total":[],
                "normalised_non_task_related_total":[],
                                
                

    }
    return chunk_vars

# replay processing loop

In [2]:
# seq filter takes presedence, if its on: only sequential events, if it is off: all events 
sequential_filter = True
## master switch - turns all sleep filters on/off (if you want all evets turn this off)
sleep_filters_on = True
# these filters refer to seq one above, and both can be true at the same time. 
nrem_filter = True
rem_filter = True
# set this as true (along with the sleep filter one) to override the other two an djust take the background 
background_only = False


## sanity checker / set save path:
print('this filtering gives...')
if sequential_filter == True:
    print(' - only sequential events')
    save_var = 'sequential_no_sleep_selected'
    type_var = 'sequential'
else:
    print('- all events')
    save_var = 'all_events_no_sleep_selected'
    type_var = 'all_events'
if sleep_filters_on == True:
    if not background_only:
        print('and only those which are in')
        if nrem_filter == True:
            print(' - nrem')
            save_var = type_var+'_NREM_sleep'
        if rem_filter == True:
            print(' - rem')
            save_var = type_var+'_REM_sleep'
        if nrem_filter == True and rem_filter == True:
            save_var = type_var+'_NREM_and_REM_sleep'
        
    else:
        print('and only those which are not in rem/nrem')
        save_var = type_var + '_OTHER_nonsleep'
        
        



this filtering gives...
 - only sequential events
and only those which are in
 - nrem
 - rem


In [None]:
sleep_ppseq_path = r"Z:\projects\sequence_squad\ppseq_finalised_publication_data\expert\postsleep\\"
out_path = r"Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\"
useable_mirs  = ['136_1_3','136_1_4','149_1_1','178_1_6','178_1_7','178_1_8','178_2_1','178_2_2','178_2_3','268_1_2','269_1_2','269_1_4','270_1_6']
#['ap5r_1_1','ap5r_1_2','ap5r_1_3','seq006_1_1','seq006_1_10','seq006_1_11','seq006_1_2','seq006_1_3','seq006_1_4','seq006_1_5','seq006_1_6','seq006_1_7','seq006_1_8','seq006_1_9','seq007_1_1','seq007_1_2','seq007_1_3','seq007_1_4','seq008_1_3']
# load in sleep time points
sleep_time_point_df = pd.read_csv(sleep_ppseq_path + 'sleep_time_points.csv')
# decide when sleep started
sleep_start = {}
for index,value in enumerate(sleep_time_point_df.approx_sleep_start.values):
    mouse = sleep_time_point_df.mir.values[index]
    sleep_start[mouse] = value
# define mice/sessions in each group    
expert_mice = sleep_time_point_df[sleep_time_point_df.group == 'expert'].mir.values
hlesion_mice = sleep_time_point_df[sleep_time_point_df.group == 'h_lesion'].mir.values
learning_mice = sleep_time_point_df[sleep_time_point_df.group == 'learning'].mir.values


var_dict = {'expert':[],'hlesion':[],'learning' :[],'mirs':[],'current_sleep_start':[], 'time_spans':[]}

#Load in seq order data 
sequence_order_df = pd.read_csv(sleep_ppseq_path+"sequence_order.csv")
# get all the relevant path name and some other data 
current_mouse_path,var_dict = find_useable_mouse_paths(sleep_ppseq_path,useable_mirs,expert_mice,hlesion_mice,learning_mice,var_dict,sleep_start)
# loop across each mouse path:
for loop_index, path in enumerate(current_mouse_path):
    # create empty chunk vars dict
    chunk_vars = empty_chunk_vars()
    
    print(f"run index: {loop_index}, processing {var_dict['mirs'][loop_index]}")

    ## loop across all chunk files
    for chunk_index, file in enumerate(os.listdir(path)):
        if 'chunk' in file:
            print(file)
            path_ = path + '\\' + file + '\\'
            chunk_time = np.load(path_ + 'chunk_time_interval.npy')
            data = pd.read_csv(path_ + 'filtered_replay_clusters_df.csv')
            
            # filter based on the sequential/rem-nrem conditions set above
            filter_mask = make_filter_masks(data,sequential_filter,nrem_filter,rem_filter,sleep_filters_on,background_only)
            filtered_chunk_data = data[filter_mask].reset_index()
            
            # how many reactivations found
            reactivations_found = len(filtered_chunk_data)
            print(reactivations_found)
            
            ####################################### chunk rate per minute: (# this one depends on rem/nrem filter... )
            mins = determine_chunk_mins(chunk_time,sleep_filters_on,nrem_filter,rem_filter,background_only,path_)
            if mins > 0:
                chunk_vars['chunk_reactivations'] += [reactivations_found] 
                chunk_vars['chunk_mins'] += [mins]     
                
            ####################################### replay rate per motif type
            all_motif_type_reactivations = []
            all_motif_type_reactivations_min = []
            all_motif_type_relative_proportion = []
            for seq_type in range(1,7):
                motif_type_reactivations = [len(np.where(filtered_chunk_data.cluster_seq_type.values == seq_type)[0])][0]
                all_motif_type_reactivations += [motif_type_reactivations]
            chunk_vars['chunk_motif_type_reactivations'] += [all_motif_type_reactivations]
            
            ##################################### av. spikes involved
            chunk_vars['mean_spikes_per_event'] += [[len(item) for item in filtered_chunk_data.cluster_spike_times]]
            # per motif
            motif_by_motif_mean_spikes_per_event = []
            for motif_number in range(1,7):
                motif_data = filtered_chunk_data[filtered_chunk_data['cluster_seq_type'] == motif_number]
                motif_by_motif_mean_spikes_per_event += [[len(item) for item in motif_data.cluster_spike_times]]
            chunk_vars['motif_by_motif_mean_spikes_per_event'] += [motif_by_motif_mean_spikes_per_event]  
                    
            ########################################## average units involved 
            chunk_vars['mean_units_per_event'] += [[len(np.unique(ast.literal_eval(item))) for item in filtered_chunk_data.cluster_neurons]]
            motif_by_motif_mean_units_per_event = []
            for motif_number in range(1,7):
                motif_data = filtered_chunk_data[filtered_chunk_data['cluster_seq_type'] == motif_number]
                motif_by_motif_mean_units_per_event += [[len(np.unique(ast.literal_eval(item))) for item in motif_data.cluster_neurons]]
            chunk_vars['motif_by_motif_mean_units_per_event'] += [motif_by_motif_mean_units_per_event]
            
            ########################################### replay length overall 
            chunk_vars['chunk_event_lengths'] += [filtered_chunk_data.event_length.values]
            
            ########################################### replay length per motif 
            motif_event_lenghts = []
            for i in range(1,7):
                motif_event_lenghts += [filtered_chunk_data[filtered_chunk_data.cluster_seq_type == i].event_length.values]
            chunk_vars['motif_event_lenghts'] += [motif_event_lenghts]
            
            ########################################### coactive rate overall
            event_proximity_filter =  0.3 #s (how close events have to be to each other to be clustered together as coacitve 
            # refind the clusters
            filtered_chunk_data  = refind_cluster_events(filtered_chunk_data,event_proximity_filter)
            # how many single evtns coactivly paired? average coactive rate? proportion of global events coactive? 
            cocative_total,coactive_len_per_chunk,overall_total = coactive_rate(filtered_chunk_data)
            chunk_vars['total_single_events_coacitvely_paired'] += [cocative_total]
            chunk_vars['coactive_lenghts'] += [coactive_len_per_chunk]
            chunk_vars['overall_total_coactive_or_single_cluster_events'] += [overall_total]
                    
                    

            # ordering of coactive?
            if not chunk_vars['total_single_events_coacitvely_paired'][chunk_index] == 0:
                multi_cluster_df,meaned_order,fs_order = create_multicluster_dataframe(filtered_chunk_data)
            else:
                multi_cluster_df,meaned_order,fs_order = [],[],[]

            # pull out sequence order for current mouse
            seq_order= ast.literal_eval(sequence_order_df[sequence_order_df.mir == var_dict['mirs'][loop_index]].seq_order.values[0])
            num_dominant_seqs = int(sequence_order_df[sequence_order_df.mir == var_dict['mirs'][loop_index]].dominant_task_seqs)
            real_order = np.array(seq_order)+1

            #deal wih the fact that the way I order the sequence messes up the order a bit
            if not len(real_order) == num_dominant_seqs:
                dominant = list(real_order[0:num_dominant_seqs])
                other_ = list(real_order[num_dominant_seqs::])
            else:
                dominant = list(real_order)
                other_ = []
                
            # orderng amounts for mean ordering - this calculated for each pair in the chunk 
            ordered,misordered,other = calculate_ordering_amounts(meaned_order,dominant,other_)
            chunk_vars['meaned_order_task_related_ordered'] += [ordered]
            chunk_vars['meaned_order_task_related_misordered'] += [misordered]
            chunk_vars['meaned_order_task_related_other'] += [other]
            
            # orderng amounts for first spike ordering
            ordered,misordered,other = calculate_ordering_amounts(fs_order,dominant,other_)
            chunk_vars['fs_order_task_related_ordered'] += [ordered]
            chunk_vars['fs_order_task_related_misordered'] += [misordered]
            chunk_vars['fs_order_task_related_other'] += [other]
            ### motif by motif:
            # does one motif appear more in coactive?
            if not chunk_vars['total_single_events_coacitvely_paired'][chunk_index] == 0:
                all_motifs_total_coactive = all_motifs_proportion_coactive(multi_cluster_df)
                chunk_vars['all_motifs_total_coactive'] += [all_motifs_total_coactive]
            else:
                chunk_vars['all_motifs_total_coactive'] += [[0,0,0,0,0,0]]
            
            # does one motif appeaer more ordered? 
            # # this is only calculated for the task related motifs as the non task dont have an order - thought other catagory still exists for times it was task to non task or other way around 
            # meaned ordering 
            all_motifs_meaned_ordering_task_related_ordered,all_motifs_meaned_ordering_task_related_misordered,all_motifs_meaned_ordering_task_related_other = motif_by_motif_ordering(meaned_order,real_order,dominant,other_)
            chunk_vars['meaned_ordering_all_motifs_task_related_ordered'] += [all_motifs_meaned_ordering_task_related_ordered]
            chunk_vars['meaned_ordering_all_motifs_task_related_misordered'] += [all_motifs_meaned_ordering_task_related_misordered]
            chunk_vars['meaned_ordering_all_motifs_task_related_other'] += [all_motifs_meaned_ordering_task_related_other]
            
            # first spike ordering 
            all_motifs_fs_task_related_ordered,all_motifs_fs_task_related_misordered,all_motifs_fs_task_related_other = motif_by_motif_ordering(fs_order,real_order,dominant,other_)
            chunk_vars['fs_ordering_all_motifs_task_related_ordered'] += [all_motifs_fs_task_related_ordered]
            chunk_vars['fs_ordering_all_motifs_task_related_misordered'] += [all_motifs_fs_task_related_misordered]
            chunk_vars['fs_ordering_all_motifs_task_related_other'] += [all_motifs_fs_task_related_other]

            ########################################### task related vs other rate
            task_seqs = np.array(seq_order)+1
            # mask each condition
            mask = np.isin(filtered_chunk_data.cluster_seq_type.values, task_seqs)
            opposite_mask = ~mask
            task_related = filtered_chunk_data[mask]
            non_task_related = filtered_chunk_data[opposite_mask]

            #  task v nontask overallrate
            chunk_vars['normalised_task_related_total'] += [len(task_related)/len(task_seqs)]
            if len(task_seqs) == 6:
                chunk_vars['normalised_non_task_related_total'] += [0]
            else:
                chunk_vars['normalised_non_task_related_total'] += [len(non_task_related)/(6-len(task_seqs))]
            
            ### extra stuff to add in:
            
            # same but motif by motif
            
            # number of spikes task related 
            # number of units task related
            
            # # coative rate 
            # task_related_number = 0
            # non_task_related_number = 0
            # for coactive_ in meaned_order:
            #     for motif_item in coactive_:
            #         if motif_item in task_seqs:
            #             task_related_number += 1
            #         else:
            #             non_task_related_number += 1
            # # make it relative:
            # task_related_number = task_related_number/len(task_seqs) 
            # non_task_related_number = non_task_related_number/(6-len(task_seqs))
            # proportion_coacitve_event_that_are_task_related = task_related_number/(task_related_number+non_task_related_number)
            # chunk_vars['proportion_coacitve_event_that_are_task_related'] = proportion_coacitve_event_that_are_task_related
            
            # # motif coactive rate for task and non task 
            # # task
            # task_events = filtered_chunk_data[filtered_chunk_data.cluster_seq_type.isin(task_seqs)]
            # task_proportion_single_events_coacitvely_paired,task_av_coactive_len_per_chunk,task_proporiton_of_events_coactive = coactive_rate(task_related)
            # # non task:
            # non_task_events = filtered_chunk_data[~filtered_chunk_data.cluster_seq_type.isin(task_seqs)]
            # nontask_proportion_single_events_coacitvely_paired,nontask_av_coactive_len_per_chunk,nontask_proporiton_of_events_coactive = coactive_rate(non_task_related)

            # replay length
            #task v non task
            # motif by motif  
            # spikes involved
            
            # save out to newly made place
            
        ###now do averages for each chunk and save out to a new file
        ########## Calculate averages across chunks/ combine across chunks for each data variable
        out_vars = {}

        # 1 overall event rate 
        if sum(chunk_vars['chunk_reactivations']) == 0:
            out_vars['event_rpm'] = 0
        else:
            out_vars['event_rpm'] = sum(chunk_vars['chunk_reactivations'])/sum(chunk_vars['chunk_mins'])

        # 2 motif by motif event rate 
        out_vars['motif_event_rpm'] = np.sum(chunk_vars['chunk_motif_type_reactivations'],axis = 0)/sum(chunk_vars['chunk_mins'])

        # 3 motif by motif proportion of all events
        out_vars['motif_relative_event_proportion'] = np.sum(chunk_vars['chunk_motif_type_reactivations'],axis = 0)/sum(np.sum(chunk_vars['chunk_motif_type_reactivations'],axis = 0))

        # 4 spikes per replay event
        chunk_spikes_per_event = chunk_vars['mean_spikes_per_event']
        spikes_per_event = [item for sublist in chunk_spikes_per_event for item in sublist]
        out_vars['spikes_per_event'] = spikes_per_event

        # 5 motif by motif spikes per replay event
        motif_by_motif_spikes_per_event = []
        for seq in range(1,7):
            motif_spikes_per_event = []
            for chunk_ in chunk_vars['motif_by_motif_mean_spikes_per_event']:
                motif_spikes_per_event +=chunk_[seq-1]
            motif_by_motif_spikes_per_event += [motif_spikes_per_event]
        out_vars['motif_by_motif_spikes_per_event'] = motif_by_motif_spikes_per_event

        # 6 units per event
        chunk_units_per_event = chunk_vars['mean_units_per_event']
        units_per_event = [item for sublist in chunk_units_per_event for item in sublist]
        out_vars['units_per_event'] = units_per_event

        # 7 motif by motif units per replay event
        motif_by_motif_units_per_event = []
        for seq in range(1,7):
            motif_units_per_event = []
            for chunk_ in chunk_vars['motif_by_motif_mean_units_per_event']:
                motif_units_per_event +=chunk_[seq-1]
            motif_by_motif_units_per_event += [motif_units_per_event]
        out_vars['motif_by_motif_units_per_event'] = motif_by_motif_units_per_event

        # 8 event lengths
        chunk_event_lengths = chunk_vars['chunk_event_lengths']
        event_lengths = [item for sublist in chunk_event_lengths for item in sublist]
        out_vars['event_lengths'] = event_lengths

        # 9 motif by motif event lengths
        motif_by_motif_event_lengths = []
        for seq in range(1,7):
            motif_event_lengths = []
            for chunk_ in chunk_vars['motif_event_lenghts']:
                motif_event_lengths +=list(chunk_[seq-1])
            motif_by_motif_event_lengths += [motif_event_lengths]
        out_vars['motif_by_motif_event_lengths'] = motif_by_motif_event_lengths

        # 10 coactive rate (proportion of single events coacitvly paired)
        if sum(chunk_vars['total_single_events_coacitvely_paired']) == 0:
            out_vars['proportion_single_events_coactivly_paired'] = 0
        else:
            out_vars['proportion_single_events_coactivly_paired'] = sum(chunk_vars['total_single_events_coacitvely_paired'])/sum(chunk_vars['overall_total_coactive_or_single_cluster_events'])


        # 11 number of motifs in each coative group
        out_vars['coactive_group_lengths'] = [item for sublist in chunk_vars['coactive_lenghts']  for item in sublist]

        # 12 meaned order
        # ordered proporiton out of all 
        ordered = sum(chunk_vars['meaned_order_task_related_ordered'])
        misordered = sum(chunk_vars['meaned_order_task_related_misordered'])
        other = sum(chunk_vars['meaned_order_task_related_other'])
        if ordered+misordered+other == 0:
            out_vars['meaned_order_overall_ordered_prop'] = 0
            # ordered proporito out of taks related (ordered/misodered)
            out_vars['meaned_order_task_related_ordered_prop'] = 0
            # other proportion
            out_vars['meaned_order_overall_other_prop'] = 0
        else:
            out_vars['meaned_order_overall_ordered_prop'] = ordered/(ordered+misordered+other)
            # ordered proporito out of taks related (ordered/misodered)
            if ordered+misordered == 0:
                out_vars['meaned_order_task_related_ordered_prop'] = 0
            else:
                out_vars['meaned_order_task_related_ordered_prop'] = ordered/(ordered+misordered)
            # other proportion
            out_vars['meaned_order_overall_other_prop'] = other/(ordered+misordered+other)

        # 13 first spike order
        # ordered proporiton out of all 
        ordered = sum(chunk_vars['fs_order_task_related_ordered'])
        misordered = sum(chunk_vars['fs_order_task_related_misordered'])
        other = sum(chunk_vars['fs_order_task_related_other'])
        if ordered+misordered+other == 0:
            out_vars['fs_order_overall_ordered_prop'] = 0
            # ordered proporito out of taks related (ordered/misodered)
            out_vars['fs_order_task_related_ordered_prop'] = 0
            # other proportion
            out_vars['fs_order_overall_other_prop'] = 0
        else:
            out_vars['fs_order_overall_ordered_prop'] = ordered/(ordered+misordered+other)
            # ordered proporito out of taks related (ordered/misodered)
            if ordered+misordered == 0:
                out_vars['fs_order_task_related_ordered_prop'] = 0
            else:
                out_vars['fs_order_task_related_ordered_prop'] = ordered/(ordered+misordered)
            # other proportion
            out_vars['fs_order_overall_other_prop'] = other/(ordered+misordered+other)

        # 14 motif by motif proportion coactive, what proprotion of each motif is coactive
        out_vars['motif_proportion_coactive'] = np.sum(chunk_vars['all_motifs_total_coactive'],axis = 0)/np.sum(chunk_vars['chunk_motif_type_reactivations'],axis = 0)

        # 15 does one motif appeaer more ordered? Proportion of moitfs that were in ordered events for coactive task related motifs
        # this is only calculated for the task related motifs as the non task dont have an order - thought other catagory still exists for times it was task to non task or other way around 
        ordered_prop = []
        for seq in range(1,7):
            ordered = []
            misordered = []
            other  = []
            for chunk in chunk_vars['meaned_ordering_all_motifs_task_related_ordered']:
                ordered += [chunk[seq-1]]
            for chunk in chunk_vars['meaned_ordering_all_motifs_task_related_misordered']:
                misordered += [chunk[seq-1]]
            for chunk in chunk_vars['meaned_ordering_all_motifs_task_related_other']:
                other += [chunk[seq-1]]
            if not 'nan' in ordered and (sum(ordered)+sum(misordered)+sum(other)) > 0:
                ordered_prop += [sum(ordered)/(sum(ordered)+sum(misordered)+sum(other))]
            else:
                ordered_prop += ['nan']
        out_vars['motif_meaned_ordering_ordered_prop_out_of_all_task_related'] = ordered_prop

        # 16 same but for first spike ordering
        ordered_prop = []
        for seq in range(1,7):
            ordered = []
            misordered = []
            other  = []
            for chunk in chunk_vars['fs_ordering_all_motifs_task_related_ordered']:
                ordered += [chunk[seq-1]]
            for chunk in chunk_vars['fs_ordering_all_motifs_task_related_misordered']:
                misordered += [chunk[seq-1]]
            for chunk in chunk_vars['fs_ordering_all_motifs_task_related_other']:
                other += [chunk[seq-1]]
            if not 'nan' in ordered and (sum(ordered)+sum(misordered)+sum(other)) > 0:
                ordered_prop += [sum(ordered)/(sum(ordered)+sum(misordered)+sum(other))]
            else:
                ordered_prop += ['nan']
        out_vars['motif_fs_ordering_ordered_prop_out_of_all_task_related'] = ordered_prop
        
        if sum(chunk_vars['normalised_non_task_related_total']) == 0:
            out_vars['Ratio_task_related_to_non_task_related_events'] = sum(chunk_vars['normalised_task_related_total'])
        elif sum(chunk_vars['normalised_task_related_total']) == 0:
            out_vars['Ratio_task_related_to_non_task_related_events'] = 0
        else:
            out_vars['Ratio_task_related_to_non_task_related_events'] = sum(chunk_vars['normalised_task_related_total'])/sum(chunk_vars['normalised_non_task_related_total'])


        ####### SAVE OUT THE DATA 

        try: 
            int(var_dict['mirs'][loop_index].split('_')[0])
            current_save_path = out_path + 'EJT' + var_dict['mirs'][loop_index] + '\\replay\\' + save_var 
        except:
            current_save_path = out_path + var_dict['mirs'][loop_index] + '\\replay\\' + save_var 
        
        ## if the path doesnt exist, make a new dir 
        if not os.path.exists(current_save_path):
            os.makedirs(current_save_path)
            
        # convert all np arrays to list
        for key, value in out_vars.items():
            if isinstance(value, np.ndarray):
                out_vars[key] = value.tolist()

        ### save out the out_vars dict
        with open(current_save_path + '\\replay_data_variables.json', 'w') as file:
            json.dump(out_vars, file)
            
        # convert all np arrays to list
        for key, value in var_dict.items():
            if isinstance(value, np.ndarray):
                var_dict[key] = value.tolist()
            
        ### save out the var_dict 
        with open(current_save_path + '\\general_mouse_info.json', 'w') as file:
            json.dump(var_dict, file)
            
    print('done!')


In [None]:
#Load in seq order data 
sequence_order_df = pd.read_csv(sleep_ppseq_path+"sequence_order.csv")
# get all the relevant path name and some other data 
current_mouse_path,var_dict = find_useable_mouse_paths(sleep_ppseq_path,useable_mirs,expert_mice,hlesion_mice,learning_mice,var_dict,sleep_start)
# loop across each mouse path: