In [1]:
from loren_frank_data_processing import (make_epochs_dataframe,
                                         make_neuron_dataframe)
from src.parameters import (ANIMALS, MIN_N_NEURONS, _BRAIN_AREAS)

from tqdm.auto import tqdm
import numpy as np 
import pickle
import xarray as xr
import os 
import pandas as pd

In [2]:
def get_sweep_results(epoch_key, datadir, sweep_speed_threshold=4):
    '''
    Get the average sweep length and replay exponent for one epoch
    '''
    animal, day, epoch = epoch_key
    
    #1, LOAD THETA SWEEPS RESULTS
    #load the classifier results
    cv_classifier_clusterless_results = xr.open_dataset(os.path.join(datadir, 'ThetaSweepTrajectories', f'{animal}_{day:02d}_{epoch:02d}_cv_classifier_clusterless_results.nc'))
    #load the data with pickle
    with open(os.path.join(datadir, 'ThetaSweepTrajectories', f'{animal}_{day:02d}_{epoch:02d}_speed_info.pkl'), 'rb') as f:
        speed_dic = pickle.load(f)
        speed = speed_dic.values
        
    #get the average sweep length
    #extract speed from data and threshold it with 10 cm/s
    is_running = speed > sweep_speed_threshold

    sweeps_dist_in_running = cv_classifier_clusterless_results.mental_distance_from_actual_position[is_running]

    
        
    return sweeps_dist_in_running.values

In [3]:
epoch_info = make_epochs_dataframe(ANIMALS)
neuron_info = make_neuron_dataframe(ANIMALS)

neuron_info = neuron_info.loc[
(neuron_info.type == 'principal') &
(neuron_info.numspikes > 100) &
neuron_info.area.isin(_BRAIN_AREAS)]

n_neurons = (neuron_info
                .groupby(['animal', 'day', 'epoch'])
                .neuron_id
                .agg(len)
                .rename('n_neurons')
                .to_frame())

epoch_info = epoch_info.join(n_neurons)
is_w_track = (epoch_info.environment
                .isin(['TrackA', 'TrackB', 'WTrackA', 'WTrackB']))
is_animal = epoch_info.index.isin(['bon', 'fra', 'gov', 'dud', 'con', 'Cor', 'dav', 'egy', 'cha'], level='animal')

valid_epochs = (is_w_track &
                (epoch_info.n_neurons > MIN_N_NEURONS) &
                is_animal
                )

In [4]:
DATA_DIR = '/home/zilong/Desktop/replay_trajectory_paper/Processed-Data/'

#initial a panda series to save the median
median_error = pd.DataFrame(columns=['animal', 'day', 'epoch', 'median'])
#create a list to store the median
median_error_list = []

for epoch_key in tqdm(epoch_info[valid_epochs].index, desc='epochs'):
    animal, day, epoch = epoch_key
    print(f'{animal}, {day}, {epoch}')
    #skip 'egy', 10, 2 since no mental_distance_from_actual_position is saved (double check later)
    if (animal, day, epoch) == ('egy', 10, 2):
        continue
    
    if (animal, day, epoch) == ('bon', 4, 2):
        continue
    
    sweep_dist = get_sweep_results(epoch_key, DATA_DIR, sweep_speed_threshold=4)
    #save the median into a pandas dataframe with animal, day, epoch, median
    sweep_dist_median = np.nanmedian(sweep_dist)
    median_error = median_error.append({'animal': animal, 'day': day, 'epoch': epoch, 'median': sweep_dist_median}, ignore_index=True)

    median_error_list.append(sweep_dist)
    
#turn the panda dataframe into a panda series
median_error_series = median_error.set_index(['animal', 'day', 'epoch'])['median']

epochs:   0%|          | 0/140 [00:00<?, ?it/s]

Cor, 1, 2
Cor, 1, 4
Cor, 2, 2
Cor, 2, 4
Cor, 3, 2
Cor, 3, 4
Cor, 4, 2
Cor, 4, 4
Cor, 7, 2
Cor, 8, 2
bon, 3, 2
bon, 3, 4
bon, 3, 6
bon, 4, 2
bon, 4, 4
bon, 4, 6
bon, 5, 2
bon, 5, 4
bon, 5, 6
bon, 6, 2
bon, 6, 4
bon, 6, 6
bon, 7, 2
bon, 7, 4
bon, 7, 6
bon, 8, 2
bon, 8, 4
bon, 8, 6
bon, 9, 2
bon, 9, 4
bon, 9, 6
bon, 10, 2
bon, 10, 4
bon, 10, 6
cha, 4, 2
cha, 4, 4
cha, 8, 2
cha, 8, 4
con, 1, 2
con, 1, 4
con, 2, 2
con, 2, 4
con, 3, 2
con, 3, 4
con, 4, 2
con, 4, 4
con, 4, 6
con, 5, 2
con, 5, 4
con, 5, 6
con, 6, 2
con, 6, 4
con, 6, 6
dav, 3, 2
dav, 3, 4
dav, 3, 6
dav, 4, 2
dav, 4, 4
dav, 4, 6
dav, 5, 3
dav, 6, 2
dav, 6, 4
dav, 6, 6
dav, 7, 2
dav, 7, 3
dav, 7, 5
dav, 7, 7
dav, 7, 9
dud, 2, 2
dud, 5, 2
egy, 5, 2
egy, 5, 4
egy, 5, 6
egy, 6, 4
egy, 6, 7
egy, 7, 2
egy, 7, 4
egy, 7, 6
egy, 8, 2
egy, 8, 4
egy, 8, 6
egy, 9, 2
egy, 9, 4
egy, 10, 2
egy, 10, 4
egy, 10, 6
egy, 11, 2
egy, 11, 4
egy, 11, 6
fra, 4, 2
fra, 4, 4
fra, 4, 6
fra, 5, 2
fra, 5, 4
fra, 5, 6
fra, 6, 2
fra, 6, 4
fra, 6, 6
fra, 7, 2
f

In [5]:
#calculate the mean of the median_error_series, as well as 95% confidence interval
mean = median_error_series.mean()
std = median_error_series.std()
ci = 1.96 * std / np.sqrt(len(median_error_series))

lower_bound = mean - ci
upper_bound = mean + ci

#print mean with 0.1 precision
print(f'mean: {mean:.1f}')
#print 95% confidence interval with 0.1 precision
print(f'95% confidence interval: {lower_bound:.1f} - {upper_bound:.1f}')


mean: 9.0
95% confidence interval: 8.2 - 9.8


In [6]:
#merge the median_error_list nto a long list
median_error_list_long = np.concatenate(median_error_list)

#print the mean and std and median of the median_error_list
print(f'mean: {np.nanmean(median_error_list_long):.1f}')
print(f'std: {np.nanstd(median_error_list_long):.1f}')
print(f'median: {np.nanmedian(median_error_list_long):.1f}')

mean: 29.3
std: 47.5
median: 8.3


In [28]:
median_error_list

[array([2.45960841, 2.45960841, 2.45960841, ..., 2.45960841, 2.45960841,
        2.45960841]),
 array([2.48575754, 2.48575754, 2.48575754, ..., 1.40452863, 1.40665412,
        1.40877961]),
 array([2.38959485, 2.38959485, 2.38959485, ..., 9.2383378 , 9.22158513,
        9.20483245]),
 array([143.21936382, 143.23951703, 148.19534275, ...,   2.38175824,
          2.38175824,   2.38175824]),
 array([2.48567678, 2.48567678, 2.48567678, ..., 8.06871628, 8.0723426 ,
        8.07596891]),
 array([198.4284643 , 198.4284643 , 193.6060643 , ..., 103.29251612,
        103.28373314, 103.27495016]),
 array([2.41209179, 2.41209179, 2.41209179, ..., 7.02562531, 7.00853929,
        6.99145326]),
 array([185.81122619, 185.81122619, 185.81122619, ...,   5.33898432,
          5.3383815 ,   5.33777868]),
 array([36.36291484, 36.36291484, 36.36291484, ...,  5.57148736,
         5.57148736,  5.57148736]),
 array([0.82930199, 0.8218799 , 0.81445781, ..., 0.63771014, 0.63771014,
        0.63771014]),
 array([