In [29]:
from ephysvibe.structures.population_data import PopulationData
from ephysvibe.structures.neuron_data import NeuronData
from ephysvibe.trials.spikes import firing_rate
from ephysvibe.trials import select_trials
from ephysvibe.task import task_constants

import numpy as np
from sklearn.decomposition import PCA

In [5]:
def check_number_of_trials(xdict, samples, min_ntr):
    for key in samples:
        if xdict[key].shape[0] < min_ntr:
            return False
    return True

In [20]:
def preproc_for_pca(
    neu: NeuronData,
    time_before_son: str,
    time_before_t1on: str,
    sp_son: str,
    sp_t1on: str,
    mask_son: str,
    min_ntr: int,
    start_sample: int,
    end_sample: int,
    start_test: int,
    end_test: int,
    avgwin: int = 100,
    step: int = 1,
    zscore=False,
    no_match=False,
):
    # Average fr across time
    idx_start_sample = int((getattr(neu, time_before_son) + start_sample) / step)
    idx_end_sample = int((getattr(neu, time_before_son) + end_sample) / step)
    idx_start_test = int((getattr(neu, time_before_t1on) + start_test) / step)
    idx_end_test = int((getattr(neu, time_before_t1on) + end_test) / step)
    sampleon = getattr(neu, sp_son)
    t1on = getattr(neu, sp_t1on)

    fr_son = firing_rate.moving_average(sampleon, win=avgwin, step=step)[
        :, idx_start_sample:idx_end_sample
    ]
    fr_t1on = firing_rate.moving_average(t1on, win=avgwin, step=step)[
        :, idx_start_test:idx_end_test
    ]

    fr = np.concatenate([fr_son, fr_t1on], axis=1)
    mask_son = getattr(neu, mask_son)
    sample_id = neu.sample_id[mask_son]
    if no_match:
        mask_no_match = np.where(
            neu.test_stimuli[mask_son, 0] == sample_id,
            False,
            True,
        )
        fr = fr[mask_no_match]
        sample_id = sample_id[mask_no_match]
    if len(fr) < 2:
        return None
    if zscore:
        fr_std = np.std(fr, ddof=1, axis=0)
        fr_std = np.where(fr_std == 0, 1, fr_std)
        fr = (fr - np.mean(fr, axis=0).reshape(1, -1)) / fr_std.reshape(1, -1)

    fr = np.array(fr, dtype=np.float32)
    fr_samples = select_trials.get_sp_by_sample(fr, sample_id)
    samples = ["0","11", "15", "51", "55"]
    enough_tr = check_number_of_trials(fr_samples, samples, min_ntr)
    if not enough_tr:
        return None
    return fr_samples

In [30]:
def compute_pca(x, n_comp=50):
    model = PCA(n_components=n_comp).fit(x.T)
    C = model.components_
    pc_s = C @ x
    return model, pc_s

In [21]:
args={
    "preprocessing": {
        "min_ntr": 15,
        "start_sample": -200,
        "end_sample": 850,
        "start_test": -400,
        "end_test": 500,
        "step": 1,
        "time_before_son": "time_before_son_in",
        "time_before_t1on": "time_before_t1on_in",
        "sp_son": "sp_son_in",
        "sp_t1on": "sp_t1on_in",
        "mask_son": "mask_son_in",
        "zscore":False,
        "no_match": False,
    },
    # workspace
    "workspace": {"output": "", "path": ""},
}

In [22]:
trial_duration = int(
    (
        (args['preprocessing']["end_sample"] - args['preprocessing']["start_sample"])
        + (args['preprocessing']["end_test"] - args['preprocessing']["start_test"])
    )
    / args['preprocessing']["step"]
)

In [23]:
popu = PopulationData.from_python_hdf5("/envau/work/invibe/USERS/IBOS/data/Riesling/TSCM/OpenEphys/population/lip/2024_08_28_12_23_36/population.h5")

In [24]:

list_data = popu.execute_function(
    preproc_for_pca,
    **args['preprocessing'],
    ret_df=False,
)
list_data = [idata for idata in list_data if idata is not None]

100%|██████████| 530/530 [00:24<00:00, 21.74it/s]


In [25]:
len(list_data)

502

In [26]:
fr_cells= np.empty((len(list_data),int(trial_duration*5)),dtype=np.float16)
for i,idata in enumerate(list_data):
    
    fr_cells[i] = np.concatenate((np.mean(idata ['0'],axis=0),np.mean(idata['11'],axis=0),np.mean(idata['15'],axis=0),np.mean(idata['51'],axis=0),np.mean(idata['55'],axis=0)))

In [31]:
model,pc_s = compute_pca(fr_cells,n_comp=200)

In [None]:

reshape_pc_s = pc_s.reshape(n_comp,-1,trial_duration)
f,ax=plt.subplots(3,2,figsize=(15,10),sharex=True)
for i_sample in samples.keys():
    
    ax[0,0].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[0,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)
    ax[0,1].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[1,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)
    ax[1,0].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[2,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)
    ax[1,1].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[3,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)
    ax[2,0].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[4,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)
    ax[2,1].plot((np.arange(trial_duration)-200)[:part1],reshape_pc_s[5,samples[i_sample],:part1],color=task_constants.PALETTE_B1[i_sample],label=i_sample)

    ax[0,0].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[0,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])
    ax[0,1].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[1,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])
    ax[1,0].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[2,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])
    ax[1,1].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[3,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])
    ax[2,0].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[4,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])
    ax[2,1].plot((np.arange(trial_duration)-200)[part1:],reshape_pc_s[5,samples[i_sample],part1:],color=task_constants.PALETTE_B1[i_sample])

ax[0,0].vlines([0,458,1258],ax[0,0].get_ylim()[0],ax[0,0].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)
ax[0,1].vlines([0,458,1258],ax[0,1].get_ylim()[0],ax[0,1].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)
ax[1,0].vlines([0,458,1258],ax[1,0].get_ylim()[0],ax[1,0].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)
ax[1,1].vlines([0,458,1258],ax[1,1].get_ylim()[0],ax[1,1].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)
ax[2,0].vlines([0,458,1258],ax[2,0].get_ylim()[0],ax[2,0].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)
ax[2,1].vlines([0,458,1258],ax[2,1].get_ylim()[0],ax[2,1].get_ylim()[1],color='k',linestyle='--',linewidth=0.5)

ax[0,0].set(ylabel='PC1')
ax[0,1].set(ylabel='PC2')
ax[1,0].set(ylabel='PC3')
ax[1,1].set(ylabel='PC4')
ax[2,0].set(xlabel='Time(ms)',ylabel='PC5')
ax[2,1].set(xlabel='Time(ms)',ylabel='PC6')
ax[0,0].legend()
