In [1]:
from ephysvibe.structures.neuron_data import NeuronData
import numpy as np
from dPCA import dPCA
import matplotlib.pyplot as plt
import os
import glob
from ephysvibe.trials import align_trials
from ephysvibe.trials.spikes import firing_rate
from ephysvibe.task import task_constants
import platform
from joblib import Parallel, delayed
from tqdm import tqdm
import h5py
from pathlib import Path
from typing import Dict,List
#import pca_tools
seed=2024

### Def functions

In [None]:
def get_neuron_test_fr(
    path,
    time_before_test,
    idx_start_test,
    idx_end_test,
    min_trials,
    min_neu=False,
    avgwin=50,
    n_sp_sec=5,
    norm=False,
    zscore=False,
    error_type=0
):
    neu_data = NeuronData.from_python_hdf5(path)
    select_block = 1
    code = 1
    # Select trials aligned to test onset
    sp_test_on, mask_t = align_trials.align_on(
        sp_samples=neu_data.sp_samples,
        code_samples=neu_data.code_samples,
        code_numbers=neu_data.code_numbers,
        trial_error=neu_data.trial_error,
        block=neu_data.block,
        pos_code=neu_data.pos_code,
        select_block=select_block,
        select_pos=code,
        event="test_on_1",
        time_before=time_before_test,
        error_type=error_type,
    )
    # Build masks to select trials with match in the n_test
    mask_match = np.where(
        neu_data.test_stimuli[mask_t, 0] == neu_data.sample_id[mask_t],
        True,
        False,
    )
    mask_neu = neu_data.sample_id[mask_t] != 0
    # Build masks to select trials with the selected number of test presentations
    max_test = neu_data.test_stimuli[mask_t].shape[1]
    mask_ntest = (
        max_test - np.sum(np.isnan(neu_data.test_stimuli[mask_t]), axis=1)
    ) > (n_test - 1)

    if nonmatch:  # include nonmatch trials
        mask_match_neu = np.logical_or(mask_ntest, mask_neu)
    else:
        mask_match_neu = np.logical_or(mask_match, mask_neu)
    if np.sum(mask_match_neu) < 20:
        return {"fr": None}

    # Average fr across time
    avg_sample_on = firing_rate.moving_average(
        sp_sample_on[mask_match_neu], win=avgwin, step=1
    )[:, idx_start_sample:idx_end_sample]
    avg_test1_on = firing_rate.moving_average(
        sp_test_on[mask_match_neu], win=avgwin, step=1
    )[:, idx_start_test:idx_end_test]
    # Concatenate sample and test aligned data
    sp = np.concatenate((avg_sample_on, avg_test1_on), axis=1)
    # Check fr
    ms_fr = np.nanmean(sp) * 1000 > n_sp_sec
    if not ms_fr:
        return {"fr": None}
    # Check number of trials
    sample_id = neu_data.sample_id[mask_t][mask_match_neu]
    samples = [0, 11, 15, 55, 51]

    if min_neu:
        sample_fr = sp[np.where(sample_id == 0, True, False)]
        if sample_fr.shape[0] < min_trials:
            return {"fr": None}
    else:
        for s_id in samples:
            sample_fr = sp[np.where(sample_id == s_id, True, False)]
            if sample_fr.shape[0] < min_trials:
                return {"fr": None}
    if norm == True:
        sp = sp / np.max(sp)
    if zscore == True:
        sp_std = np.std(sp, ddof=1, axis=0)
        sp_std = np.where(sp_std == 0, 1, sp_std)
        sp = (sp - np.mean(sp, axis=0).reshape(1, -1)) / sp_std.reshape(1, -1)
    # Get trials grouped by sample
    fr_samples = select_trials.get_sp_by_sample(sp, sample_id, samples=samples)

    if fr_samples is None:
        return {"fr": None}
    return {"fr": fr_samples}

### Read data

In [2]:
if platform.system() == 'Linux':
    basepath = '/envau/work/invibe/USERS/IBOS/data/Riesling/TSCM/OpenEphys/new_structure/'
elif platform.system() == 'Windows':
    basepath = 'C:/Users/camil/Documents/int/'

In [3]:
area='lip'
neu_path = basepath+'session_struct/'+area+'/neurons/*neu.h5'
path_list = glob.glob(neu_path)

In [None]:
data = Parallel(n_jobs=-1)(
    delayed(get_neuron_sample_test_fr)(
        path=path,
        time_before_sample=time_before_sample,
        time_before_test=time_before_test,
        idx_start_sample=idx_start_sample,
        idx_end_sample=idx_end_sample,
        idx_start_test=idx_start_test,
        idx_end_test=idx_end_test,
        n_test=n_test,
        min_trials=min_trials,
        nonmatch=nonmatch,
        avgwin=avgwin,
        n_sp_sec=min_sp_sec,
        norm=False,
        zscore=True,
    )
    for path in tqdm(path_list[:50])
)