In [1]:
import numpy as np
import pandas as pd
from glob import glob
import os
import scipy.io as sio
import mat73

In [2]:
def load_waveform(fname, fraction_of_max=2e-2):
    data = sio.loadmat(fname)
    k = list(data.keys())[-1]
    data = data[k]

    base_waveform = data["waveform"][0][-1].flatten()
    max_base_waveform = np.abs(base_waveform).max()
    ids_good = np.where(np.abs(base_waveform) > fraction_of_max * max_base_waveform)[0]

    waveforms = []
    for i in range(len(data["waveform"][0]) - 1):
        waveform = data["waveform"][0][i].flatten()
        waveform = waveform[ids_good[0] : ids_good[-1]]
        waveform = waveform / max_base_waveform
        waveform = waveform - waveform[0]
        waveforms.append(waveform)
    base_waveform = base_waveform[ids_good[0] : ids_good[-1]] / max_base_waveform
    base_waveform = base_waveform - base_waveform[0]

    return pd.DataFrame(
        dict(
            stimulus_fname=[x[0] for x in data["fname"][0]],
            stimulus_marker=[x[0][0] for x in data["marker"][0]],
            stimulus_sampling_rate=[x[0][0] for x in data["samprate"][0]],
            stimulus_resistance=[x[0][0] for x in data["RO"][0]],
            stimulus_capacitance=[x[0][0] for x in data["CO"][0]],
            stimulus_amplitude_modulation=[x[0][0] for x in data["amp_mod"][0]],
            stimulus_waveform_modulation=[x[0][0] for x in data["wav_mod"][0]],
            stimulus_value_max=[x[0][0] for x in data["maxv"][0]],
            stimulus_value_min=[x[0][0] for x in data["minv"][0]],
            waveform=waveforms + [base_waveform],
            base_waveform=[base_waveform] * len(data["waveform"][0]),
        )
    )


def load_lfp_data(fname, lfp_id_min=301, lfp_id_max=512):
    data_means = mat73.loadmat(fname)["LfpMeans"]
    experiment_date, session_id, zone, _ = fname.split("/")[-1].split("-")

    lfp_means_time = data_means["lfptime"][-1]
    lfp_sampling_rate = data_means["vdt"][-1]

    # extract the lfp traces
    lfp_trace = data_means["lfpNorm"][:-1]
    mean_lfp_trace = data_means["lfpMean"][:-1]
    base_lfp_trace = data_means["b1lfpNorm"][:-1]
    base_mean_lfp_trace = data_means["b1lfpMean"][:-1]

    def process_response(list_of_traces):
        return [x.T[lfp_id_min:lfp_id_max].min(axis=0) for x in list_of_traces]

    # compute the lfp responses for single trials
    lfp_response = [process_response(y) for y in lfp_trace]
    mean_lfp_response = [process_response(y.reshape(y.shape[0], -1).T) for y in mean_lfp_trace]
    base_lfp_response = [process_response(y) for y in base_lfp_trace]
    base_mean_lfp_response = [process_response(y.reshape(y.shape[0], -1).T) for y in base_mean_lfp_trace]

    def process_response_modulation(list_of_responses, base):
        return [list_of_responses[i] / base[i] - 1 for i in range(len(list_of_responses))]

    # compute the lfp response modulation for single trials
    lfp_response_modulation = [
        process_response_modulation(y, base) for y, base in zip(lfp_response, base_mean_lfp_response)
    ]

    # compute the lfp response modulation for the whole trial
    mean_lfp_response_modulation = [
        (np.array(mean_lfp_response[i]) / np.array(base_mean_lfp_response[i])).mean() - 1
        for i in range(len(mean_lfp_response))
    ]

    return pd.DataFrame(
        dict(
            stimulus_marker=[int(x) for x in data_means["marker"][:-1]],
            number_bouts=[int(x) for x in data_means["bouts"][:-1]],
            lfp_trace=lfp_trace,
            mean_lfp_trace=mean_lfp_trace,
            base_lfp_trace=base_lfp_trace,
            base_mean_lfp_trace=base_mean_lfp_trace,
            lfp_response=lfp_response,
            mean_lfp_response=mean_lfp_response,
            base_lfp_response=base_lfp_response,
            base_mean_lfp_response=base_mean_lfp_response,
            lfp_response_modulation=lfp_response_modulation,
            mean_lfp_response_modulation=mean_lfp_response_modulation,
            stimulus_amplitude_modulation=data_means["ampmod"][:-1],
            stimulus_waveform_modulation=data_means["wavmod"][:-1],
            lfp_sampling_rate=[lfp_sampling_rate] * len(data_means["marker"][:-1]),
            lfp_times=[lfp_means_time] * len(data_means["marker"][:-1]),
            experiment_date=[experiment_date] * len(data_means["marker"][:-1]),
            session_id=[session_id] * len(data_means["marker"][:-1]),
            zone=[zone] * len(data_means["marker"][:-1]),
        )
    )


def expand_data_to_single_trials(dfrow):
    num_bouts = dfrow["number_bouts"]

    new_df = pd.DataFrame()
    for i in range(num_bouts):
        lfp_trace = list(dfrow["lfp_trace"][i])
        mean_lfp_trace = [dfrow["mean_lfp_trace"].reshape(dfrow["mean_lfp_trace"].shape[0], -1).T[i]] * len(lfp_trace)
        base_mean_lfp_trace = [
            dfrow["base_mean_lfp_trace"].reshape(dfrow["base_mean_lfp_trace"].shape[0], -1).T[i]
        ] * len(lfp_trace)
        lfp_response = list(dfrow["lfp_response"][i])
        mean_lfp_response = [dfrow["mean_lfp_response"][i]] * len(lfp_trace)
        base_mean_lfp_response = [dfrow["base_mean_lfp_response"][i]] * len(lfp_trace)
        lfp_response_modulation = list(dfrow["lfp_response_modulation"][i])

        new_df = pd.concat(
            [
                new_df,
                pd.DataFrame(
                    dict(
                        lfp_trace=lfp_trace,
                        mean_lfp_trace=mean_lfp_trace,
                        base_mean_lfp_trace=base_mean_lfp_trace,
                        lfp_response=lfp_response,
                        mean_lfp_response=mean_lfp_response,
                        base_mean_lfp_response=base_mean_lfp_response,
                        lfp_response_modulation=lfp_response_modulation,
                    ),
                ),
            ],
            axis=0,
            ignore_index=True,
        )

    for col_name in [
        "stimulus_marker",
        "number_bouts",
        # "lfp_trace",
        # "mean_lfp_trace",
        # DELETED "base_lfp_trace",
        # "base_mean_lfp_trace",
        # "lfp_response",
        # "mean_lfp_response",
        # DELETED "base_lfp_response",
        # "base_mean_lfp_response",
        # "lfp_response_modulation",
        "mean_lfp_response_modulation",
        "stimulus_amplitude_modulation_x",
        "stimulus_waveform_modulation_x",
        "lfp_sampling_rate",
        "lfp_times",
        "fish_id",
        "experiment_date",
        "session_id",
        "zone",
        "paired_experiment",
        "stimulus_fname",
        "stimulus_sampling_rate",
        "stimulus_resistance",
        "stimulus_capacitance",
        "stimulus_amplitude_modulation_y",
        "stimulus_waveform_modulation_y",
        "stimulus_value_max",
        "stimulus_value_min",
        "waveform",
        "base_waveform",
    ]:
        new_df[col_name] = [dfrow[col_name]] * new_df.shape[0]
    return new_df

In [3]:
data = pd.DataFrame()
for folder in glob("raw/*", recursive=True):
    if os.path.isdir(folder):
        print(folder)
        fnames = glob(f"{folder}/*lfp_means.mat")
        waveforms_fname = glob(f"{folder}/waveform*.mat")
        if len(waveforms_fname) == 1:
            fish_id = folder.split("/")[-1].split("-")[0]
            waveforms_fname = waveforms_fname[0]
            waveforms = load_waveform(waveforms_fname)
            new_lfp_data = pd.DataFrame()
            for fname in fnames:
                lfp_data = load_lfp_data(fname)
                new_lfp_data = pd.concat([new_lfp_data, lfp_data], axis=0, ignore_index=True)
            new_lfp_data["fish_id"] = fish_id
            new_lfp_data["paired_experiment"] = "paired" in folder
            new_lfp_data = pd.merge(new_lfp_data, waveforms, on="stimulus_marker")
        else:
            print(f"Folder {folder} does not contain a waveform file or contains more than one waveform file.")
            continue
        data = pd.concat([data, new_lfp_data], axis=0, ignore_index=True)

data_single_trials = data.apply(expand_data_to_single_trials, axis=1)  # type: ignore
data_single_trials = pd.concat(data_single_trials.tolist(), axis=0, ignore_index=True)
data.to_pickle("processed/trial_averages.pkl")
data_single_trials.to_pickle("processed/single_trials.pkl")

raw/fish_03-20190711-separate
raw/fish_12-20200826-paired
raw/fish_08-20200714-mz
raw/fish_02-20190617-dlz
raw/fish_06-20200623-separate
raw/fish_13-20200902-paired
raw/fish_01-20190605-separate
raw/fish_10-20200722-separate
raw/fish_07-20200626-dlz
raw/fish_04-20190731_dlz
raw/fish_09-20200715-separate
raw/fish_11-20200806-paired
raw/fish_05-20190910-mz
