In [1]:
from utils import db_setup_core
from sqlalchemy import select, and_, or_, Integer, cast, Numeric, func
from queries import _result_proxy_to_df
import numpy as np
from pprint import pprint

In [3]:
engine, metadata = db_setup_core()

In [4]:
from queries import select_ifr, select_spike_times, select_waveforms, select_analog_signal_data

In [5]:
# df = select_ifr(engine, metadata, block_name="base_shock", session_names=["hamilton_07"])
# df = select_waveforms(engine, metadata, session_names=["hamilton_07"])

In [89]:
def select_discrete_data(
    engine,
    metadata,
    signal_names=None,
    block_name="pre",
    t_before=0,
    t_after=0,
    group_names=None,
    exp_names=None,
    as_df=True,
    align_to_block=False,
    exclude_excluded_recordings=True,
):
    d_data, sesh_d_sig, d_sigs, r_sesh, rs_blocks, groups, experiments = (
        metadata.tables["discrete_signal_data"],
        metadata.tables["session_discrete_signals"],
        metadata.tables["discrete_signals"],
        metadata.tables["recording_sessions"],
        metadata.tables["recording_session_block_times"],
        metadata.tables["experimental_groups"],
        metadata.tables["experiments"],
    )
    
    stmt_block = select(
        [
            rs_blocks.c.recording_session_id,
            rs_blocks.c.block_start_samples,
            rs_blocks.c.block_end_samples,
        ]
    )
    stmt_block = stmt_block.select_from(
        rs_blocks.join(r_sesh).join(groups).join(experiments)
    )
    if block_name != "all":
        stmt_block = stmt_block.where(rs_blocks.c.block_name == block_name)
    stmt_block = stmt_block.alias("block")
    
    columns = [
        d_sigs.c.signal_name,
        r_sesh.c.session_name,
    ]
    if align_to_block:
        columns.append((d_data.c.timepoint_sample - stmt_block.c.block_start_samples).label("timepoint_sample"))
    else:
        columns.append(d_data.c.timepoint_sample)
    
    if group_names:
        columns.extend([groups.c.group_name])
    if exp_names:
        columns.extend([experiments.c.experiment_name, groups.c.group_name])

    stmt = select(columns)
    stmt = stmt.select_from(
        d_data.join(sesh_d_sig)
        .join(d_sigs)
        .join(r_sesh, r_sesh.c.id == sesh_d_sig.c.recording_session_id)
        .join(rs_blocks)
        .join(groups)
        .join(experiments)
        .join(
            stmt_block, stmt_block.c.recording_session_id == r_sesh.c.id, isouter=False
        )
    )
    
    if block_name != "all":
        stmt = stmt.where(
            and_(
                d_data.c.timepoint_sample
                > (stmt_block.c.block_start_samples  - (t_before / 30000)),
                d_data.c.timepoint_sample < (stmt_block.c.block_end_samples  + (t_after / 30000)),
            )
        )
    
    if exclude_excluded_recordings:
        stmt = stmt.where(or_(r_sesh.c.excluded.is_(None), r_sesh.c.excluded == 1))
    if signal_names:
        stmt = stmt.where(d_sigs.c.signal_name.in_(signal_names))
    if group_names:
        stmt = stmt.where(groups.c.group_name.in_(group_names))
    if exp_names:
        stmt = stmt.where(experiments.c.experiment_name.in_(exp_names))

    stmt = stmt.limit(10)
    with engine.connect() as conn:
        res = conn.execute(stmt)
    if as_df:
        res = _result_proxy_to_df(res)
    return res

In [95]:
df = select_discrete_data(engine, metadata, group_names=["acute_cit"],
                    align_to_block=True, block_name="base_shock", t_before=10)

In [97]:
df

Unnamed: 0,signal_name,session_name,timepoint_sample,group_name
0,eshock,hamilton_04,349195,acute_cit
1,eshock,hamilton_04,409184,acute_cit
2,eshock,hamilton_04,529163,acute_cit
3,eshock,hamilton_04,589152,acute_cit
4,eshock,hamilton_04,709131,acute_cit
5,eshock,hamilton_04,829110,acute_cit
6,eshock,hamilton_04,889099,acute_cit
7,eshock,hamilton_04,949089,acute_cit
8,eshock,hamilton_04,1069068,acute_cit
9,eshock,hamilton_04,1129058,acute_cit
