# import + functions

import os

In [31]:
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt  
import json

def make_transition_df(df):
    transitions = []
    in_in_Latency = []
    for index in range(len(df.Port.values)):
        if index < len(df.Port.values)-1:
            transit = int(str(df.Port.values[index]) + str(df.Port.values[index+1]))
            transitions += [transit]
            in_in_diff = df.PokeIn_Time.values[index+1] - df.PokeIn_Time.values[index]
            in_in_Latency += [in_in_diff]
        
 
    transit_df = pd.DataFrame({'Trial_id': df.Trial_id.values[0:-1], 
                                'Transition_type': transitions,
                                'in_in_Latency': in_in_Latency,
                                '2s_Time_Filter_in_in': list((np.array(in_in_Latency) <= 2).astype(int))})
    return transit_df

def get_mean_port_port_transition_latencies(day1_path):
    raw_poke_path = os.path.join(day1_path, np.array(os.listdir(day1_path))[['.csv' in file for file in os.listdir(day1_path)]][0])
    raw_pokes = pd.read_csv(raw_poke_path)
    raw_transits = make_transition_df(raw_pokes)
    mean_port_port_transition_latencies = np.mean(raw_transits['in_in_Latency'].values)
    return mean_port_port_transition_latencies

# main

In [32]:
# pull in the data and calculate deltas
data_path = r"Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\"

replay_type = 'sequential_NREM_and_REM_sleep'


In [33]:
# create dfs to hold the data
predictor_matrix_general_df = pd.DataFrame()
predictor_matrix_motif_df = pd.DataFrame()
outcome_matrix_general_df = pd.DataFrame()
outcome_matrix_motif_df = pd.DataFrame()

for mir in os.listdir(data_path):
    Tracking = False
    Pokes1 = False
    Pokes2 = False
    
    current_path = os.path.join(data_path, mir)
    # only bother loading any of the data if there is replay data saved out. 
    if 'replay' in os.listdir(current_path):#
        print('-------------------------------')
        print(f"Processing {mir}")
        replay_path = os.path.join(current_path, 'replay',replay_type)
        # load replay json 
        json_path = replay_path + r'\replay_data_variables.json'
        with open(json_path, 'r') as file:
            replay_data = json.load(file)
            print(f"Loaded replay data for {mir} from {json_path} ✓ ")
            
        # load in beahvioural data
        for file in os.listdir(current_path):
            if 'day2' in file:
                day2_path = os.path.join(current_path, file)
                # i forgot to do this in the preprocessing, so i added it in here 
                mean_port_port_transition_latencies_2 = get_mean_port_port_transition_latencies(day2_path)
                if 'processed' in os.listdir(day2_path):
                    processed_poke_data_2 = pd.read_csv(os.path.join(day2_path, 'processed', 'processed_poke_data_2.csv'))
                    processed_poke_data_2['transition_times'] = mean_port_port_transition_latencies_2
                    print(f"Loaded poke data 2 for {day2_path} ✓ ")
                    Pokes2 = True
                    try:
                        processed_tracking_data_2 = pd.read_csv(os.path.join(day2_path, 'processed', 'processed_tracking_data.csv'))
                        Tracking = True
                        print(f"Tracking data 2 found for {day2_path} ✓ ")
                    except:
                        print(f"Tracking data not found for {day2_path} ✗")
                        Tracking = False
        for file in os.listdir(current_path):
            if 'day1' in file:
                day1_path = os.path.join(current_path, file)
                # i forgot to do this in the preprocessing, so i added it in here 
                mean_port_port_transition_latencies_1 = get_mean_port_port_transition_latencies(day1_path)
                if 'processed' in os.listdir(day1_path):
                    processed_poke_data_1 = pd.read_csv(os.path.join(day1_path, 'processed', 'processed_poke_data_1.csv'))
                    processed_poke_data_1['transition_times'] = mean_port_port_transition_latencies_1
                    Pokes1 = True
                    print(f"Loaded poke data 1 for {day1_path} ✓ ")
                    try:
                        processed_tracking_data_1 = pd.read_csv(os.path.join(day1_path, 'processed', 'processed_tracking_data.csv'))
                        Tracking = True
                        print(f"Tracking data 1 found for {day1_path} ✓ ")
                    except:
                        print(f"Tracking data not found for {day1_path} ✗")
                        Tracking = False
                        
        if Pokes1 * Pokes2 == True:
            #### Add replay data (predictor) to dataframes

            # add all the data to dataframes
            predictor_general_dict = {}
            predictor_general_dict['animal_id'] = mir
            for key, val in replay_data.items():
                if 'motif' not in key:
                    mean_val = np.mean(val) if np.ndim(val) > 0 else val
                    predictor_general_dict.setdefault(key, []).append(mean_val)

            predictor_motif_dict = {}
            for key, val in replay_data.items():
                predictor_motif_dict['animal_id'] = [mir] * len(replay_data['motif_event_rpm'])
                if 'motif' in key:
                    if np.ndim(replay_data[key][0]) > 0:
                        vals = [np.nanmean(value) for value in replay_data[key]]
                    else:
                        vals = val
                    predictor_motif_dict[key] = vals
                
            #concatenate to DataFrame
            predictor_matrix_general_df = pd.concat([predictor_matrix_general_df, pd.DataFrame(predictor_general_dict)], ignore_index=True)
            predictor_matrix_motif_df = pd.concat([predictor_matrix_motif_df, pd.DataFrame(predictor_motif_dict)], ignore_index=True)


            #################
            ## calculate behvaioural deltas
            delta_behavioural_data = processed_poke_data_2 - processed_poke_data_1
            if Tracking:
                delta_tracking_data = processed_tracking_data_2 - processed_tracking_data_1
                
            delta_behavioural_data['mouse_id'] = mir
            if Tracking:
                delta_tracking_data['mouse_id'] = [mir] * len(delta_tracking_data)   
                
            # add deltas to dataframes
            outcome_matrix_general_df = pd.concat([outcome_matrix_general_df, delta_behavioural_data], ignore_index=True)
            if Tracking:
                outcome_matrix_motif_df = pd.concat([outcome_matrix_motif_df, delta_tracking_data], ignore_index=True)




-------------------------------
Processing EJT136_1_3
Loaded replay data for EJT136_1_3 from Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT136_1_3\replay\sequential_NREM_and_REM_sleep\replay_data_variables.json ✓ 
Loaded poke data 2 for Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT136_1_3\day2_12-11-2021 ✓ 
Tracking data 2 found for Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT136_1_3\day2_12-11-2021 ✓ 
Loaded poke data 1 for Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT136_1_3\day1_11-11-2021 ✓ 
Tracking data 1 found for Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT136_1_3\day1_11-11-2021 ✓ 
-------------------------------
Processing EJT136_1_4
Loaded replay data for EJT136_1_4 from Z:\projects\sequen

  vals = [np.nanmean(value) for value in replay_data[key]]


Loaded replay data for EJT269_1_4 from Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT269_1_4\replay\sequential_NREM_and_REM_sleep\replay_data_variables.json ✓ 
-------------------------------
Processing EJT270_1_6
Loaded replay data for EJT270_1_6 from Z:\projects\sequence_squad\revision_data\emmett_revisions\sleep_wake_link_data\replay_to_behaviour\\EJT270_1_6\replay\sequential_NREM_and_REM_sleep\replay_data_variables.json ✓ 


In [34]:
outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times,mouse_id
0,92,-0.080249,0.075162,0.005087,-0.286907,1752,2.116936,-6.390712,-0.983713,EJT136_1_3
1,-148,0.109913,-0.084754,-0.025159,0.145115,-2435,-3.423323,2.684527,0.756373,EJT136_1_4
2,48,-0.036402,-0.002426,0.038827,0.031036,476,0.164377,-0.788421,-0.108143,EJT149_1_1
3,-149,-0.08508,0.014692,0.070388,-0.069142,-902,0.849644,3.477093,0.355008,EJT178_1_6
4,185,0.070512,-0.013032,-0.05748,0.072169,1080,-0.971027,-3.141863,-0.299464,EJT178_1_7
5,-149,-0.020067,0.007576,0.012491,-0.037825,-860,0.553224,-0.351901,-0.116723,EJT178_1_8
6,-2,-0.074117,0.014767,0.05935,-0.090783,735,2.741238,3.351847,0.07672,EJT178_2_1
7,-99,0.027252,0.009026,-0.036278,-0.024232,-1154,-0.536228,1.688874,0.229377,EJT178_2_3


In [None]:
# remebr these are deltas
# so if the value is positive, it means the value has increased from day 1 to day 2 

# 1) flip the sign of any columns which are negative in the outcome matrix (eg. higher error rate = lower performance )
negative_columns = ['error_transit_rate', 'neutral_transit_rate', 'transits_per_reward', 'seconds_per_reward']
flipped_outcome_matrix_general_df = outcome_matrix_general_df.copy()
for col in negative_columns:
    if col in flipped_outcome_matrix_general_df.columns:
        flipped_outcome_matrix_general_df[col] = -flipped_outcome_matrix_general_df[col]


# 2) zscore each column in the outcome matrices
zscore_outcome_matrix_general_df = outcome_matrix_general_df.copy()
# drop the mouse_id column if it exists
zscore_outcome_matrix_general_df = zscore_outcome_matrix_general_df.drop(columns=['mouse_id'])
#apply z-scoring to just those
zscore_outcome_matrix_general_df = zscore_outcome_matrix_general_df.apply(lambda x: (x - x.mean()) / x.std(), axis=0)

# 3) make a weighted dataframe for each column in the outcome matrix
num_trials_weight = 0 # I think this is probably irrelevant for performance
correct_transit_rate_weight = 1
error_transit_rate_weight = 1 
neutral_transit_rate_weight = 0.2
perf_score_weight = 0.8
total_transitions_weight = 0
transits_per_reward_weight = 0.5
seconds_per_reward_weight = 0.7
transition_times_weight = 1 

# make a weighted dataframe
weights_df = pd.DataFrame({
    'num_trials': [num_trials_weight]*len(zscore_outcome_matrix_general_df),
    'correct_transit_rate': [correct_transit_rate_weight]*len(zscore_outcome_matrix_general_df),
    'error_transit_rate': [error_transit_rate_weight]*len(zscore_outcome_matrix_general_df),
    'neutral_transit_rate': [neutral_transit_rate_weight]*len(zscore_outcome_matrix_general_df),
    'perf_score': [perf_score_weight]*len(zscore_outcome_matrix_general_df),
    'total_transitions': [total_transitions_weight]*len(zscore_outcome_matrix_general_df),
    'transits_per_reward': [transits_per_reward_weight]*len(zscore_outcome_matrix_general_df),
    'seconds_per_reward': [seconds_per_reward_weight]*len(zscore_outcome_matrix_general_df),
    'transition_times': [transition_times_weight]*len(zscore_outcome_matrix_general_df)
})
# make a weighted vector
weights_vector = np.array([num_trials_weight, correct_transit_rate_weight, error_transit_rate_weight, neutral_transit_rate_weight, perf_score_weight, total_transitions_weight, transits_per_reward_weight, seconds_per_reward_weight, transition_times_weight])

# normalise the weights so they sum to 1
weights_vector_normalised = weights_vector/sum(weights_vector)

# dot product the weights with the z-scored outcome matrix
overall_omnibus_deltas_per_animal = zscore_outcome_matrix_general_df.dot(weights_vector_normalised)

omnibus_delta_df = pd.DataFrame({'mouse_id': outcome_matrix_general_df['mouse_id'],
    'overall_delta' : overall_omnibus_deltas_per_animal})
# units are SD units - “SD units” just means you’re expressing a score in multiples of the standard deviation of the distribution - a zscore


In [None]:
## check above, im not convinced this is right yet

In [125]:
zscore_outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,0.934507,-0.942933,1.650122,-0.071983,-1.974888,1.374086,1.006886,-1.855055,-1.897505
1,-0.938408,1.647524,-1.987837,-0.728567,1.379717,-1.629463,-1.883361,0.752247,1.49806
2,0.591139,-0.345625,-0.114932,0.660447,0.493904,0.458746,-0.011726,-0.245526,-0.188938
3,-0.946212,-1.008738,0.274486,1.345563,-0.283971,-0.529764,0.345764,0.97995,0.714846
4,1.660261,1.110792,-0.356222,-1.430197,0.813293,0.892026,-0.604045,-0.921666,-0.562277
5,-0.946212,-0.123116,0.112609,0.088737,-0.040798,-0.499635,0.191128,-0.120114,-0.205679
6,0.200948,-0.859396,0.276189,1.105953,-0.45201,0.64454,1.332572,0.943967,0.171801
7,-0.556022,0.521493,0.145584,-0.969953,0.064752,-0.710536,-0.377219,0.466197,0.469692


In [126]:
omnibus_delta_df

Unnamed: 0,mouse_id,overall_delta
0,EJT136_1_3,-0.688408
1,EJT136_1_4,0.327058
2,EJT149_1_1,-0.057695
3,EJT178_1_6,0.169496
4,EJT178_1_7,-0.075058
5,EJT178_1_8,-0.042229
6,EJT178_2_1,0.149085
7,EJT178_2_3,0.217752


In [111]:
np.shape(weights_vector_normalised)

(9,)

In [105]:
# make a weighted vector
weights_vector = np.array([num_trials_weight, correct_transit_rate_weight, error_transit_rate_weight, neutral_transit_rate_weight, perf_score_weight, total_transitions_weight, transits_per_reward_weight, seconds_per_reward_weight, transition_times_weight])


array([0.        , 0.19230769, 0.19230769, 0.03846154, 0.15384615,
       0.        , 0.09615385, 0.13461538, 0.19230769])

In [86]:
sum(weights_df.iloc[0])

1.0

ValueError: matrices are not aligned

In [100]:
weights_df

Unnamed: 0,0,1,2,3,4,5,6,7
num_trials,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
correct_transit_rate,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308
error_transit_rate,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308
neutral_transit_rate,0.038462,0.038462,0.038462,0.038462,0.038462,0.038462,0.038462,0.038462
perf_score,0.153846,0.153846,0.153846,0.153846,0.153846,0.153846,0.153846,0.153846
total_transitions,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
transits_per_reward,0.096154,0.096154,0.096154,0.096154,0.096154,0.096154,0.096154,0.096154
seconds_per_reward,0.134615,0.134615,0.134615,0.134615,0.134615,0.134615,0.134615,0.134615
transition_times,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308,0.192308


Unnamed: 0,0,1,2,3,4,5,6,7
0,-0.688408,-0.688408,-0.688408,-0.688408,-0.688408,-0.688408,-0.688408,-0.688408
1,0.327058,0.327058,0.327058,0.327058,0.327058,0.327058,0.327058,0.327058
2,-0.057695,-0.057695,-0.057695,-0.057695,-0.057695,-0.057695,-0.057695,-0.057695
3,0.169496,0.169496,0.169496,0.169496,0.169496,0.169496,0.169496,0.169496
4,-0.075058,-0.075058,-0.075058,-0.075058,-0.075058,-0.075058,-0.075058,-0.075058
5,-0.042229,-0.042229,-0.042229,-0.042229,-0.042229,-0.042229,-0.042229,-0.042229
6,0.149085,0.149085,0.149085,0.149085,0.149085,0.149085,0.149085,0.149085
7,0.217752,0.217752,0.217752,0.217752,0.217752,0.217752,0.217752,0.217752


In [95]:
np.shape(weights_df)

(8, 9)

In [96]:
np.shape(zscore_outcome_matrix_general_df)

(8, 9)

In [98]:
weights_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
1,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
2,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
3,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
4,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
5,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
6,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308
7,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308


In [97]:
zscore_outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,0.934507,-0.942933,1.650122,-0.071983,-1.974888,1.374086,1.006886,-1.855055,-1.897505
1,-0.938408,1.647524,-1.987837,-0.728567,1.379717,-1.629463,-1.883361,0.752247,1.49806
2,0.591139,-0.345625,-0.114932,0.660447,0.493904,0.458746,-0.011726,-0.245526,-0.188938
3,-0.946212,-1.008738,0.274486,1.345563,-0.283971,-0.529764,0.345764,0.97995,0.714846
4,1.660261,1.110792,-0.356222,-1.430197,0.813293,0.892026,-0.604045,-0.921666,-0.562277
5,-0.946212,-0.123116,0.112609,0.088737,-0.040798,-0.499635,0.191128,-0.120114,-0.205679
6,0.200948,-0.859396,0.276189,1.105953,-0.45201,0.64454,1.332572,0.943967,0.171801
7,-0.556022,0.521493,0.145584,-0.969953,0.064752,-0.710536,-0.377219,0.466197,0.469692


In [80]:
# get the row vector from he weights df
weights_vector = weights_df.iloc[0].values

In [81]:
weights_vector

array([0.        , 0.19230769, 0.19230769, 0.03846154, 0.15384615,
       0.        , 0.09615385, 0.13461538, 0.19230769])

In [70]:
weights_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,0.0,0.192308,0.192308,0.038462,0.153846,0.0,0.096154,0.134615,0.192308


In [51]:
weights_df

Unnamed: 0,0,correct_transit_rate,error_transit_rate,neutral_transit_rate,num_trials,perf_score,seconds_per_reward,total_transitions,transition_times,transits_per_reward
0,,,,,,,,,,


In [47]:
weights_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,,1.0,1.0,1.0,1.0,,1.0,1.0,1.0


In [42]:
weighted_outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times
0,0.0,-0.942933,1.650122,-0.014397,-1.579911,0.0,0.503443,-1.298539,-1.897505
1,-0.0,1.647524,-1.987837,-0.145713,1.103774,-0.0,-0.94168,0.526573,1.49806
2,0.0,-0.345625,-0.114932,0.132089,0.395123,0.0,-0.005863,-0.171868,-0.188938
3,-0.0,-1.008738,0.274486,0.269113,-0.227177,-0.0,0.172882,0.685965,0.714846
4,0.0,1.110792,-0.356222,-0.286039,0.650635,0.0,-0.302022,-0.645166,-0.562277
5,-0.0,-0.123116,0.112609,0.017747,-0.032638,-0.0,0.095564,-0.08408,-0.205679
6,0.0,-0.859396,0.276189,0.221191,-0.361608,0.0,0.666286,0.660777,0.171801
7,-0.0,0.521493,0.145584,-0.193991,0.051802,-0.0,-0.188609,0.326338,0.469692


In [40]:
flipped_outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,transition_times,mouse_id
0,92,-0.080249,-0.075162,-0.005087,-0.286907,1752,-2.116936,6.390712,-0.983713,EJT136_1_3
1,-148,0.109913,0.084754,0.025159,0.145115,-2435,3.423323,-2.684527,0.756373,EJT136_1_4
2,48,-0.036402,0.002426,-0.038827,0.031036,476,-0.164377,0.788421,-0.108143,EJT149_1_1
3,-149,-0.08508,-0.014692,-0.070388,-0.069142,-902,-0.849644,-3.477093,0.355008,EJT178_1_6
4,185,0.070512,0.013032,0.05748,0.072169,1080,0.971027,3.141863,-0.299464,EJT178_1_7
5,-149,-0.020067,-0.007576,-0.012491,-0.037825,-860,-0.553224,0.351901,-0.116723,EJT178_1_8
6,-2,-0.074117,-0.014767,-0.05935,-0.090783,735,-2.741238,-3.351847,0.07672,EJT178_2_1
7,-99,0.027252,-0.009026,0.036278,-0.024232,-1154,0.536228,-1.688874,0.229377,EJT178_2_3


In [264]:
zscore_outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,mouse_id
0,0.760181,-0.914441,1.467404,-0.050161,-1.753122,1.236212,1.17438,-1.536257,EJT136_1_3
1,-0.866358,1.456156,-1.635537,-0.717995,1.130541,-1.48777,-1.73636,0.936223,EJT136_1_4
2,0.461982,-0.367828,-0.038072,0.694817,0.369085,0.406071,0.148545,-0.009956,EJT149_1_1
3,-0.873135,-0.974661,0.294076,1.391672,-0.299587,-0.49043,0.50857,1.152152,EJT178_1_6
4,1.390465,0.964978,-0.243877,-1.431645,0.643635,0.799022,-0.447973,-0.651133,EJT178_1_7
5,-0.873135,-0.164204,0.156006,0.113312,-0.090552,-0.463105,0.352838,0.108971,EJT178_1_8


In [258]:
outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,mouse_id
0,92,-0.080249,0.075162,0.005087,-0.286907,1752,2.116936,-6.390712,EJT136_1_3
1,-148,0.109913,-0.084754,-0.025159,0.145115,-2435,-3.423323,2.684527,EJT136_1_4
2,48,-0.036402,-0.002426,0.038827,0.031036,476,0.164377,-0.788421,EJT149_1_1
3,-149,-0.08508,0.014692,0.070388,-0.069142,-902,0.849644,3.477093,EJT178_1_6
4,185,0.070512,-0.013032,-0.05748,0.072169,1080,-0.971027,-3.141863,EJT178_1_7
5,-149,-0.020067,0.007576,0.012491,-0.037825,-860,0.553224,-0.351901,EJT178_1_8


In [None]:
# zscore my deltas.

# weighted average of the deltas




delta_tracking_data

Unnamed: 0,sequence_motif,mean_h_distance_from_av_mm,std_h_distance_from_av_mm,mean_dtw_distance_from_av_mm,std_dtw_distance_from_av_mm,mean_movement_speed_mm_s,std_movement_speed_mm_s,mouse_id
0,0,0.525667,0.609281,0.409804,0.034668,4.296042,2.551191,EJT178_1_7
1,0,0.410894,0.291842,-0.047228,-0.517722,10.015036,0.490464,EJT178_1_7
2,0,0.080067,0.182667,0.11779,0.087281,9.061283,1.823435,EJT178_1_7
3,0,0.093624,0.095479,-0.086953,0.122308,7.008607,1.098241,EJT178_1_7
4,0,0.660841,0.549272,-0.327533,-0.110235,6.771786,2.699511,EJT178_1_7


In [213]:
outcome_matrix_general_df

Unnamed: 0,num_trials,correct_transit_rate,error_transit_rate,neutral_transit_rate,perf_score,total_transitions,transits_per_reward,seconds_per_reward,mouse_id
0,92,-0.080249,0.075162,0.005087,-0.286907,1752,2.116936,-6.390712,EJT136_1_3
1,-148,0.109913,-0.084754,-0.025159,0.145115,-2435,-3.423323,2.684527,EJT136_1_4
2,48,-0.036402,-0.002426,0.038827,0.031036,476,0.164377,-0.788421,EJT149_1_1
3,-149,-0.08508,0.014692,0.070388,-0.069142,-902,0.849644,3.477093,EJT178_1_6
4,185,0.070512,-0.013032,-0.05748,0.072169,1080,-0.971027,-3.141863,EJT178_1_7
5,-149,-0.020067,0.007576,0.012491,-0.037825,-860,0.553224,-0.351901,EJT178_1_8


In [None]:
#improved behaviour or not? how much has behaviour changed? find a number that indicates this...? 
# increased stereotypy or not? 
# which replay features are associated with improved behaviour?


In [None]:
# stereotypy is kind of easy from the tracking data
# correct transition rate 

# take the delta of these 
# first plot them against replay features 


In [None]:
# work out what the next steps are...I need to talk to chat gpt
# essentially I want to know if replay features can predict the behavioural deltas? 
# how do I do a regression with muliple predictors and mulitple outcomes? 
# maybe I cant? Maybe the predictor (replay) variables need to be combined into a single variable? or I need to do a GLM for each? 

In [None]:
# looks like I want to use a multivariate regression 
# I will use the R squared as a meausre of predictive power of the model, and compare this to shuffled replay features. 

In [230]:
Y_df

Unnamed: 0,event_rpm,spikes_per_event,units_per_event,event_lengths
0,34.474165,192.193237,8.31723,0.288156
1,23.785714,176.121622,8.009009,0.272384
2,25.660087,138.924166,6.703741,0.112025
3,22.043532,154.452991,7.301994,0.239185
4,30.756495,177.417752,7.358306,0.173785
5,23.720815,169.163265,7.591837,0.203919


# multivariate regression - do replay features predict behavioural (poke data) deltas?

how well can it predict compared to shuffled versions? 
partial models? getting rid of some features?
which features are the most important 


In [253]:
import pandas as pd
import statsmodels.api as sm
from statsmodels.multivariate.multivariate_ols import MultivariateLS

# set up the variable names
Y = predictor_matrix_general_df
X = outcome_matrix_general_df
# drop the mouse_id and animal_id columns as they are not predictors or outcomes
Y_df = Y.drop(columns=['animal_id'])
X_df = X.drop(columns=['mouse_id'])

# drop some of the more complex to interpret columns for now
Y_df = Y_df.drop(columns = list(predictor_matrix_general_df)[5::])
X_df = X_df.drop(columns = ['error_transit_rate','neutral_transit_rate','num_trials'])

#Fit the model
model = MultivariateLS(X_df, sm.add_constant(Y_df)) # X = outcomes, Y = predictors
result = model.fit()

# 4. Summarize
print(result.summary())      # full MultivariateLS table
print(result.params)         # coefficient matrix (exog × endog)
print(result.pvalues)        # p-values per coefficient
print(result.mv_test())      # Wilks’, Pillai, etc. tests

ImportError: cannot import name 'MultivariateLS' from 'statsmodels.multivariate.multivariate_ols' (c:\miniconda\Lib\site-packages\statsmodels\multivariate\multivariate_ols.py)

general regression for inference - is replay predictive? is this prediction significant?

In [243]:

import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.multivariate.manova import MANOVA
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


# set up the variable names
X = predictor_matrix_general_df
Y = outcome_matrix_general_df
# drop the mouse_id and animal_id columns as they are not predictors or outcomes
Y_df = Y.drop(columns=['mouse_id'])
X_df = X.drop(columns=['animal_id'])

# drop some of the more complex to interpret columns for now
X_df = X_df.drop(columns = list(predictor_matrix_general_df)[5::])
Y_df = Y_df.drop(columns = ['error_transit_rate','neutral_transit_rate','num_trials'])


# behavior_names = list(Y.columns)

# STANDARDIZE PREDICTORS
scaler = StandardScaler()
Y_std = scaler.fit_transform(X_df.values)

# 1) MULTIVARIATE TEST (overall inference)
Y_const = sm.add_constant(Y_std)
mv = MANOVA(endog=X_df.values, exog=Y_const)
mv_results = mv.mv_test()
print("=== MANOVA test results ===")
print(mv_results)  # look for the overall p-value under Pillai’s trace / Wilks’ lambda


mv_res = mv.mv_test().results
stat_table = mv_res['x1']['stat']
p_wilks = stat_table.loc["Wilks' lambda", "Pr > F"]
p_pillai = stat_table.loc["Pillai's trace", "Pr > F"]

print("Omnibus p-value (Wilks'):", p_wilks)
print("Omnibus p-value (Pillai):", p_pillai)

=== MANOVA test results ===
                               Multivariate linear model
                                                                                        
----------------------------------------------------------------------------------------
           x0                   Value         Num DF Den DF        F Value        Pr > F
----------------------------------------------------------------------------------------
          Wilks' lambda                0.0000 1.0000 1.0000 6095533867062600.0000 0.0000
         Pillai's trace                1.2612 1.0000 1.0000               -4.8290 1.0000
 Hotelling-Lawley trace 4503599627370495.5000 1.0000 0.5000 2251799813685247.7500 0.0001
    Roy's greatest root 4503599627370495.0000 1.0000 1.0000 4503599627370495.0000 0.0000
----------------------------------------------------------------------------------------
                                                                                        
-------------------------

to what exect does each predictor predict each outcome?

In [None]:
# multivariate regression - do replay features predict behavioural (tracking data) deltas?
predictor_matrix_motif_df
outcome_matrix_motif_df