In [1]:
import pandas as pd
import numpy as np
import ast
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import multirecording_spikeanalysis_edit as spike

cols = ['condition ', 'session_dir', 'all_subjects', 'tone_start_timestamp', 'tone_stop_timestamp']

# Load the data
df = pd.read_excel('rce_pilot_2_per_video_trial_labels.xlsx', usecols=cols, engine='openpyxl')

df2 = df.dropna() # Drop the rows missing data
df3 = df2.copy()
df3['all_subjects'] = df3['all_subjects'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) # Make the 'all_subjects' column readable as a list
df4 = df3[df3['all_subjects'].apply(lambda x: len(x) < 3)] # Ignore novel sessions for now

# Initialize an empty list to collect data for the new DataFrame
new_df_data = []

for _, row in df4.iterrows():
    session_dir = row['session_dir']
    subjects = row['all_subjects']
    condition = row['condition ']

    # Split session_dir on '_subj_' and take the first part only
    # This ensures everything after '_subj_' is ignored
    base_session_dir = session_dir.split('_subj_')[0]

    for subject in subjects:
        subject_formatted = subject.replace('.', '-')
        # Append formatted subject to the base session_dir correctly
        subj_recording = f"{base_session_dir}_subj_{subject_formatted}"
        new_df_data.append({
            'session_dir': session_dir,
            'subject': subject,
            'subj_recording': subj_recording,
            'condition': condition if condition in ['rewarded', 'omission', 'both_rewarded', 'tie'] else ('win' if str(condition) == str(subject) else 'lose'),
            'tone_start_timestamp': row['tone_start_timestamp'],
            'tone_stop_timestamp': row['tone_stop_timestamp']
        })

# Convert list to DataFrame
new_df = pd.DataFrame(new_df_data)
new_df = new_df.drop_duplicates()

# Prepare timestamp_dicts from new_df
timestamp_dicts = {}
for _, row in new_df.iterrows():
    key = row['subj_recording']
    condition = row['condition']
    timestamp_start = int(row['tone_start_timestamp']) // 20
    timestamp_end = int(row['tone_stop_timestamp']) // 20
    tuple_val = (timestamp_start, timestamp_end)

    if key not in timestamp_dicts:
        timestamp_dicts[key] = {cond: [] for cond in ['rewarded', 'win', 'lose', 'omission', 'both_rewarded', 'tie']}
    timestamp_dicts[key][condition].append(tuple_val)

# Convert lists in timestamp_dicts to numpy arrays
for subj_recording in timestamp_dicts:
    for condition in timestamp_dicts[subj_recording]:
        timestamp_dicts[subj_recording][condition] = np.array(timestamp_dicts[subj_recording][condition], dtype=np.int64)
        

# Construct the path in a platform-independent way (HiPerGator or Windows)
ephys_path = Path('.') / 'export' / 'updated_phys' / 'test'

ephys_data = spike.EphysRecordingCollection(str(ephys_path))

<class 'numpy.ndarray'>
20230612_101430_standard_comp_to_training_D1_subj_1-3_t3b3L_box2_merged.rec
<class 'numpy.ndarray'>
20230617_115521_standard_comp_to_omission_D1_subj_1-1_t1b3L_box1_merged.rec
<class 'numpy.ndarray'>
Unit 92 is unsorted & has 2494 spikes
Unit 92 will be deleted
20230622_110832_standard_comp_to_both_rewarded_D1_subj_1-1_t1b3L_box1_merged.rec
<class 'numpy.ndarray'>
Unit 103 is unsorted & has 512 spikes
Unit 103 will be deleted
20230622_110832_standard_comp_to_both_rewarded_D1_subj_1-2_t3b3L_box1_merged.rec
Please assign event dictionaries to each recording
as recording.event_dict
event_dict = {event name(str): np.array[[start(ms), stop(ms)]...]
Please assign subjects to each recording as recording.subject


In [2]:
for recording in ephys_data.collection.keys():
    # Check if the recording key (without everything after subject #) is in timestamp_dicts
    start_pos = recording.find('subj_')
    # Add the length of 'subj_' and 3 additional characters to include after 'subj_'
    end_pos = start_pos + len('subj_') + 3
    # Slice the recording key to get everything up to and including the subject identifier plus three characters
    recording_key_without_suffix = recording[:end_pos]
    if recording_key_without_suffix in timestamp_dicts:
        # Assign the corresponding timestamp_dicts dictionary to event_dict
        ephys_data.collection[recording].event_dict = timestamp_dicts[recording_key_without_suffix]
        
        # Extract the subject from the recording key
        start = recording.find('subj_') + 5  # Start index after 'subj_'
        subject = recording[start:start+3]
        
        # Assign the extracted subject
        ephys_data.collection[recording].subject = subject
        
spike_analysis = spike.SpikeAnalysis_MultiRecording(ephys_data, timebin = 100, smoothing_window=250, ignore_freq = 0.5)

All set to analyze


In [3]:
import statsmodels.api as sm
from statsmodels.discrete.discrete_model import Poisson

def analyze_firing_rates(ephyscollection, event, timebin, baseline_window, pre_window=0, post_window=0, equalize=0.5):
  """
  Analyzes firing rate changes during events using Poisson GLM.

  Args:
      ephyscollection: An EphysRecordingCollection object.
      event: The event name (str) or event data (numpy array).
      timebin: Time bin size in seconds.
      baseline_window: Baseline window size in seconds.
      pre_window: Pre-event window size in seconds (default: 0).
      post_window: Post-event window size in seconds (default: 0).
      equalize: Time window size (in seconds) to equalize event durations (default: 0.5).

  Returns:
      A dictionary containing unit IDs as keys and p-values as values.
  """

  unit_pvals = {}
  for recording_name, recording in ephyscollection.collection.items():
    # Get unit spiketrains and baseline firing rates
    recording.__get_unit_spiketrains__(timebin)
    baseline_rates = recording.__calc_preevent_baseline__(baseline_window, baseline_window, 0, event)

    # Prepare data for GLM
    for unit, spiketrain in recording.unit_spiketrains.items():
      if len(spiketrain) == 0:
        continue
      event_rates = recording.__get_unit_event_firing_rates__(unit, event, equalize, pre_window, post_window)
      response = np.concatenate([*event_rates])
      design = np.ones((len(response), 2))
      design[:, 1] = np.concatenate([*baseline_rates[unit]])

      # Fit Poisson GLM
      glm_model = Poisson(response, design)
      glm_result = glm_model.fit()

      # Get p-value for baseline effect
      unit_pvals[f"{recording_name}-{unit}"] = glm_result.pvalues[1]

  return unit_pvals

# Example usage
event_data = np.load("event_times.npy")  # Assuming event data is loaded
pvals = analyze_firing_rates(ephyscollection, event_data, 0.001, 0.2)

# Print units with significant changes (adjust for significance level)
for unit, pval in pvals.items():
  if pval < 0.05:
    print(f"Unit {unit} has significant change in firing rate during event (p-value: {pval})")

FileNotFoundError: [Errno 2] No such file or directory: 'event_times.npy'