In [1]:
import pandas as pd
from ephys_queries import db_setup_core
from exporters import (
    RecordingsExporter, NeuronsExporter, 
    DistanceExporter, EventsExporter, SpikesExporter,
    EEGExporter
)
from drn_interactions.load import get_group_names
from dotenv import load_dotenv
from pathlib import Path
from tqdm.notebook import tqdm
from pathlib import Path

In [2]:
load_dotenv()
engine, metadata = db_setup_core()

In [3]:
class ExportRunner:

    def __init__(self, outdir: Path, blocks_to_export, footshock_blocks=("base_shock", "chal_shock")):
        self.outdir = outdir
        self.footshock_blocks = footshock_blocks
        self.blocks_to_export = blocks_to_export
    
    def is_footshock(self, block):
        return block in self.footshock_blocks
    
    def run(self):
        recordings = RecordingsExporter(engine, metadata).processed_data
        neurons = NeuronsExporter(engine, metadata).processed_data
        distances = DistanceExporter(engine, metadata).processed_data

        self.save_parquet("recordings", recordings)
        self.save_parquet("neurons", neurons)
        self.save_parquet("distances", distances)

        
        for block in tqdm(self.blocks_to_export):
            block_datasets = {}
            block_datasets["spiketimes"] = SpikesExporter(engine, metadata, block=block).processed_data
            eeg_exporter = EEGExporter(engine, metadata, block=block)
            block_datasets["stft"]  = eeg_exporter.processed_data
            block_datasets["eeg_band_ts"] = eeg_exporter.get_band_ts()
            if self.is_footshock(block):
                block_datasets["events"] = EventsExporter(engine, metadata, block=block).processed_data
            for datatype, df in block_datasets.items():
                self.save_parquet(datatype, df, block_name=block)
    
    def save_parquet(
        self, datatype, df, block_name = None):
        p = self.outdir
        if block_name is not None:
            p = p / block_name
        p.mkdir(exist_ok=True, parents=True)
        fname = str(p / datatype) + ".parquet.gzip"
        df.to_parquet(fname, compression="gzip")

In [4]:
outdir = Path(r"C:\Users\roryl\repos\DRN Interactions\data")

ExportRunner(outdir=outdir, blocks_to_export=(
        "pre", 
        "base_shock",
        "post_base_shock",
        "chal",
        "chal_shock",
        "post_chal_shock",
        "way",
        "pre",
        "chal",
        "way",)).run()

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
# RecordingsExporter(engine, metadata).processed_data
# NeuronsExporter(engine, metadata).processed_data
# DistanceExporter(engine, metadata).processed_data
# EventsExporter(engine, metadata, block="base_shock").processed_data
# SpikesExporter(engine, metadata, block="base_shock").processed_data
# EEGExporter(engine, metadata, block="base_shock").processed_data

In [2]:
load_dotenv()
engine, metadata = db_setup_core()

In [7]:
class NeuronsExporter:
    def __init__(self, engine, metadata):
        self.engine = engine
        self.metadata = metadata
        self._raw_data = None
        self._processed_data = None

    @property
    def raw_data(self):
        if self._raw_data is None:
            self._raw_data = load_neurons(engine, metadata).rename(
                columns={"id": "neuron_id"}
            )
        return self._raw_data

    @property
    def processed_data(self):
        if self._processed_data is None:
            df_spikes = load_spiketimes(engine, metadata, block_name="pre")
            waveforms = load_waveforms(engine, metadata)
            peaks = waveform_peaks_by_neuron(waveforms, neuron_col="neuron_id", 
                         index_col="waveform_index", 
                         value_col="waveform_value").dropna()
            width = waveform_width_by_neuron(peaks, peak_names=["initiation", "ahp"])
            peak_asym = peak_asymmetry_by_neuron(peaks, peak_names=["initiation", "ahp"])
            mfr = mean_firing_rate_by(
                df_spikes, spiketimes_col="spiketimes", spiketrain_col="neuron_id"
            )
            cv_isi = cv2_isi_by(
                df_spikes, spiketimes_col="spiketimes", spiketrain_col="neuron_id"
            )
            self._processed_data = width.merge(peak_asym).merge(mfr).merge(cv_isi).merge(self.raw_data)
        return self._processed_data


In [12]:
class Exporter:
    def __init__(self, engine, metadata):
        self.engine = engine
        self.metadata = metadata
        self._raw_data = None
        self._processed_data = None

    def _get_raw_data(self):
        ...
    
    def _process_data(self, raw_data):
        return raw_data

    @property
    def raw_data(self):
        if self._raw_data is None:
            self._raw_data = self._get_raw_data()
        return self._raw_data

    @property
    def processed_data(self):
        if self._processed_data is None:
            self._processed_data = self._process_data(self.raw_data)
        return self._processed_data


class NeuronsExporter(Exporter):
    def _get_raw_data(self):
        return load_neurons(engine, metadata).rename(
                columns={"id": "neuron_id"}
            )

    def _process_data(self, raw_data):
        df_spikes = load_spiketimes(engine, metadata, block_name="pre")
        waveforms = load_waveforms(engine, metadata)
        peaks = waveform_peaks_by_neuron(waveforms, neuron_col="neuron_id", 
                         index_col="waveform_index", 
                         value_col="waveform_value").dropna()
        width = waveform_width_by_neuron(peaks, peak_names=["initiation", "ahp"])
        peak_asym = peak_asymmetry_by_neuron(peaks, peak_names=["initiation", "ahp"])
        mfr = mean_firing_rate_by(
                df_spikes, spiketimes_col="spiketimes", spiketrain_col="neuron_id"
            )
        cv_isi = cv2_isi_by(
                df_spikes, spiketimes_col="spiketimes", spiketrain_col="neuron_id"
            )
        return width.merge(peak_asym).merge(mfr).merge(cv_isi).merge(raw_data)


class DistanceExporter(Exporter):
    @staticmethod
    def _distance_between_chans(ch1, ch2):
        """
            Calculate distances between two channels on a cambridge neurotech 32 channel P series probe.
            
            Electrode spec:
                2 shanks 250um appart
                Each shank has 16 channels in two columns of 8, spaced 22.5um appart
                Contacts are placed 25um above eachother
            """
        # Shank
        shank_1 = 1 if ch1 <= 15 else 2
        shank_2 = 1 if ch2 <= 15 else 2
        width = 250 if shank_1 != shank_2 else 0

        # Column
        col_1 = 1 if ch1 % 2 == 0 else 2
        col_2 = 1 if ch2 % 2 == 0 else 2
        width = 22.5 if (col_1 != col_2) and (width == 0) else width

        #
        ch1t = ch1 - 16 if ch1 > 15 else ch1
        ch2t = ch2 - 16 if ch2 > 15 else ch2
        height = abs(ch1t - ch2t) * 25

        return math.hypot(height, width) if width else height
    
    def _get_raw_data(self):
        return load_neurons(self.engine, self.metadata)
    
    def _process_data(df, raw_data):
        dfs = []
        for session_name in raw_data.session_name.unique():
            neurons = raw_data.loc[lambda x: x.session_name == session_name]
            combs = combinations(neurons.id, r=2)
            c1s, c2s = [], []
            for comb in combs:
                c1s.append(comb[0])
                c2s.append(comb[1])
            by_comb = pd.DataFrame({"neuron1": c1s, "neuron2": c2s})
            by_comb = (
                by_comb.merge(neurons[["id", "channel"]], left_on="neuron1", right_on="id")
                .drop("id", axis=1)
                .rename(columns={"channel": "channel_n1"})
            )
            by_comb = (
                by_comb.merge(neurons[["id", "channel"]], left_on="neuron2", right_on="id")
                .drop("id", axis=1)
                .rename(columns={"channel": "channel_n2"})
            )
            by_comb["distance"] = by_comb.apply(
                lambda x: _distance_between_chans(x.channel_n1, x.channel_n2), axis=1
            )
            by_comb["session_name"] = session_name
            dfs.append(by_comb)
        return pd.concat(dfs)


class RecordingsExporter(Exporter):
    def _get_raw_data(self):
        return select_recording_sessions(self.engine, self.metadata, group_names=get_group_names())
    
    def _process_data(self, raw_data):
        return raw_data.rename(columns={"id": "session_id"})

class BlockExporter(Exporter):
    def __init__(self, engine, metadata, block):
        super().__init__(engine=engine, meta=metadata)
        self.block = block
        self.t_before = 0 if self.block == "pre" else 600
        self.align_to_block = False if self.block == "pre" else True


class EventsExporter(BlockExporter):
    FS = 30000
    def _get_raw_data(self):
        return select_discrete_data(
        engine,
        metadata,
        group_names=get_group_names(),
        block_name=self.block,
        align_to_block=self.align_to_block,
    )

    def _process_data(self, raw_data):
        return raw_data.assign(event_s=lambda x: x["timepoint_sample"].divide(EventsExporter.FS)) 

class SpikesExporter(BlockExporter):
    FS = 30000
    def _get_raw_data(self):
        ...
    
    def _process_data(self):
        ...

class EEGExporter(BlockExporter):
    def _get_raw_data(self):
        ...
    
    def _process_data(self):
        ...

class EEGTimeSeries(BlockExporter):
    def _get_raw_data(self):
        ...
    
    def _process_data(self):
        ...

In [18]:
r = RecordingsExporter(engine, metadata)


In [None]:
class SpikesExporter:
    ...

In [16]:
neuron_exporter = NeuronsExporter(engine, metadata)

In [17]:
neurons = neuron_exporter.processed_data

In [6]:
neurons

Unnamed: 0,neuron_id,waveform_width,peak_asymmetry,mean_firing_rate,cv2_isi
0,1,16.0,0.065574,0.601693,0.467315
1,4,9.0,0.037344,2.105512,0.824797
2,5,17.0,0.070539,1.778190,0.394575
3,6,21.0,0.087137,0.832163,0.428821
4,8,70.0,0.343137,1.690391,0.305202
...,...,...,...,...,...
561,2627,8.0,0.033333,0.916395,1.032022
562,2628,19.0,0.078189,0.429663,1.250412
563,2629,17.0,0.068826,2.536273,0.840159
564,2630,67.0,0.233449,10.256485,0.573768
