# import

In [2]:

import json
import os
import pandas as pd
import numpy as np

def assign_to_group(mouse,expert_mice,hlesion_mice,learning_mice,var_dict):
    # if session in one of the groups (and define which)   
    if mouse in list(expert_mice) + list(hlesion_mice) + list(learning_mice):
        if mouse in expert_mice:
            var_dict['expert'] += [1]
            var_dict['hlesion'] += [0]
            var_dict['learning'] += [0]               
        elif mouse in hlesion_mice:                
            var_dict['expert'] += [0]
            var_dict['hlesion'] += [1]
            var_dict['learning'] += [0]   
        elif mouse in learning_mice:                
            var_dict['expert'] += [0]
            var_dict['hlesion'] += [0]
            var_dict['learning'] += [1]   
    return var_dict

def get_time_span(dat_path,pp_file,mouse):
    with open(dat_path + pp_file + r'\trainingData\\' + 'params_' + mouse + '.json', 'r') as file:
        params = json.load(file)
    time_spans = params['time_span']
    return time_spans

def find_useable_mouse_paths(sleep_ppseq_path,useable_mirs,expert_mice,hlesion_mice,learning_mice,var_dict,sleep_start):
    current_mouse_path = []
    for run_index,pp_file in enumerate(os.listdir(sleep_ppseq_path)):
        if not 'sleep_time_points' in pp_file:
            # current mouse
            mouse = '_'.join(pp_file.split('_')[0:3])    

            if mouse in useable_mirs:
                    #print out progress
                    print(f"run index: {run_index}, processing {mouse}")
                    
                    # asign to experimental group in var_dict
                    var_dict = assign_to_group(mouse,expert_mice,hlesion_mice,learning_mice,var_dict)

                    # load in sleep start time and time span
                    var_dict['current_sleep_start'] += sleep_start[mouse]
                    var_dict['time_spans'] += get_time_span(sleep_ppseq_path,pp_file,mouse)

                    # set path to processed files 
                    current_mouse_path += [sleep_ppseq_path + pp_file + '\\analysis_output\\']
                    var_dict['mirs'] += [mouse]
    return current_mouse_path,var_dict


def make_filter_masks(data,sequential_filter,nrem_filter,rem_filter,sleep_filters_on,background_only):
    ## filter this data
    if sequential_filter == True: 
        sequential_condition = data.ordering_classification == 'sequential'
    else:
        sequential_condition = np.array([True]*len(data.ordering_classification))

    if sleep_filters_on == True:
        if nrem_filter == True: 
            nrem_condition = data.nrem_events == 1
        else:
            nrem_condition = np.array([False]*len(data.nrem_events))

        if rem_filter == True: 
            rem_condition = data.rem_events == 1
        else:
            rem_condition = np.array([False]*len(data.rem_events))

        if background_only == True:
            rem_condition = data.rem_events == 0
            nrem_condition = data.nrem_events == 0

    else:
        nrem_condition = np.array([True]*len(data))
        rem_condition = np.array([True]*len(data))
        
    # filter is set up so that any true will carry forward 
    filter_mask = sequential_condition * (nrem_condition + rem_condition)
        
    return filter_mask

def determine_chunk_mins(chunk_time,sleep_filters_on,nrem_filter,rem_filter,background_only,path):
    # if sleep_filters_on is false, use all chunk time
    if sleep_filters_on == False:
        mins = np.diff(chunk_time)[0]
    else:
        # load in state times
        rem_state_times = np.load(path + 'rem_state_times.npy')
        nrem_state_times = np.load(path + 'nrem_state_times.npy')
        if len(rem_state_times) > 0:
            tot_rem = sum(np.diff(rem_state_times))[0]
        else:
            tot_rem = 0
        if len(nrem_state_times) > 0:
            tot_nrem = sum(np.diff(nrem_state_times))[0]
        else:
            tot_nrem = 0

        # if background then use all non rem and non nrem times
        if background_only:
            mins = np.diff(chunk_time)[0] - (tot_rem+tot_nrem)
        else:
            # if both, use both 
            if nrem_filter == True and rem_filter == True:
                mins = tot_rem+tot_nrem
            elif nrem_filter == True and rem_filter == False:
                mins = tot_nrem
            elif nrem_filter == False and rem_filter == True:
                mins = tot_rem
    # convert to mins            
    mins = mins/60
    
    return mins

def empty_chunk_vars():
    ## set chunk vars 
    chunk_vars = {"chunk_rpm": [],           
    "chunk_motif_type_reactivations" :[],
    "chunk_motif_type_reactivations_min" :[],
    "chunk_motif_type_relative_proportion" :[],
    "chunk_event_lengths":[],
    "motif_event_lenghts":[],
    #2
    # "chunk_binned_rate": [],
    # "chunk_bins_relative_so": [],
    # #3
    # "chunk_event_lens": [],
    # #4
    # "coactive_freqs_chunk": {},
    # "chunk_total_nontask_task_related_events": [],
    # "total_events": 0,
    # "chunk_task_num_spikes": [],
    # "chunk_nontask_num_spikes": [],
    # "chunk_task_e_len": [],
    # "chunk_nontask_e_len": [],
    # #5
    # "chunk_summed_amounts": [],
    # "chunk_ordered_sum": 0,
    # "chunk_coactive_total": 0,
    # "task_related": 0,
    # "non_task_related": 0
    }
    return chunk_vars

In [3]:
# load in sleep ppseq data - the replay data

# replay rate 
# replay length 
# motif rate
# decay rate? 


## then run this for a load of mice, make a new file for the plots


## for the osciallation analysis, look into the LFP preprocess file...make sure it makes sense in terms of the alignment. 
## do some lfp processing, then start the osciallation analysis! 

# replay processing loop

In [4]:
# seq filter takes presedence, if its on: only sequential events, if it is off: all events 
sequential_filter = True
## master switch - turns all sleep filters on/off (if you want all evets turn this off)
sleep_filters_on = True
# these filters refer to seq one above, and both can be true at the same time. 
nrem_filter = True
rem_filter = True
# set this as true (along with the sleep filter one) to override the other two an djust take the background 
background_only = False


## sanity checker / set save path:
print('this filtering gives...')
if sequential_filter == True:
    print(' - only sequential events')
    save_var = 'sequential_no_sleep_selected'
    type_var = 'sequential'
else:
    print('- all events')
    save_var = 'all_events_no_sleep_selected'
    type_var = 'all_events'
if sleep_filters_on == True:
    if not background_only:
        print('and only those which are in')
        if nrem_filter == True:
            print(' - nrem')
            save_var = type_var+'_NREM_sleep'
        if rem_filter == True:
            print(' - rem')
            save_var = type_var+'_REM_sleep'
        if nrem_filter == True and rem_filter == True:
            save_var = type_var+'_NREM_and_REM_sleep'
        
    else:
        print('and only those which are not in rem/nrem')
        save_var = type_var + '_OTHER_nonsleep'
        
        



this filtering gives...
 - only sequential events
and only those which are in
 - nrem
 - rem


In [5]:
sleep_ppseq_path = r"Z:\projects\sequence_squad\organised_data\ppseq_data\finalised_output\striatum\paper_submission\post_sleep\\"
out_path = r"Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\behaviour_to_replay\processed_data\\" + save_var + '\\'

useable_mirs = ['178_1_7','149_1_1']

# load in sleep time points
sleep_time_point_df = pd.read_csv(sleep_ppseq_path + 'sleep_time_points.csv')
# decide when sleep started
sleep_start = {}
for index,value in enumerate(sleep_time_point_df.approx_sleep_start.values):
    mouse = sleep_time_point_df.mir.values[index]
    sleep_start[mouse] = value
# define mice/sessions in each group    
expert_mice = sleep_time_point_df[sleep_time_point_df.group == 'expert'].mir.values
hlesion_mice = sleep_time_point_df[sleep_time_point_df.group == 'h_lesion'].mir.values
learning_mice = sleep_time_point_df[sleep_time_point_df.group == 'learning'].mir.values

# ## set up empty vars 
# mirs = []

var_dict = {'expert':[],'hlesion':[],'learning' :[],'mirs':[],'current_sleep_start':[], 'time_spans':[]}

# # 1
# reactivations_per_min = []
# # 2
# event_rate_binned = []
# er_bins_relative_to_so = []
# # 3
# event_lens = []
# # 4
# av_coactive_len_per_chunk = []

# e_coactive_freqs_counts = {}
# hl_coactive_freqs_counts = {}
# l_coactive_freqs_counts = {}

# all_total_events = []
# rel_task_nontask = []
# chunks_task_nontask = []

# task_nontask_num_spikes = []
# task_nontask_e_len = []

# chunk_expert = []
# chunk_mid_time_post_onset = []
# #5 
# mouse_summed_amounts = []
# ordered_sum = []
# ordered_misordered_total = []


# warps = []


# get all the relevant path name and some other data 
current_mouse_path,var_dict = find_useable_mouse_paths(sleep_ppseq_path,useable_mirs,expert_mice,hlesion_mice,learning_mice,var_dict,sleep_start)
# loop across each mouse path:
for index, path in enumerate(current_mouse_path):
    # create empty chunk vars dict
    chunk_vars = empty_chunk_vars()
    break

run index: 3, processing 149_1_1
run index: 7, processing 178_1_7


In [24]:

## loop across all chunk files
for file in os.listdir(path):
    if 'chunk' in file:
        print(file)
        path_ = path + '\\' + file + '\\'
        chunk_time = np.load(path_ + 'chunk_time_interval.npy')
        data = pd.read_csv(path_ + 'filtered_replay_clusters_df.csv')
        
        # filter based on the sequential/rem-nrem conditions set above
        filter_mask = make_filter_masks(data,sequential_filter,nrem_filter,rem_filter,sleep_filters_on,background_only)
        filtered_chunk_data = data[filter_mask].reset_index()
        
        # how many reactivations found
        reactivations_found = len(filtered_chunk_data)
        print(reactivations_found)
        
        ####################################### chunk rate per minute: (# this one depends on rem/nrem filter... )
        mins = determine_chunk_mins(chunk_time,sleep_filters_on,nrem_filter,rem_filter,background_only,path_)
        if mins > 0:
            chunk_vars['chunk_rpm'] += [reactivations_found/mins]    
            
        ####################################### replay rate per motif type
        all_motif_type_reactivations = []
        all_motif_type_reactivations_min = []
        all_motif_type_relative_proportion = []
        for seq_type in range(1,7):
            motif_type_reactivations = [len(np.where(filtered_chunk_data.cluster_seq_type.values == seq_type)[0])][0]
            motif_type_reactivations_min = (motif_type_reactivations / mins)
            relative_motif_proportion = motif_type_reactivations / len(filtered_chunk_data.cluster_seq_type.values)
            all_motif_type_reactivations += [motif_type_reactivations]
            all_motif_type_reactivations_min += [motif_type_reactivations_min]
            all_motif_type_relative_proportion += [motif_type_reactivations]
        chunk_vars['chunk_motif_type_reactivations'] = all_motif_type_reactivations
        chunk_vars['chunk_motif_type_reactivations_min'] = all_motif_type_reactivations_min
        chunk_vars['chunk_motif_type_relative_proportion'] = all_motif_type_relative_proportion
        
        ########################################### replay length overall 
        chunk_vars['chunk_event_lengths'] = filtered_chunk_data.event_length.values
        
        ########################################### replay length per motif 
        motif_event_lenghts = []
        for i in range(1,7):
            motif_event_lenghts += [filtered_chunk_data[filtered_chunk_data.cluster_seq_type == i].event_length.values]
        chunk_vars['motif_event_lenghts'] = motif_event_lenghts
        
        ########################################### coactive rate overall
        
        
                
        ########################################### task related vs other rate
        # load task related seqs
        task_seqs = np.load(path_ + 'task_order_seqs.npy')+1
        # mask each condition
        mask = np.isin(filtered_chunk_data.cluster_seq_type.values, task_seqs)
        opposite_mask = ~mask

        task_related = filtered_chunk_data[mask]
        non_task_related = filtered_chunk_data[opposite_mask]

        # rate
        # replay length
        # coative rate 
                

        
        # motif coactive rate
        
        # save out to newly made place

            

        break




chunk1_8300to9300
577


In [11]:
def cluster_events(start_times, end_times, threshold):
    clusters = []
    for i in range(len(start_times)):
        event_added = False
        for cluster in clusters:
            for index in cluster:
                if (start_times[i] <= end_times[index] + threshold and end_times[i] >= start_times[index] - threshold):
                    cluster.append(i)
                    event_added = True
                    break
            if event_added:
                break
        if not event_added:
            clusters.append([i])
    return clusters

def relative_dict(input_dict):
    total_sum = sum(input_dict.values())
    relative_dict = {key: value / total_sum for key, value in input_dict.items()}
    return relative_dict

def refind_cluster_events(filtered_chunk_data,event_proximity_filter):
    
    ### ignore the origonal clusterg rosp and remake them: 
    start_times = filtered_chunk_data.first_spike_time.values
    end_times = filtered_chunk_data.last_spike_time.values

    clustered_events = cluster_events(start_times, end_times,event_proximity_filter)

    cluster_group = np.zeros(len(filtered_chunk_data))
    for index,cluster in enumerate(clustered_events):
        for item in cluster:
            cluster_group[item] = int(index)
    filtered_chunk_data['coactive_cluster_group'] = cluster_group
    
    return filtered_chunk_data

########## coactive rate
event_proximity_filter =  0.3 #s (how close events have to be to each other to be clustered together as coacitve 
# refind the clusters
filtered_chunk_data  = refind_cluster_events(filtered_chunk_data,event_proximity_filter)

# how many coactive?

# average number in coactive? 

# ordering of coactive?




In [None]:


# work out how mnay coacitve in chunk: 
current_coactive_freqs_chunk = {}
for cluster in filtered_chunk_data.coactive_cluster_group.unique():
    num = list(filtered_chunk_data.coactive_cluster_group.values).count(cluster)
    if num in current_coactive_freqs_chunk:
        current_coactive_freqs_chunk[num] += 1
    else:
        current_coactive_freqs_chunk[num] = 1
        
avs =[]
for item in current_coactive_freqs_chunk:
    if item > 1:
        avs += current_coactive_freqs_chunk[item] * [item]
# av_coactive_len_per_chunk += [np.mean(avs)]





In [None]:


# make it relative:
current_coactive_freqs_chunk = relative_dict(current_coactive_freqs_chunk)

coactive_freqs_keys = list(current_coactive_freqs_chunk.keys())
rel_coactive_freqs = list(current_coactive_freqs_chunk.values())
for index,item in enumerate(rel_coactive_freqs):
    num = int(coactive_freqs_keys[index])
    if num in coactive_freqs_chunk:
        coactive_freqs_chunk[num] += [item]
    else:
        coactive_freqs_chunk[num] = [item]


task_events = filtered_chunk_data[filtered_chunk_data.cluster_seq_type.isin(task_seqs)]
non_task_events = filtered_chunk_data[~filtered_chunk_data.cluster_seq_type.isin(task_seqs)]

chunk_task_num_spikes+=list(task_events.num_spikes)
chunk_nontask_num_spikes+=list(non_task_events.num_spikes)
chunk_task_e_len+=list(task_events.event_length)
chunk_nontask_e_len+=list(non_task_events.event_length)


# 5 ##############################################################################

############################################## split into multi clusters and process

multi_cluster_df = pd.DataFrame({'cluster_seq_type':[],
    'num_spikes':[],
    'num_neurons':[],
    'first_spike_time':[],
    'event_length':[],
    'last_spike_time':[],
    'cluster_spike_times':[],
    'cluster_neurons':[],
    'spike_plotting_order':[],
    'coactive_cluster_group':[],
    'new_cluster_group':[],
    'cluster_order_first_spike_defined':[],
    'cluster_order_mean_weighted_spikes_defined':[],
    'pairs_mean_ordering':[],
    'catagories_mean_ordering':[],
    'pairs_fs_ordering':[],
    'catagories_fs_ordering':[],
    'real_sequence_order':[]})
meaned_order = []
fs_order = []
event_times = []
multi_cluster_df
count = 0
for i,group in enumerate(filtered_chunk_data.coactive_cluster_group.unique()):
    group_mask = filtered_chunk_data.coactive_cluster_group == group
    current_cluster = filtered_chunk_data[group_mask]
    if len(current_cluster) > 1:
        means = []
        event_types = []
        fs_orders = []
        for index,events in enumerate(current_cluster.cluster_spike_times):
            event_types += [current_cluster.cluster_seq_type.values[index]]
            # calculate event order based on spike time weighted mean
            means += [np.mean(ast.literal_eval(events))]
            # calculate order based on first spike time:
            fs_orders += [current_cluster.first_spike_time.values[index]]

        # order by mean time:    
        meaned_order += [list(np.array(event_types)[np.argsort(means)])]
        # order by first spike:
        fs_order += [list(np.array(event_types)[np.argsort(fs_orders)])]

        event_times += [fs_orders]

        current_cluster['new_cluster_group'] =  [count]*len(current_cluster)
        current_cluster['cluster_order_first_spike_defined'] =  list(np.argsort(np.argsort(fs_orders)))
        current_cluster['cluster_order_mean_weighted_spikes_defined'] =  list(np.argsort(np.argsort(means)))

        if count == 0:
            multi_cluster_df = current_cluster.copy()
        else:
            # Concatenate the DataFrames vertically (row-wise)
            multi_cluster_df = pd.concat([multi_cluster_df, current_cluster], axis=0)
            # Reset the index if needed
            multi_cluster_df = multi_cluster_df.reset_index(drop=True)

        count += 1

############################################## Load in seq order data 

awake_PP_path = r"Z:\projects\sequence_squad\organised_data\ppseq_data\finalised_output\striatum\awake\\"

for index_,M_I_R in enumerate(os.listdir(awake_PP_path)):
    if not M_I_R == 'not_suitable':
        mir = '_'.join(M_I_R.split('_')[0:3])
        if mir == mouse:
            c_path = awake_PP_path + M_I_R + r"\analysis_output\reordered_recolored\\" 

sequence_order_df = pd.read_csv(awake_PP_path+"sequence_order.csv")

import ast
seq_order= ast.literal_eval(sequence_order_df[sequence_order_df.mir == mouse].seq_order.values[0])
num_dominant_seqs = int(sequence_order_df[sequence_order_df.mir == mouse].dominant_task_seqs)

############################################## calculate catagory breakdown

if len(multi_cluster_df.coactive_cluster_group.unique()) > 1:

    real_order = list(np.array(seq_order)+1)

    # # mean ordering first : 
    if len(real_order) > 3: # 3 will always be ordered so exclude
        relative_amounts,amounts,pair_outcomes,pairs = catagorize_seqs(real_order,num_dominant_seqs,meaned_order)
        summed_amounts = [sum(items) for items in conactinate_nth_items(amounts)]
    #     labels = ['ordered','reverse','repeat','misordered','other_to_task','task_to_other','other']
    #     fig, ax = plt.subplots()
    #     ax.bar(labels,summed_amounts)
    #     ax.set_title('catagory occurances (seqs ordered by mean spike time)')

    #     SaveFig('catagory occurances_1___chunk'+ str(index_+1) + '.png',chunk_path)

        all_pair_outcomes_todf = []
        all_pairs_todf = []
        for group in multi_cluster_df.new_cluster_group.unique():
            group_pairs = np.array(pairs)[multi_cluster_df[multi_cluster_df.new_cluster_group == group].index.values]
            group_pair_outcomes = np.array(pair_outcomes)[multi_cluster_df[multi_cluster_df.new_cluster_group == group].index.values]
            all_pairs = []
            all_pair_outcomes = []
            for index,pair_ in enumerate(group_pairs[0:-1]):
                all_pairs += [pair_]
                all_pair_outcomes += [group_pair_outcomes[index]]

            all_pair_outcomes_todf  += [all_pair_outcomes] * len(multi_cluster_df[multi_cluster_df.new_cluster_group == group])
            all_pairs_todf += [all_pairs] * len(multi_cluster_df[multi_cluster_df.new_cluster_group == group])

        multi_cluster_df['pairs_mean_ordering'] = all_pairs_todf
        multi_cluster_df['catagories_mean_ordering'] = all_pair_outcomes_todf

        multi_cluster_df['real_sequence_order'] = [real_order]*len(multi_cluster_df)

        chunk_summed_amounts += [list(np.array(summed_amounts)/sum(summed_amounts))]

        chunk_ordered_sum += sum(summed_amounts[0:3])
        chunk_coactive_total += sum(summed_amounts[0:4])
    else:
        print('only 3 seqs')

    
    
#                             print(chunk_summed_amounts)
    

In [7]:
filtered_chunk_data

Unnamed: 0.3,index,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,cluster_seq_type,num_spikes,num_neurons,first_spike_time,event_length,last_spike_time,cluster_spike_times,cluster_neurons,spike_plotting_order,coactive_cluster_group,ordering_classification,rem_events,nrem_events
0,9,9,9,9,1,7,78.9772,78.9622,0.0150,78.9772,"[78.9769, 78.9622, 78.9755, 78.9757, 78.977, 7...","[79.0, 149.0, 163.0, 163.0, 206.0, 208.0, 208.0]",[147. 129. 72. 72. 145. 148. 148.],8.0,sequential,0,1
1,10,10,10,10,1,15,93.5430,93.4474,0.0956,93.5430,"[93.4474, 93.4562, 93.4626, 93.4476, 93.4564, ...","[40.0, 40.0, 40.0, 42.0, 42.0, 42.0, 48.0, 48....",[140. 140. 140. 143. 143. 143. 142. 142. 115. ...,9.0,sequential,0,1
2,17,17,17,17,1,7,114.2586,114.2036,0.0550,114.2586,"[114.2374, 114.2375, 114.2511, 114.2512, 114.2...","[6.0, 6.0, 6.0, 6.0, 6.0, 206.0, 210.0]",[138. 138. 138. 138. 138. 145. 146.],15.0,sequential,0,1
3,18,18,18,18,1,11,115.0579,115.0205,0.0374,115.0579,"[115.0206, 115.0207, 115.0207, 115.0208, 115.0...","[6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 149.0, 149...",[138. 138. 138. 138. 138. 137. 137. 129. 129. ...,16.0,sequential,0,1
4,21,21,21,21,1,5,119.7683,119.7589,0.0094,119.7683,"[119.7589, 119.768, 119.7682, 119.7683, 119.7679]","[21.0, 40.0, 42.0, 45.0, 48.0]",[116. 140. 143. 139. 142.],18.0,sequential,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
572,1911,1911,1911,1911,6,9,962.7765,962.7655,0.0110,962.7765,"[962.7655, 962.7667, 962.7715, 962.7764, 962.7...","[15.0, 60.0, 60.0, 60.0, 61.0, 61.0, 61.0, 148...",[ 96. 105. 105. 105. 106. 106. 106. 182. 182.],1078.0,sequential,0,1
573,1913,1913,1913,1913,6,9,968.9051,968.8036,0.1015,968.9051,"[968.8868, 968.8956, 968.8036, 968.8827, 968.8...","[15.0, 15.0, 101.0, 170.0, 170.0, 170.0, 170.0...",[ 96. 96. 117. 97. 97. 97. 97. 242. 242.],978.0,sequential,0,1
574,1914,1914,1914,1914,6,5,969.7686,969.7467,0.0219,969.7686,"[969.7686, 969.7686, 969.7583, 969.7526, 969.7...","[68.0, 70.0, 200.0, 204.0, 212.0]",[100. 99. 101. 248. 111.],980.0,sequential,0,1
575,1916,1916,1916,1916,6,20,982.2133,982.1402,0.0731,982.2133,"[982.2066, 982.2067, 982.2069, 982.1887, 982.1...","[8.0, 10.0, 10.0, 62.0, 68.0, 70.0, 82.0, 82.0...",[212. 213. 213. 175. 100. 99. 108. 108. 108. ...,128.0,sequential,0,1


In [None]:







                        
                        #2################################

#                             current_sleep_start = sleep_start[mouse] - 400
                            chunk_number = int(file.split('_')[0][-1])
                            start_offset = ([0]+list(np.cumsum(np.diff(time_spans))))[chunk_number-1]


                            # take away cumulative chunk offset - this gives time in terms of chunk
                            f_spike_times = filtered_chunk_data.first_spike_time.values - start_offset
                            # add on ephys time that chunk started - so its in ephys timestamps 
                            f_spike_times = f_spike_times + chunk_time[0]

                            # now make relative to sleep start time
                            f_spike_times_relative_to_so = f_spike_times - current_sleep_start 
                            # do the same but for rem and nrem start

                            # filter out anything that happened before sleep onset
                            f_spike_times_relative_to_so = f_spike_times_relative_to_so[f_spike_times_relative_to_so > 0]

                            ## calculate rate over time:
                            time_data = pd.Series(f_spike_times_relative_to_so)
                            if len(time_data) > 0:
#                                 # Calculate the number of bins required # 5 minute bins
#                                 num_bins = int((time_data.max() - time_data.min()) // 40 + 1)
#                                 # Create bins and count the occurrences in each bin
#                                 chunk_event_rate, chunk_relative_time_bins = np.histogram(time_data, bins=num_bins)
#                                 #remove extra final bin and convert to mins
#                                 chunk_relative_time_bins = chunk_relative_time_bins[0:-1]/60

                                # Calculate the number of bins required # 20s bins
                            #     num_bins = int((time_data.max() - time_data.min()) // 40 + 1)
                                if time_data.max() - time_data.min() > 19:
                                    num_bins = int((time_data.max() - time_data.min())//20)
                                    # Create bins and count the occurrences in each bin
                                    chunk_event_rate, chunk_relative_time_bins = np.histogram(time_data, bins=num_bins)
                                    #remove extra final bin and convert to mins
                                    chunk_relative_time_bins = chunk_relative_time_bins[0:-1]/60


                                    chunk_binned_rate += [list((chunk_event_rate*3).astype(float))] # *3 because its per 20s so we want it per minute )
                                    chunk_bins_relative_so += [list(chunk_relative_time_bins.astype(float))]

                        
                        #3########################################################

                        chunk_event_lens += list(filtered_chunk_data.event_length.values)

                        #4 ################################################# coactive stuff -300ms = coactive
                        event_proximity_filter =  0.3 #s (how close events have to be to each other to be clustered together as coacitve 

                        task_seqs = np.load(current_data_path + 'task_order_seqs.npy')+1
            
                        for motif_type in filtered_chunk_data.cluster_seq_type:
                            if motif_type in task_seqs:
                                task_related += 1
                            else:
                                non_task_related += 1
        
                        total_events += len(filtered_chunk_data.cluster_seq_type)

                        # normalise by number of each type: 
#                         if (6-len(task_seqs)) == 0:
#                             chunk_total_nontask_task_related_events += [[non_task_related,(task_related/len(task_seqs))]]
#                         else:
#                             chunk_total_nontask_task_related_events += [[non_task_related/(6-len(task_seqs)),(task_related/len(task_seqs))]]

                        chunk_mid_time_post_onset += [((sum(chunk_time)/2)-current_sleep_start)]

                        ### ignore the origonal clusterg rosp and remake them: 
                        start_times = filtered_chunk_data.first_spike_time.values
                        end_times = filtered_chunk_data.last_spike_time.values

                        clustered_events = cluster_events(start_times, end_times,event_proximity_filter)

                        cluster_group = np.zeros(len(filtered_chunk_data))
                        for index,cluster in enumerate(clustered_events):
                            for item in cluster:
                                cluster_group[item] = int(index)
                        filtered_chunk_data['coactive_cluster_group'] = cluster_group

                        # work out how mnay coacitve in chunk: 
                        current_coactive_freqs_chunk = {}
                        for cluster in filtered_chunk_data.coactive_cluster_group.unique():
                            num = list(filtered_chunk_data.coactive_cluster_group.values).count(cluster)
                            if num in current_coactive_freqs_chunk:
                                current_coactive_freqs_chunk[num] += 1
                            else:
                                current_coactive_freqs_chunk[num] = 1

                        avs =[]
                        for item in current_coactive_freqs_chunk:
                            avs += current_coactive_freqs_chunk[item] * [item]
                        av_coactive_len_per_chunk += [np.mean(avs)]
                        if mouse in expert_mice:
                            chunk_expert += [1]
                        elif mouse in hlesion_mice:
                            chunk_expert += [2]
                        elif mouse in learning_mice:
                            chunk_expert += [3]


                        # make it relative:
                        current_coactive_freqs_chunk = relative_dict(current_coactive_freqs_chunk)

                        coactive_freqs_keys = list(current_coactive_freqs_chunk.keys())
                        rel_coactive_freqs = list(current_coactive_freqs_chunk.values())
                        for index,item in enumerate(rel_coactive_freqs):
                            num = int(coactive_freqs_keys[index])
                            if num in coactive_freqs_chunk:
                                coactive_freqs_chunk[num] += [item]
                            else:
                                coactive_freqs_chunk[num] = [item]


                        task_events = filtered_chunk_data[filtered_chunk_data.cluster_seq_type.isin(task_seqs)]
                        non_task_events = filtered_chunk_data[~filtered_chunk_data.cluster_seq_type.isin(task_seqs)]

                        chunk_task_num_spikes+=list(task_events.num_spikes)
                        chunk_nontask_num_spikes+=list(non_task_events.num_spikes)
                        chunk_task_e_len+=list(task_events.event_length)
                        chunk_nontask_e_len+=list(non_task_events.event_length)


                        # 5 ##############################################################################

                        ############################################## split into multi clusters and process

                        multi_cluster_df = pd.DataFrame({'cluster_seq_type':[],
                         'num_spikes':[],
                         'num_neurons':[],
                         'first_spike_time':[],
                         'event_length':[],
                         'last_spike_time':[],
                         'cluster_spike_times':[],
                         'cluster_neurons':[],
                         'spike_plotting_order':[],
                         'coactive_cluster_group':[],
                         'new_cluster_group':[],
                         'cluster_order_first_spike_defined':[],
                         'cluster_order_mean_weighted_spikes_defined':[],
                         'pairs_mean_ordering':[],
                         'catagories_mean_ordering':[],
                         'pairs_fs_ordering':[],
                         'catagories_fs_ordering':[],
                         'real_sequence_order':[]})
                        meaned_order = []
                        fs_order = []
                        event_times = []
                        multi_cluster_df
                        count = 0
                        for i,group in enumerate(filtered_chunk_data.coactive_cluster_group.unique()):
                            group_mask = filtered_chunk_data.coactive_cluster_group == group
                            current_cluster = filtered_chunk_data[group_mask]
                            if len(current_cluster) > 1:
                                means = []
                                event_types = []
                                fs_orders = []
                                for index,events in enumerate(current_cluster.cluster_spike_times):
                                    event_types += [current_cluster.cluster_seq_type.values[index]]
                                    # calculate event order based on spike time weighted mean
                                    means += [np.mean(ast.literal_eval(events))]
                                    # calculate order based on first spike time:
                                    fs_orders += [current_cluster.first_spike_time.values[index]]

                                # order by mean time:    
                                meaned_order += [list(np.array(event_types)[np.argsort(means)])]
                                # order by first spike:
                                fs_order += [list(np.array(event_types)[np.argsort(fs_orders)])]

                                event_times += [fs_orders]

                                current_cluster['new_cluster_group'] =  [count]*len(current_cluster)
                                current_cluster['cluster_order_first_spike_defined'] =  list(np.argsort(np.argsort(fs_orders)))
                                current_cluster['cluster_order_mean_weighted_spikes_defined'] =  list(np.argsort(np.argsort(means)))

                                if count == 0:
                                    multi_cluster_df = current_cluster.copy()
                                else:
                                    # Concatenate the DataFrames vertically (row-wise)
                                    multi_cluster_df = pd.concat([multi_cluster_df, current_cluster], axis=0)
                                    # Reset the index if needed
                                    multi_cluster_df = multi_cluster_df.reset_index(drop=True)

                                count += 1

                        ############################################## Load in seq order data 

                        awake_PP_path = r"Z:\projects\sequence_squad\organised_data\ppseq_data\finalised_output\striatum\awake\\"

                        for index_,M_I_R in enumerate(os.listdir(awake_PP_path)):
                            if not M_I_R == 'not_suitable':
                                mir = '_'.join(M_I_R.split('_')[0:3])
                                if mir == mouse:
                                    c_path = awake_PP_path + M_I_R + r"\analysis_output\reordered_recolored\\" 

                        sequence_order_df = pd.read_csv(awake_PP_path+"sequence_order.csv")

                        import ast
                        seq_order= ast.literal_eval(sequence_order_df[sequence_order_df.mir == mouse].seq_order.values[0])
                        num_dominant_seqs = int(sequence_order_df[sequence_order_df.mir == mouse].dominant_task_seqs)

                        ############################################## calculate catagory breakdown

                        if len(multi_cluster_df.coactive_cluster_group.unique()) > 1:

                            real_order = list(np.array(seq_order)+1)

                            # # mean ordering first : 
                            if len(real_order) > 3: # 3 will always be ordered so exclude
                                relative_amounts,amounts,pair_outcomes,pairs = catagorize_seqs(real_order,num_dominant_seqs,meaned_order)
                                summed_amounts = [sum(items) for items in conactinate_nth_items(amounts)]
                            #     labels = ['ordered','reverse','repeat','misordered','other_to_task','task_to_other','other']
                            #     fig, ax = plt.subplots()
                            #     ax.bar(labels,summed_amounts)
                            #     ax.set_title('catagory occurances (seqs ordered by mean spike time)')

                            #     SaveFig('catagory occurances_1___chunk'+ str(index_+1) + '.png',chunk_path)

                                all_pair_outcomes_todf = []
                                all_pairs_todf = []
                                for group in multi_cluster_df.new_cluster_group.unique():
                                    group_pairs = np.array(pairs)[multi_cluster_df[multi_cluster_df.new_cluster_group == group].index.values]
                                    group_pair_outcomes = np.array(pair_outcomes)[multi_cluster_df[multi_cluster_df.new_cluster_group == group].index.values]
                                    all_pairs = []
                                    all_pair_outcomes = []
                                    for index,pair_ in enumerate(group_pairs[0:-1]):
                                        all_pairs += [pair_]
                                        all_pair_outcomes += [group_pair_outcomes[index]]

                                    all_pair_outcomes_todf  += [all_pair_outcomes] * len(multi_cluster_df[multi_cluster_df.new_cluster_group == group])
                                    all_pairs_todf += [all_pairs] * len(multi_cluster_df[multi_cluster_df.new_cluster_group == group])

                                multi_cluster_df['pairs_mean_ordering'] = all_pairs_todf
                                multi_cluster_df['catagories_mean_ordering'] = all_pair_outcomes_todf

                                multi_cluster_df['real_sequence_order'] = [real_order]*len(multi_cluster_df)

                                chunk_summed_amounts += [list(np.array(summed_amounts)/sum(summed_amounts))]

                                chunk_ordered_sum += sum(summed_amounts[0:3])
                                chunk_coactive_total += sum(summed_amounts[0:4])
                            else:
                                print('only 3 seqs')

                            
                            
#                             print(chunk_summed_amounts)
                            
        
                # outside of chunk loop ################################################
                
                # changed how i do this, now task freq is worke dout by adding up instances across all chunks and lookig at the proportion rather than averageing across chunks 
                if (6-len(task_seqs)) == 0:
                    chunk_total_nontask_task_related_events += [[non_task_related,(task_related/len(task_seqs))]]
                else:
                    chunk_total_nontask_task_related_events += [[non_task_related/(6-len(task_seqs)),(task_related/len(task_seqs))]]      

                ### add to animal vars
                #1
                reactivations_per_min += [np.mean(chunk_rpm)]
                if np.mean(chunk_rpm) < 3:
                    print('!!!!!')
                #2
                event_rate_binned +=[chunk_binned_rate]
                er_bins_relative_to_so +=[chunk_bins_relative_so]
                #3
                event_lens += [chunk_event_lens]


                #4 #########    
                relative = []
                totals = [sum(item) for item in chunk_total_nontask_task_related_events]
                for i,item in enumerate(chunk_total_nontask_task_related_events):
                    relative += [list(np.array(item)/totals[i])]

                all_total_events += [total_events]

                num_task_order_seqs = len(np.load(current_data_path+ 'task_order_seqs.npy')+1)

                rel_task_nontask += [[np.mean(conactinate_nth_items(relative)[1]),np.mean(conactinate_nth_items(relative)[0])]]

                chunks_task_nontask += conactinate_nth_items(relative)[1]

                for item in coactive_freqs_chunk:
                    if mouse in expert_mice:
                        if item in e_coactive_freqs_counts:
                            e_coactive_freqs_counts[item] += [np.mean(coactive_freqs_chunk[item])]
                        else:
                            e_coactive_freqs_counts[item] = [np.mean(coactive_freqs_chunk[item])]
                    elif mouse in hlesion_mice:
                        if item in hl_coactive_freqs_counts:
                            hl_coactive_freqs_counts[item] += [np.mean(coactive_freqs_chunk[item])]
                        else:
                            hl_coactive_freqs_counts[item] = [np.mean(coactive_freqs_chunk[item])]
                    elif mouse in learning_mice:
                        if item in l_coactive_freqs_counts:
                            l_coactive_freqs_counts[item] += [np.mean(coactive_freqs_chunk[item])]
                        else:
                            l_coactive_freqs_counts[item] = [np.mean(coactive_freqs_chunk[item])]



                task_nontask_num_spikes+= [[np.mean(chunk_task_num_spikes),np.mean(chunk_nontask_num_spikes)]]
                task_nontask_e_len+= [[np.mean(chunk_task_e_len),np.mean(chunk_nontask_e_len)]]

                #5 #############

                if len(chunk_summed_amounts) > 0:
                    c_summed_amounts = []
                    for item in conactinate_nth_items(chunk_summed_amounts):
                        c_summed_amounts +=[np.mean(item)]
                    mouse_summed_amounts += [c_summed_amounts]
                else:
                    mouse_summed_amounts += [[]]
                    
                    
                ordered_sum += [chunk_ordered_sum]
                ordered_misordered_total += [chunk_coactive_total]


In [None]:
if 'EJT'in mir:
    c_mir = mir.split('T')[-1]
else:
    c_mir = mir
full_sleep_path = None
for ppsleep_file in os.listdir(sleep_ppseq_path):
    if c_mir in ppsleep_file:
        'print sleep file found'
        full_sleep_path = os.path.join(sleep_ppseq_path,ppsleep_file + '/analysis_output')
if full_sleep_path is None:
    raise Exception(f"no sleep file found for {mir}")

os.listdir(full_sleep_path)

chunk_paths = []
for file in os.listdir(full_sleep_path):
    if 'chunk' in file:
        chunk_paths += [os.path.join(full_sleep_path,file)]     