# Omission LFP Analysis

Brief 1-2 sentence description of notebook.

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
import glob
from collections import defaultdict
import re

In [3]:
# Imports of all used packages and libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from scipy import stats
import itertools
from scipy.stats import linregress

In [4]:
import spikeinterface.extractors as se
import spikeinterface.preprocessing as sp
from spectral_connectivity import Multitaper, Connectivity
import spectral_connectivity

## Inputs & Data

Explanation of each input and where it comes from.

In [5]:
# Inputs and Required data loading
# input varaible names are in all caps snake case
# Whenever an input changes or is used for processing 
# the vairables are all lower in snake case

In [6]:
BBOX_TO_ANCHOR=(1.5, 0.9)
LOC='upper right'

In [7]:
ALL_BANDS = ["theta", "beta", "gamma"]
BAND_TO_FREQ = {"theta": {"low_freq": 4, "high_freq": 12}, "beta": {"low_freq": 13, "high_freq": 30}, "gamma": {"low_freq": 30, "high_freq": 70}}

In [8]:
# variables for LFP extraction
FREQ_MIN=0.5
FREQ_MAX=300
NOTCH_FREQ=60
ORIGINAL_SAMPLE_RATE = 20000
RESAMPLE_RATE=1000
TRIAL_DURATION=10

In [9]:
INPUT_VARIABLE = 1

TIME_HALFBANDWIDTH_PRODUCT = 2
TIME_WINDOW_DURATION = 3
TIME_WINDOW_STEP = 1.5 

TRIAL_TIME_STAMP_DURATION = 1000*10

In [10]:
BIN_TO_COLOR = {0: {"baseline": "lightblue", "trial": "blue"}, 1: {"baseline": "lightgreen", "trial": "green"}, 2: {"baseline": "lightcoral", "trial": "red"}}
TRIAL_OR_BASELINE_TO_STYLE = {'baseline': "--", "trial": "-"}
BIN_TO_VELOCITY = {0: "0 to 2.5cm/s", 1: "2.5 to 5cm/s", 2: "5cm/s+"}

In [11]:
NUM_LINES = 3

In [12]:
# Generate colors from the "Blues" colormap
LOSING_COLORS = cm.Oranges(np.linspace(0.5, 1, NUM_LINES))
# Generate colors from the "Blues" colormap
WINNING_COLORS = cm.Blues(np.linspace(0.5, 1, NUM_LINES))
# Generate colors from the "Blues" colormap
REWARDED_COLORS = cm.Greens(np.linspace(0.5, 1, NUM_LINES))
# Generate colors from the "Blues" colormap
OMISSION_COLORS = cm.Reds(np.linspace(0.5, 1, NUM_LINES))

In [13]:
BASELINE_OUTCOME_TO_COLOR = {'lose': "orange",
 'lose_baseline': LOSING_COLORS[0],
 'omission': "red",
 'omission_baseline': "hotpink",
 'rewarded': "green",
 'rewarded_baseline': REWARDED_COLORS[0],
 'win': "blue",
 'win_baseline': WINNING_COLORS[0]}

In [14]:
BASELINE_OUTCOME_TO_COLOR = {'lose_trial': "orange",
 'lose_baseline': LOSING_COLORS[0],
 'omission_trial': "red",
 'omission_baseline': "hotpink",
 'rewarded_trial': "green",
 'rewarded_baseline': REWARDED_COLORS[0],
 'win_trial': "blue",
 'win_baseline': WINNING_COLORS[0]}

In [15]:
COMPETITIVE_OUTCOME_TO_COLOR = {'lose_comp': "orange", 
'lose_non_comp': "yellow",
'omission': "red",
'rewarded': "green",
'win_comp': "blue", 
'win_non_comp': WINNING_COLORS[0]}

In [16]:
TRIAL_OR_BASELINE_TO_STYLE = {'baseline': "--", "trial": "-"}

In [17]:
CHANNEL_MAPPING_DF = pd.read_excel("../../channel_mapping.xlsx")
CHANNEL_MAPPING_DF["Subject"] = CHANNEL_MAPPING_DF["Subject"].astype(str)

TONE_TIMESTAMP_DF = pd.read_excel("../../rce_tone_timestamp.xlsx", index_col=0)
OUTPUT_DIR = r"./proc" # where data is saved should always be shown in the inputs


In [18]:
all_session_dir = ['/scratch/back_up/reward_competition_extention/data/pilot/20221214_125409_om_and_comp_6_1_and_6_3.rec',
'/scratch/back_up/reward_competition_extention/data/pilot/20221215_145401_comp_amd_om_6_1_and_6_3.rec',
'/scratch/back_up/reward_competition_extention/data/standard/2023_06_12/20230612_101430_standard_comp_to_training_D1_subj_1-4_and_1-3.rec',
'/scratch/back_up/reward_competition_extention/data/omission/2023_06_17/20230617_115521_standard_comp_to_omission_D1_subj_1-1_and_1-2.rec',
'/scratch/back_up/reward_competition_extention/data/omission/2023_06_18/20230618_100636_standard_comp_to_omission_D2_subj_1-4_and_1-1.rec',
'/scratch/back_up/reward_competition_extention/data/omission/2023_06_19/20230619_115321_standard_comp_to_omission_D3_subj_1-2_and_1-4.rec',
'/scratch/back_up/reward_competition_extention/data/omission/2023_06_20/20230620_114347_standard_comp_to_omission_D4_subj_1-2_and_1-1.rec',
'/scratch/back_up/reward_competition_extention/data/omission/2023_06_21/20230621_111240_standard_comp_to_omission_D5_subj_1-4_and_1-2.rec']

In [19]:
all_session_dir

['/scratch/back_up/reward_competition_extention/data/pilot/20221214_125409_om_and_comp_6_1_and_6_3.rec',
 '/scratch/back_up/reward_competition_extention/data/pilot/20221215_145401_comp_amd_om_6_1_and_6_3.rec',
 '/scratch/back_up/reward_competition_extention/data/standard/2023_06_12/20230612_101430_standard_comp_to_training_D1_subj_1-4_and_1-3.rec',
 '/scratch/back_up/reward_competition_extention/data/omission/2023_06_17/20230617_115521_standard_comp_to_omission_D1_subj_1-1_and_1-2.rec',
 '/scratch/back_up/reward_competition_extention/data/omission/2023_06_18/20230618_100636_standard_comp_to_omission_D2_subj_1-4_and_1-1.rec',
 '/scratch/back_up/reward_competition_extention/data/omission/2023_06_19/20230619_115321_standard_comp_to_omission_D3_subj_1-2_and_1-4.rec',
 '/scratch/back_up/reward_competition_extention/data/omission/2023_06_20/20230620_114347_standard_comp_to_omission_D4_subj_1-2_and_1-1.rec',
 '/scratch/back_up/reward_competition_extention/data/omission/2023_06_21/20230621_111

## Outputs

Describe each output that the notebook creates. 

- Is it a plot or is it data?

- How valuable is the output and why is it valuable or useful?

## Functions

- Ideally functions are defined here first and then data is processed using the functions
    - function names are short and in snake case all lowercase
    - a function name should be unique but does not have to describe the function
    - doc strings describe functions not function names

In [20]:
def generate_pairs(lst):
    pairs = []
    n = len(lst)
    for i in range(n):
        for j in range(i+1, n):
            pairs.append((lst[i], lst[j]))
    return pairs

In [21]:
def nested_dict():
    return defaultdict(dict)

triple_nested_dict = defaultdict(nested_dict)

## Processing

Describe what is done to the data here and how inputs are manipulated to generate outputs. 

In [22]:
# As much code and as many cells as required
# includes EDA and playing with data
# GO HAM!

In [23]:
CHANNEL_MAPPING_DF

Unnamed: 0,Cohort,Subject,eib_mPFC,eib_vHPC,eib_BLA,eib_LH,eib_MD,spike_interface_mPFC,spike_interface_vHPC,spike_interface_BLA,spike_interface_LH,spike_interface_MD
0,1,6.1,,15,14,13,31,21.0,15.0,14.0,13.0,16.0
1,1,6.2,,15,14,13,31,,,,,
2,1,6.3,,15,14,13,31,,,,,
3,1,6.4,,15,14,13,31,,,,,
4,2,1.1,,16,17,18,19,5.0,31.0,30.0,29.0,28.0
5,2,1.2,,31,30,29,28,10.0,31.0,30.0,29.0,28.0
6,2,1.3,,15,14,13,12,9.0,31.0,30.0,29.0,28.0
7,2,1.4,,15,14,13,12,15.0,31.0,30.0,29.0,28.0


### Getting the subject IDs from the file name

In [24]:
all_trials_df = TONE_TIMESTAMP_DF.dropna(subset="condition").reset_index(drop=True)

In [25]:
all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,condition,competition_closeness,Unnamed: 13
0,6310663.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,1390826.0,20221202_134600_omission_and_competition_subje...,1734.0,1.0,6_1_top_2_base_3,rewarded,,
1,7910662.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,2990825.0,20221202_134600_omission_and_competition_subje...,3728.0,1.0,6_1_top_2_base_3,rewarded,,
2,9710660.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,4790823.0,20221202_134600_omission_and_competition_subje...,5972.0,1.0,6_1_top_2_base_3,rewarded,,
3,11310658.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,6390821.0,20221202_134600_omission_and_competition_subje...,7966.0,1.0,6_1_top_2_base_3,omission,,
4,12810657.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,7890820.0,20221202_134600_omission_and_competition_subje...,9836.0,1.0,6_1_top_2_base_3,rewarded,,


- Original timestamps are based on ephys recordings at 20kHz. The LFP will be at 1kHz, so we will need to divide all the timestamps by 20

In [26]:
all_trials_df["resampled_index"] = all_trials_df["time_stamp_index"] // (ORIGINAL_SAMPLE_RATE // RESAMPLE_RATE)

In [27]:
all_trials_df["recording_dir"].unique()

array(['20221202_134600_omission_and_competition_subject_6_1_and_6_2',
       '20221203_154800_omission_and_competition_subject_6_4_and_6_1',
       '20221214_125409_om_and_comp_6_1_and_6_3',
       '20221215_145401_comp_amd_om_6_1_and_6_3',
       '20230612_101430_standard_comp_to_training_D1_subj_1-4_and_1-3',
       '20230617_115521_standard_comp_to_omission_D1_subj_1-1_and_1-2',
       '20230618_100636_standard_comp_to_omission_D2_subj_1-4_and_1-1',
       '20230619_115321_standard_comp_to_omission_D3_subj_1-2_and_1-4',
       '20230620_114347_standard_comp_to_omission_D4_subj_1-2_and_1-1',
       '20230621_111240_standard_comp_to_omission_D5_subj_1-4_and_1-2'],
      dtype=object)

- Getting a list of all the subjects through the recording name

In [28]:
all_trials_df["all_subjects"] = all_trials_df["recording_dir"].apply(lambda x: ["{}.{}".format(tup[0],tup[1]) for tup in re.findall(r'(\d+)-(\d+)', x.replace("_", "-"))[1:]])

In [29]:
all_trials_df["all_subjects"].head()

0    [6.1, 6.2]
1    [6.1, 6.2]
2    [6.1, 6.2]
3    [6.1, 6.2]
4    [6.1, 6.2]
Name: all_subjects, dtype: object

- Getting the current subject of the recording through the ending of the recording name file

In [30]:
all_trials_df["subject_info"].head()

0    6_1_top_2_base_3
1    6_1_top_2_base_3
2    6_1_top_2_base_3
3    6_1_top_2_base_3
4    6_1_top_2_base_3
Name: subject_info, dtype: object

In [31]:
all_trials_df["current_subject"] = all_trials_df["subject_info"].apply(lambda x: ".".join(x.replace("-","_").split("_")[:2]))

In [32]:
all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,condition,competition_closeness,Unnamed: 13,resampled_index,all_subjects,current_subject
0,6310663.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,1390826.0,20221202_134600_omission_and_competition_subje...,1734.0,1.0,6_1_top_2_base_3,rewarded,,,69541.0,"[6.1, 6.2]",6.1
1,7910662.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,2990825.0,20221202_134600_omission_and_competition_subje...,3728.0,1.0,6_1_top_2_base_3,rewarded,,,149541.0,"[6.1, 6.2]",6.1
2,9710660.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,4790823.0,20221202_134600_omission_and_competition_subje...,5972.0,1.0,6_1_top_2_base_3,rewarded,,,239541.0,"[6.1, 6.2]",6.1
3,11310658.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,6390821.0,20221202_134600_omission_and_competition_subje...,7966.0,1.0,6_1_top_2_base_3,omission,,,319541.0,"[6.1, 6.2]",6.1
4,12810657.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,7890820.0,20221202_134600_omission_and_competition_subje...,9836.0,1.0,6_1_top_2_base_3,rewarded,,,394541.0,"[6.1, 6.2]",6.1


- Labeling the trial as a winner or loser if the winner matches the subject id or not

In [33]:
all_trials_df["trial_outcome"] = all_trials_df.apply(
    lambda x: "win" if str(x["condition"]).strip() == str(x["current_subject"]) 
             else ("lose" if str(x["condition"]) in x["all_subjects"] 
                   else x["condition"]), axis=1)

In [34]:
all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,condition,competition_closeness,Unnamed: 13,resampled_index,all_subjects,current_subject,trial_outcome
0,6310663.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,1390826.0,20221202_134600_omission_and_competition_subje...,1734.0,1.0,6_1_top_2_base_3,rewarded,,,69541.0,"[6.1, 6.2]",6.1,rewarded
1,7910662.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,2990825.0,20221202_134600_omission_and_competition_subje...,3728.0,1.0,6_1_top_2_base_3,rewarded,,,149541.0,"[6.1, 6.2]",6.1,rewarded
2,9710660.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,4790823.0,20221202_134600_omission_and_competition_subje...,5972.0,1.0,6_1_top_2_base_3,rewarded,,,239541.0,"[6.1, 6.2]",6.1,rewarded
3,11310658.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,6390821.0,20221202_134600_omission_and_competition_subje...,7966.0,1.0,6_1_top_2_base_3,omission,,,319541.0,"[6.1, 6.2]",6.1,omission
4,12810657.0,1.0,20221202_134600_omission_and_competition_subje...,20221202_134600_omission_and_competition_subje...,dio_ECU_Din1,7890820.0,20221202_134600_omission_and_competition_subje...,9836.0,1.0,6_1_top_2_base_3,rewarded,,,394541.0,"[6.1, 6.2]",6.1,rewarded


In [35]:
competition_closeness_map = {k: "non_comp" if "only" in str(k).lower() else "comp" if type(k) is str else np.nan for k in all_trials_df["competition_closeness"].unique()}

In [36]:
competition_closeness_map

{nan: nan,
 'Subj 1 Only': 'non_comp',
 'Subj 2 blocking Subj 1': 'comp',
 'Subj 1 then Subj 2': 'comp',
 'Subj 1 blocking Subj 2': 'comp',
 'Subj 2 Only': 'non_comp',
 'Subj 2 then Subj 1': 'comp',
 'Close Call': 'comp'}

In [37]:
all_trials_df["competition_closeness"] = all_trials_df["competition_closeness"].map(competition_closeness_map)

In [38]:
all_trials_df["competition_closeness"]

0           NaN
1           NaN
2           NaN
3           NaN
4           NaN
         ...   
698    non_comp
699    non_comp
700        comp
701        comp
702        comp
Name: competition_closeness, Length: 703, dtype: object

In [39]:
all_trials_df["competition_closeness"] = all_trials_df.apply(lambda x: "_".join([str(x["trial_outcome"]), str(x["competition_closeness"])]).strip("nan").strip("_"), axis=1)

### Extracting the LFP

In [40]:
recording_name_to_all_ch_lfp = {}
# Going through all the recording sessions 
for session_dir in all_session_dir:
    # Going through all the recordings in each session
    for recording_path in glob.glob(os.path.join(session_dir, "*.rec")):
        try:
            recording_basename = os.path.splitext(os.path.basename(recording_path))[0]
            # checking to see if the recording has an ECU component
            # if it doesn't, then the next one be extracted
            current_recording = se.read_spikegadgets(recording_path, stream_id="ECU")
            current_recording = se.read_spikegadgets(recording_path, stream_id="trodes")
            print(recording_basename)
            # Preprocessing the LFP
            current_recording = sp.bandpass_filter(current_recording, freq_min=FREQ_MIN, freq_max=FREQ_MAX)
            current_recording = sp.notch_filter(current_recording, freq=NOTCH_FREQ)
            current_recording = sp.resample(current_recording, resample_rate=RESAMPLE_RATE)
            current_recording = sp.zscore(current_recording)
            recording_name_to_all_ch_lfp[recording_basename] = current_recording
        except:
            pass



20221214_125409_om_and_comp_6_1_top_1_base_2_vs_6_3
20221215_145401_comp_amd_om_6_1_top_4_base_3
20230612_101430_standard_comp_to_training_D1_subj_1-4_t4b2L_box1_merged
20230612_101430_standard_comp_to_training_D1_subj_1-3_t3b3L_box2_merged
20230617_115521_standard_comp_to_omission_D1_subj_1-2_t2b2L_box2_merged
20230617_115521_standard_comp_to_omission_D1_subj_1-1_t1b3L_box1_merged
20230618_100636_standard_comp_to_omission_D2_subj_1_4_t4b3L_box1_merged
20230618_100636_standard_comp_to_omission_D2_subj_1_1_t1b2L_box2_merged
20230619_115321_standard_comp_to_omission_D3_subj_1-4_t3b3L_box2_merged
20230620_114347_standard_comp_to_omission_D4_subj_1-1_t1b2L_box_2_merged
20230620_114347_standard_comp_to_omission_D4_subj_1-2_t3b3L_box_1_merged
20230621_111240_standard_comp_to_omission_D5_subj_1-4_t3b3L_box1_merged


- Filtering for all trials that have labels

In [41]:
all_trials_df = all_trials_df[all_trials_df["recording_file"].isin(recording_name_to_all_ch_lfp.keys())].reset_index(drop=True)

In [42]:
all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,condition,competition_closeness,Unnamed: 13,resampled_index,all_subjects,current_subject,trial_outcome
0,4359951.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,1408048.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,1405.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,70402.0,"[6.1, 6.3]",6.1,rewarded
1,5959954.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,3008051.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,3002.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,150402.0,"[6.1, 6.3]",6.1,rewarded
2,7759946.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,4808043.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,4798.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,240402.0,"[6.1, 6.3]",6.1,rewarded
3,9359945.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,6408042.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,6395.0,1.0,6_1_top_1_base_2_vs_6_3,omission,omission,,320402.0,"[6.1, 6.3]",6.1,omission
4,10859943.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,7908040.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,7892.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,395402.0,"[6.1, 6.3]",6.1,rewarded


In [43]:
all_trials_df["trial_outcome"].unique()

array(['rewarded', 'omission', 'win', 'lose'], dtype=object)

# Extracting the LFP

In [44]:
all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,condition,competition_closeness,Unnamed: 13,resampled_index,all_subjects,current_subject,trial_outcome
0,4359951.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,1408048.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,1405.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,70402.0,"[6.1, 6.3]",6.1,rewarded
1,5959954.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,3008051.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,3002.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,150402.0,"[6.1, 6.3]",6.1,rewarded
2,7759946.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,4808043.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,4798.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,240402.0,"[6.1, 6.3]",6.1,rewarded
3,9359945.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,6408042.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,6395.0,1.0,6_1_top_1_base_2_vs_6_3,omission,omission,,320402.0,"[6.1, 6.3]",6.1,omission
4,10859943.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,7908040.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,7892.0,1.0,6_1_top_1_base_2_vs_6_3,rewarded,rewarded,,395402.0,"[6.1, 6.3]",6.1,rewarded


In [45]:
CHANNEL_MAPPING_DF

Unnamed: 0,Cohort,Subject,eib_mPFC,eib_vHPC,eib_BLA,eib_LH,eib_MD,spike_interface_mPFC,spike_interface_vHPC,spike_interface_BLA,spike_interface_LH,spike_interface_MD
0,1,6.1,,15,14,13,31,21.0,15.0,14.0,13.0,16.0
1,1,6.2,,15,14,13,31,,,,,
2,1,6.3,,15,14,13,31,,,,,
3,1,6.4,,15,14,13,31,,,,,
4,2,1.1,,16,17,18,19,5.0,31.0,30.0,29.0,28.0
5,2,1.2,,31,30,29,28,10.0,31.0,30.0,29.0,28.0
6,2,1.3,,15,14,13,12,9.0,31.0,30.0,29.0,28.0
7,2,1.4,,15,14,13,12,15.0,31.0,30.0,29.0,28.0


- Adding all the brain region to ch information

In [46]:
channel_map_and_all_trials_df = all_trials_df.merge(CHANNEL_MAPPING_DF, left_on="current_subject", right_on="Subject", how="left")

- Linking up all LFP calculations with all the trials

In [47]:
channel_map_and_all_trials_df["all_ch_lfp"] = channel_map_and_all_trials_df["recording_file"].map(recording_name_to_all_ch_lfp)

In [48]:
channel_map_and_all_trials_df.head()

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,...,eib_vHPC,eib_BLA,eib_LH,eib_MD,spike_interface_mPFC,spike_interface_vHPC,spike_interface_BLA,spike_interface_LH,spike_interface_MD,all_ch_lfp
0,4359951.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,1408048.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,1405.0,1.0,6_1_top_1_base_2_vs_6_3,...,15,14,13,31,21.0,15.0,14.0,13.0,16.0,ZScoreRecording: 32 channels - 1 segments - 1....
1,5959954.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,3008051.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,3002.0,1.0,6_1_top_1_base_2_vs_6_3,...,15,14,13,31,21.0,15.0,14.0,13.0,16.0,ZScoreRecording: 32 channels - 1 segments - 1....
2,7759946.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,4808043.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,4798.0,1.0,6_1_top_1_base_2_vs_6_3,...,15,14,13,31,21.0,15.0,14.0,13.0,16.0,ZScoreRecording: 32 channels - 1 segments - 1....
3,9359945.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,6408042.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,6395.0,1.0,6_1_top_1_base_2_vs_6_3,...,15,14,13,31,21.0,15.0,14.0,13.0,16.0,ZScoreRecording: 32 channels - 1 segments - 1....
4,10859943.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,7908040.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,7892.0,1.0,6_1_top_1_base_2_vs_6_3,...,15,14,13,31,21.0,15.0,14.0,13.0,16.0,ZScoreRecording: 32 channels - 1 segments - 1....


- Getting the LFP indexes for each trial before and after the tone

In [49]:
channel_map_and_all_trials_df["resampled_index"] = channel_map_and_all_trials_df["resampled_index"].astype(int)

In [50]:
# Getting the LFP index for the trial portion
trial_channel_map_and_all_trials_df = channel_map_and_all_trials_df.copy()
trial_channel_map_and_all_trials_df["trial_or_baseline"] = "trial"
trial_channel_map_and_all_trials_df["trial_or_baseline_entire_lfp_index"] = trial_channel_map_and_all_trials_df["resampled_index"].apply(lambda x: (x, x+RESAMPLE_RATE*TRIAL_DURATION,))

# Getting the LFP index for the baseline portion
baseline_channel_map_and_all_trials_df = channel_map_and_all_trials_df.copy()
baseline_channel_map_and_all_trials_df["trial_or_baseline"] = "baseline"
baseline_channel_map_and_all_trials_df["trial_or_baseline_entire_lfp_index"] = baseline_channel_map_and_all_trials_df["resampled_index"].apply(lambda x: (x-RESAMPLE_RATE*TRIAL_DURATION, x))

- Combining the dataframe for the trial and baseline portions

In [51]:
channel_map_and_all_trials_df = pd.concat([trial_channel_map_and_all_trials_df, baseline_channel_map_and_all_trials_df])

In [52]:
channel_map_and_all_trials_df["trial_or_baseline_entire_lfp_index"].head()

0      (70402, 80402)
1    (150402, 160402)
2    (240402, 250402)
3    (320402, 330402)
4    (395402, 405402)
Name: trial_or_baseline_entire_lfp_index, dtype: object

- Getting the LFP for each brain region

In [53]:
channel_columns = sorted([col for col in channel_map_and_all_trials_df.columns if "spike" in col])

In [54]:
channel_columns

['spike_interface_BLA',
 'spike_interface_LH',
 'spike_interface_MD',
 'spike_interface_mPFC',
 'spike_interface_vHPC']

In [55]:
for col in channel_columns:
    print(col)
    channel_map_and_all_trials_df["{}_trace".format(col.strip("spike_interface").strip("_"))] = channel_map_and_all_trials_df.apply(lambda x: 
x["all_ch_lfp"].get_traces(channel_ids=[str(int(x[col]))], start_frame=x["trial_or_baseline_entire_lfp_index"][0], end_frame=x["trial_or_baseline_entire_lfp_index"][-1] ).T[0][:RESAMPLE_RATE*TRIAL_DURATION], axis=1)

spike_interface_BLA
spike_interface_LH
spike_interface_MD
spike_interface_mPFC
spike_interface_vHPC


## Power calculation for each trial for each brain region

In [56]:
raise ValueError()

ValueError: 

In [57]:
trace_columns = [col for col in channel_map_and_all_trials_df.columns if "trace" in col]

In [58]:
trace_columns

['BLA_trace', 'LH_trace', 'MD_trace', 'mPFC_trace', 'vHPC_trace']

## Power correlation between brain regions calculation

- Combining the trial/baseline and outcome label for coloring

In [59]:
channel_map_and_all_trials_df["outcome_and_trial_or_baseline"] = channel_map_and_all_trials_df.apply(lambda x: "_".join([x["trial_outcome"], x["trial_or_baseline"]]), axis=1)

In [60]:
from statsmodels.tsa.stattools import grangercausalitytests



In [61]:
channel_map_and_all_trials_df.columns

Index(['time', 'state', 'recording_dir', 'recording_file', 'din',
       'time_stamp_index', 'video_file', 'video_frame', 'video_number',
       'subject_info', 'condition', 'competition_closeness', 'Unnamed: 13',
       'resampled_index', 'all_subjects', 'current_subject', 'trial_outcome',
       'Cohort', 'Subject', 'eib_mPFC', 'eib_vHPC', 'eib_BLA', 'eib_LH',
       'eib_MD', 'spike_interface_mPFC', 'spike_interface_vHPC',
       'spike_interface_BLA', 'spike_interface_LH', 'spike_interface_MD',
       'all_ch_lfp', 'trial_or_baseline', 'trial_or_baseline_entire_lfp_index',
       'BLA_trace', 'LH_trace', 'MD_trace', 'mPFC_trace', 'vHPC_trace',
       'outcome_and_trial_or_baseline'],
      dtype='object')

In [62]:
trace_columns = [col for col in channel_map_and_all_trials_df.columns if "trace" in col]

In [63]:
trace_columns

['BLA_trace', 'LH_trace', 'MD_trace', 'mPFC_trace', 'vHPC_trace']

In [64]:
brain_region_pairs = generate_pairs(trace_columns)

In [65]:
brain_region_pairs

[('BLA_trace', 'LH_trace'),
 ('BLA_trace', 'MD_trace'),
 ('BLA_trace', 'mPFC_trace'),
 ('BLA_trace', 'vHPC_trace'),
 ('LH_trace', 'MD_trace'),
 ('LH_trace', 'mPFC_trace'),
 ('LH_trace', 'vHPC_trace'),
 ('MD_trace', 'mPFC_trace'),
 ('MD_trace', 'vHPC_trace'),
 ('mPFC_trace', 'vHPC_trace')]

In [None]:
grangercausalitytests(df[['column1', 'column2']], maxlag=[3])

In [72]:
channel_map_and_all_trials_df[region_1].iloc[0]

array([ 0.02984513,  0.19697787,  0.42579055, ..., -1.5957198 ,
       -1.6693377 , -1.484298  ], dtype=float32)

In [None]:
def granger

In [77]:
for region_1, region_2 in brain_region_pairs:
    pair_base_name = "{}_{}".format(region_1.strip("trace").strip("_"), region_2.strip("trace").strip("_"))
    print(pair_base_name)
    try:
        
        # granger_value = grangercausalitytests(channel_map_and_all_trials_df[[region_1, region_2]], maxlag=[3])
        channel_map_and_all_trials_df["{}_granger".format(pair_base_name)] = channel_map_and_all_trials_df.apply(lambda row: grangercausalitytests(np.array([row[region_1], row[region_2]]).T, maxlag=[3]), axis=1)
        print()
    except Exception as e: 
        print(e)
    break

BLA_LH

Granger Causality
number of lags (no zero) 3
ssr based F test:         F=628.0128, p=0.0000  , df_denom=9990, df_num=3
ssr based chi2 test:   chi2=1885.3585, p=0.0000  , df=3
likelihood ratio test: chi2=1727.1794, p=0.0000  , df=3
parameter F test:         F=628.0128, p=0.0000  , df_denom=9990, df_num=3

Granger Causality
number of lags (no zero) 3
ssr based F test:         F=558.7703, p=0.0000  , df_denom=9990, df_num=3
ssr based chi2 test:   chi2=1677.4856, p=0.0000  , df=3
likelihood ratio test: chi2=1550.7416, p=0.0000  , df=3
parameter F test:         F=558.7703, p=0.0000  , df_denom=9990, df_num=3

Granger Causality
number of lags (no zero) 3
ssr based F test:         F=426.7804, p=0.0000  , df_denom=9990, df_num=3
ssr based chi2 test:   chi2=1281.2382, p=0.0000  , df=3
likelihood ratio test: chi2=1205.5382, p=0.0000  , df=3
parameter F test:         F=426.7804, p=0.0000  , df_denom=9990, df_num=3

Granger Causality
number of lags (no zero) 3
ssr based F test:         F=4

In [76]:
channel_map_and_all_trials_df[""]

0      [0.029845133, 0.19697787, 0.42579055, 0.869488...
1      [-0.5013982, -0.55511945, 0.105452806, 0.65858...
2      [-1.7310177, -1.96182, -1.9240162, -1.9379439,...
3      [-0.60287166, -0.23279203, 0.055710915, -0.240...
4      [-0.49741888, -0.11142183, -0.027855458, -0.44...
                             ...                        
632    [1.2422042, 2.207114, 2.5181334, 2.0459833, 1....
633    [-1.4220709, -1.6881236, -1.6899972, -1.588822...
634    [0.17049861, 0.19485556, 0.31101945, 0.7026042...
635    [0.2192125, 0.09742778, -0.41406804, -0.861861...
636    [-1.28155, -1.3433791, -1.1297874, -0.9424264,...
Name: BLA_trace, Length: 1274, dtype: object

In [75]:
channel_map_and_all_trials_df

Unnamed: 0,time,state,recording_dir,recording_file,din,time_stamp_index,video_file,video_frame,video_number,subject_info,...,spike_interface_MD,all_ch_lfp,trial_or_baseline,trial_or_baseline_entire_lfp_index,BLA_trace,LH_trace,MD_trace,mPFC_trace,vHPC_trace,outcome_and_trial_or_baseline
0,4359951.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,1408048.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,1405.0,1.0,6_1_top_1_base_2_vs_6_3,...,16.0,ZScoreRecording: 32 channels - 1 segments - 1....,trial,"(70402, 80402)","[0.029845133, 0.19697787, 0.42579055, 0.869488...","[-1.2876818, -1.2059243, -1.2297702, -1.073068...","[-0.6062626, -0.34118676, -0.020996109, 0.4094...","[-1.2641292, -1.3847351, -1.3154984, -0.913478...","[-0.78455, -0.7881, -0.6674, -0.5029167, -0.30...",rewarded_trial
1,5959954.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,3008051.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,3002.0,1.0,6_1_top_1_base_2_vs_6_3,...,16.0,ZScoreRecording: 32 channels - 1 segments - 1....,trial,"(150402, 160402)","[-0.5013982, -0.55511945, 0.105452806, 0.65858...","[-0.9027399, -1.2536162, -0.5382374, 0.2078005...","[-1.0498054, -1.2781382, -0.76898247, -0.47503...","[-1.0474851, -1.7934552, -1.9832981, -1.876092...","[-1.4981, -1.7051834, -1.5123, -1.2365834, -1....",rewarded_trial
2,7759946.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,4808043.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,4798.0,1.0,6_1_top_1_base_2_vs_6_3,...,16.0,ZScoreRecording: 32 channels - 1 segments - 1....,trial,"(240402, 250402)","[-1.7310177, -1.96182, -1.9240162, -1.9379439,...","[-2.3266842, -2.7286592, -2.3266842, -1.972401...","[-1.5878308, -1.7689222, -1.5852063, -1.658692...","[-1.8693924, -1.8693924, -1.404836, -1.1167219...","[-1.6412834, -1.6341833, -1.45905, -1.3939667,...",rewarded_trial
3,9359945.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,6408042.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,6395.0,1.0,6_1_top_1_base_2_vs_6_3,...,16.0,ZScoreRecording: 32 channels - 1 segments - 1....,trial,"(320402, 330402)","[-0.60287166, -0.23279203, 0.055710915, -0.240...","[-2.759318, -2.1631691, -1.7475681, -2.1529496...","[-2.4670427, -2.4145525, -2.2360857, -2.493288...","[-2.5952616, -2.0860364, -1.5879785, -1.748786...","[-0.98808336, -0.8756667, -0.65675, -0.5845666...",omission_trial
4,10859943.0,1.0,20221214_125409_om_and_comp_6_1_and_6_3,20221214_125409_om_and_comp_6_1_top_1_base_2_v...,dio_ECU_Din1,7908040.0,20221214_125409_om_and_comp_6_1_and_6_3.1.vide...,7892.0,1.0,6_1_top_1_base_2_vs_6_3,...,16.0,ZScoreRecording: 32 channels - 1 segments - 1....,trial,"(395402, 405402)","[-0.49741888, -0.11142183, -0.027855458, -0.44...","[0.14988889, 0.5859293, 0.7596641, 0.27252525,...","[-0.59314007, 0.1338502, 0.47503698, -0.002624...","[0.5672947, 0.8576424, 0.92911255, 0.39531955,...","[0.7798167, 0.7526, 0.57865, 0.29228333, 0.196...",rewarded_trial
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
632,29281401.0,1.0,20230621_111240_standard_comp_to_omission_D5_s...,20230621_111240_standard_comp_to_omission_D5_s...,dio_ECU_Din1,26088357.0,20230621_111240_standard_comp_to_omission_D5_s...,26040.0,2.0,1-4_t3b3L_box1,...,28.0,ZScoreRecording: 32 channels - 1 segments - 1....,baseline,"(1294417, 1304417)","[1.2422042, 2.207114, 2.5181334, 2.0459833, 1....","[-0.72688836, -0.42565534, 0.3045073, 0.124422...","[-0.8399434, -0.62995756, 0.14635377, -0.00636...","[-1.5326667, -1.3743334, -0.7758333, -1.095666...","[-0.57296234, -0.91625625, -0.85098207, -0.863...",lose_baseline
633,30881425.0,1.0,20230621_111240_standard_comp_to_omission_D5_s...,20230621_111240_standard_comp_to_omission_D5_s...,dio_ECU_Din1,27688381.0,20230621_111240_standard_comp_to_omission_D5_s...,27636.0,2.0,1-4_t3b3L_box1,...,28.0,ZScoreRecording: 32 channels - 1 segments - 1....,baseline,"(1374419, 1384419)","[-1.4220709, -1.6881236, -1.6899972, -1.588822...","[-0.03601699, -0.36016992, -0.5697233, -0.6450...","[-0.3404316, -0.71267927, -0.9640259, -1.03083...","[-0.09183333, -0.47183332, -0.817, -1.0735, -1...","[-0.47142473, -0.82438886, -1.0153763, -1.1217...",lose_baseline
634,32281440.0,1.0,20230621_111240_standard_comp_to_omission_D5_s...,20230621_111240_standard_comp_to_omission_D5_s...,dio_ECU_Din1,29088396.0,20230621_111240_standard_comp_to_omission_D5_s...,29033.0,2.0,1-4_t3b3L_box1,...,28.0,ZScoreRecording: 32 channels - 1 segments - 1....,baseline,"(1444419, 1454419)","[0.17049861, 0.19485556, 0.31101945, 0.7026042...","[1.1918349, 1.3784684, 1.4897937, 1.3948398, 1...","[1.3521816, 1.4985354, 1.5748938, 1.4794457, 1...","[1.5706667, 1.4028333, 1.1653334, 1.1843333, 1...","[-0.58505017, -0.44483155, -0.2465914, -0.0531...",lose_baseline
635,34481464.0,1.0,20230621_111240_standard_comp_to_omission_D5_s...,20230621_111240_standard_comp_to_omission_D5_s...,dio_ECU_Din1,31288420.0,20230621_111240_standard_comp_to_omission_D5_s...,31230.0,2.0,1-4_t3b3L_box1,...,28.0,ZScoreRecording: 32 channels - 1 segments - 1....,baseline,"(1554421, 1564421)","[0.2192125, 0.09742778, -0.41406804, -0.861861...","[-0.68432283, -1.01175, -1.0313957, -0.5893689...","[-0.65541035, -0.9385731, -1.0531108, -0.85585...","[1.7036667, 1.2951666, 0.836, 0.95316666, 1.16...","[-0.43757886, -0.3747222, -0.27801973, -0.2054...",lose_baseline


In [67]:
granger_value

NameError: name 'granger_value' is not defined

- Filtering out the outliers

In [None]:
band_to_power_correlation = defaultdict(dict)
for band in ALL_BANDS:
    # Getting all the pairs of brain regions
    band_averaged_columns = [col for col in channel_map_and_all_trials_df.columns if "averaged_{}".format(band) in col]
    band_to_power_correlation[band]["brain_region_pairs"] = generate_pairs(band_averaged_columns)
    print(band_to_power_correlation[band]["brain_region_pairs"])

    # Removing rows that are outliers
    filtered_df = channel_map_and_all_trials_df.copy()
    
    for col in band_averaged_columns:
        # filtered_df = filtered_df[filtered_df[col] <= 3]
        # Assuming data is a 1D numpy array
        Q1 = np.percentile(filtered_df[col], 25)
        Q3 = np.percentile(filtered_df[col], 75)
        IQR = Q3 - Q1
        band_to_power_correlation[band]["outlier_removed_df"] = filtered_df[(filtered_df[col] >= Q1 - 1.5 * IQR) & (filtered_df[col] <= Q3 + 1.5 * IQR)]


    
    # Getting the mean and standard deviation
    

In [None]:
band_to_power_correlation = defaultdict(dict)
for band in ALL_BANDS:
    # Getting all the pairs of brain regions
    band_averaged_columns = [col for col in channel_map_and_all_trials_df.columns if "averaged_{}".format(band) in col]
    band_to_power_correlation[band]["brain_region_pairs"] = generate_pairs(band_averaged_columns)
    print(band_to_power_correlation[band]["brain_region_pairs"])

    # Removing rows that are outliers
    filtered_df = channel_map_and_all_trials_df.copy()
    
    for col in band_averaged_columns:
        # filtered_df = filtered_df[filtered_df[col] <= 3]
        # Assuming data is a 1D numpy array
        Q1 = np.percentile(filtered_df[col], 25)
        Q3 = np.percentile(filtered_df[col], 75)
        IQR = Q3 - Q1
        filtered_df = filtered_df[(filtered_df[col] >= Q1 - 1.5 * IQR) & (filtered_df[col] <= Q3 + 1.5 * IQR)]
    band_to_power_correlation[band]["outlier_removed_df"] = filtered_df

In [None]:
channel_map_and_all_trials_df.shape

In [None]:
band_to_power_correlation[band]["outlier_removed_df"].shape

- Plotting all of the conditions

In [None]:
for band in ALL_BANDS:
    for region_1, region_2 in band_to_power_correlation[band]["brain_region_pairs"]:
        region_1_basename = region_1.split("_")[0]
        region_2_basename = region_2.split("_")[0]
        x = band_to_power_correlation[band]["outlier_removed_df"][region_1]
        y = band_to_power_correlation[band]["outlier_removed_df"][region_2]
        
        # Perform linear regression to get the slope, intercept and r-value (correlation coefficient)
        slope, intercept, r_value, p_value, std_err = linregress(x, y)
        
        # Create a line of best fit using the slope and intercept
        line = slope * x + intercept
        
        # Create scatter plot
        sns.scatterplot(x=x, y=y, data=band_to_power_correlation[band]["outlier_removed_df"], hue='outcome_and_trial_or_baseline', palette=BASELINE_OUTCOME_TO_COLOR)
        
        # Plot line of best fit
        plt.plot(x, line, color='red')
        
        # Add R² value to the plot
        plt.text(0.1, 0.9, f'R = {r_value:.2f}', transform=plt.gca().transAxes)
        
        # Add labels and legend
        plt.title("Power correlation of Z-scored {} band LFP: {} and {}".format(band, region_2_basename, region_1_basename))
        plt.xlabel('{} {} power of Z-scored LFP'.format(band, region_1_basename))
        plt.ylabel('{} {} power of Z-scored LFP'.format(band, region_2_basename))
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig("./proc/power_correlation/zscored/{}/all_condition_{}_{}_power_correlation_of_zscored_{}_lfp.png".format(band, region_1_basename, region_2_basename, band))
        # Display the plot
        plt.show()




In [None]:
raise ValueError()

In [None]:
channel_map_and_all_trials_df = filtered_df

In [None]:
channel_map_and_all_trials_df["trial_outcome"].unique()

In [None]:
channel_map_and_all_trials_df["trial_or_baseline"]

In [None]:
for band in ALL_BANDS:
    band_df = band_to_power_correlation[band]["outlier_removed_df"]
    band_to_power_correlation[band]["region_pair_to_outcome_to_r2"] = defaultdict(nested_dict)
    for outcome in band_df["trial_outcome"].unique():
        outcome_df = band_df[band_df["trial_outcome"] == outcome]
        for region_1, region_2 in brain_region_pairs:
            region_1_basename = region_1.split("_")[0]
            region_2_basename = region_2.split("_")[0]
            
            x = outcome_df[region_1]
            y = outcome_df[region_2]
            
            # Perform linear regression to get the slope, intercept and r-value (correlation coefficient)
            slope, intercept, r_value, p_value, std_err = linregress(x, y)
            # Square the r value to get the r squared value
            r2_value = r_value**2
            band_to_power_correlation[band]["region_pair_to_outcome_to_r2"]["{}_{}".format(region_1.split("_")[0], region_2.split("_")[0])][outcome]["r"] = r_value
            band_to_power_correlation[band]["region_pair_to_outcome_to_r2"]["{}_{}".format(region_1.split("_")[0], region_2.split("_")[0])][outcome]["std"] = std_err
            
            # Create a line of best fit using the slope and intercept
            line = slope * x + intercept
            
            # Create scatter plot
            sns.scatterplot(x=x, y=y, data=outcome_df, hue='outcome_and_trial_or_baseline', palette=BASELINE_OUTCOME_TO_COLOR, style='outcome_and_trial_or_baseline', markers=['^', 'o'])
            
            # Plot line of best fit
            plt.plot(x, line, color='red')
            
            # Add R² value to the plot
            plt.text(0.1, 0.9, f'R = {r_value:.2f}', transform=plt.gca().transAxes)
            
            # Add labels and legend
            plt.title("Power Correlation of Z-scored {} LFP: {} and {}".format(band, region_2_basename, region_1_basename))
            plt.xlabel('{} {} Power of Z-scored LFP'.format(region_1_basename, band))
            plt.ylabel('{} {} Power of Z-scored LFP'.format(region_2_basename, band))
            plt.legend(loc="lower right")
            plt.tight_layout()
            plt.savefig("./proc/power_correlation/zscored/{}/{}_{}_{}_power_correlation_of_zscored_{}_lfp.png".format(band, outcome, region_1_basename, region_2_basename, band))
            # Display the plot
            plt.show()

In [None]:
for band in ALL_BANDS:
    # Convert the nested dictionary to a DataFrame
    data = []
    for group_name, group_data in band_to_power_correlation[band]['region_pair_to_outcome_to_r2'].items():
        for bar_name, bar_dict in group_data.items():
            data.append({"Group": group_name, "Bar": bar_name, "r": bar_dict["r"], "std": bar_dict["std"]})
    df = pd.DataFrame(data)
    
    # Create the bar plot using seaborn
    # sns.catplot(
    #     data=df, 
    #     x='Group', 
    #     y='r2', 
    #     hue='Bar', 
    #     kind='bar', 
    #     height=4, 
    #     aspect=2,
    #     legend=False,
    #     # yerr=df['std'].values,  # This line adds the SEM bars
    #     # capsize=0.1  # This line adds caps on the error bars
    # )
    
    # Create barplot
    ax = sns.barplot(x='Group', y='r', hue='Bar', data=df, ci=None)
    
    # Adding error bars
    groups = df['Group'].unique()
    bars_per_group = df['Bar'].nunique()
    bar_width = 0.8 / bars_per_group
    x_positions = []
    
    for i, group in enumerate(groups):
        num_bars = df[df['Group'] == group].shape[0]
        group_positions = np.linspace(i - bar_width*(num_bars-1)/2, i + bar_width*(num_bars-1)/2, num_bars)
        x_positions.extend(group_positions)
    
    for i, (r2, sem) in enumerate(zip(df['r'], df['std'])):
        plt.errorbar(x_positions[i], r2, yerr=sem, fmt='none', color='black', capsize=5)
    
    
    plt.xticks(rotation=90)
    plt.xlabel("Brain region pairs")
    plt.ylabel("Power correlation r")
    plt.legend(title="Trial Conditions")
    plt.title("{} Power correlations".format(band))
    plt.tight_layout()
    plt.grid()
    
    plt.savefig("./proc/power_correlation/zscored/all_zscored_{}_lfp_power_correlation.png".format(band))
    # Show the plot
    plt.show()

In [None]:

# Convert the nested dictionary to a DataFrame
data = []
for group_name, group_data in region_pair_to_outcome_to_r2.items():
    for bar_name, bar_dict in group_data.items():
        data.append({"Group": group_name, "Bar": bar_name, "r": bar_dict["r"], "std": bar_dict["std"]})
df = pd.DataFrame(data)

# Create the bar plot using seaborn
# sns.catplot(
#     data=df, 
#     x='Group', 
#     y='r2', 
#     hue='Bar', 
#     kind='bar', 
#     height=4, 
#     aspect=2,
#     legend=False,
#     # yerr=df['std'].values,  # This line adds the SEM bars
#     # capsize=0.1  # This line adds caps on the error bars
# )

# Create barplot
ax = sns.barplot(x='Group', y='r', hue='Bar', data=df, ci=None)

# Adding error bars
groups = df['Group'].unique()
bars_per_group = df['Bar'].nunique()
bar_width = 0.8 / bars_per_group
x_positions = []

for i, group in enumerate(groups):
    num_bars = df[df['Group'] == group].shape[0]
    group_positions = np.linspace(i - bar_width*(num_bars-1)/2, i + bar_width*(num_bars-1)/2, num_bars)
    x_positions.extend(group_positions)

for i, (r2, sem) in enumerate(zip(df['r'], df['std'])):
    plt.errorbar(x_positions[i], r2, yerr=sem, fmt='none', color='black', capsize=5)


plt.xticks(rotation=90)
plt.xlabel("Brain region pairs")
plt.ylabel("Power correlation r")
plt.legend(title="Trial Conditions")
plt.title("Power correlations")
plt.tight_layout()
plt.grid()

plt.savefig("./proc/power_correlation/zscored/all_zscored_lfp_power_correlation.png")
# Show the plot
plt.show()