In [1]:
import numpy as np
from scipy.interpolate import interp1d
import pandas as pd

In [2]:
%%capture
import datajoint as dj
nda = dj.create_virtual_module('nda', 'microns_phase3_nda')

# get oracle scores

In [3]:
oracle_rel = nda.ScanUnit() * nda.Oracle() & nda.ScanInclude()
oracle_df = pd.DataFrame(oracle_rel.proj(oracle_score='pearson').fetch())
oracle_df['oracle_decile'] = pd.qcut(oracle_df['oracle_score'], q=10, labels=np.arange(10))

# save 1000 oracle traces

In [4]:
def fetch_trial_average_oracle_raster(unit_key, desired_fps=None, average_over_repeats=True, normalize_traces=True):
    """Fetches the responses of the provided unit to the oracle trials
    Args:
        unit_key      (dict):        dictionary to uniquely identify a functional unit (must contain the keys: "session", "scan_idx", "unit_id") 
        
    Returns:
        oracle_score (float):        
        responses    (array):        array of oracle responses interpolated to scan frequency: 10 repeats x 6 oracle clips x f response frames
    """
    if desired_fps is None:
        fps = (nda.Scan & unit_key).fetch1('fps')  # get frame rate of scan
    else:
        fps = desired_fps

    oracle_rel = (dj.U('condition_hash').aggr(nda.Trial & unit_key, n='count(*)', m='min(trial_idx)') & 'n=10')  # get oracle clips
    oracle_hashes = oracle_rel.fetch('KEY', order_by='m ASC')  # get oracle clip hashes sorted temporally

    frame_times_set = []
    # iterate over oracle repeats (10 repeats)
    for first_clip in (nda.Trial & oracle_hashes[0] & unit_key).fetch('trial_idx'): 
        trial_block_rel = (nda.Trial & unit_key & f'trial_idx >= {first_clip} and trial_idx < {first_clip+6}')  # uses the trial_idx of the first clip to grab subsequent 5 clips (trial_block) 
        start_times, end_times = trial_block_rel.fetch('start_frame_time', 'end_frame_time', order_by='condition_hash DESC')  # grabs start time and end time of each clip in trial_block and orders by condition_hash to maintain order across scans
        frame_times = [np.linspace(s, e , np.round(fps * (e - s)).astype(int)) for s, e in zip(start_times, end_times)]  # generate time vector between start and end times according to frame rate of scan
        frame_times_set.append(frame_times)

    trace, fts, delay = ((nda.Activity & unit_key) * nda.ScanTimes * nda.ScanUnit).fetch1('trace', 'frame_times', 'ms_delay')  # fetch trace delay and frame times for interpolation
    f2a = interp1d(fts + delay/1000, trace)  # create trace interpolator with unit specific time delay
    oracle_traces = np.array([f2a(ft) for ft in frame_times_set])  # interpolate oracle times to match the activity trace
    if average_over_repeats:
        oracle_traces = oracle_traces.mean(0)
    if normalize_traces:
        oracle_traces -= np.min(oracle_traces, axis=(0,1), keepdims=True)
        oracle_traces /= np.max(oracle_traces, axis=(0,1), keepdims=True)
    oracle_score = (nda.Oracle & unit_key).fetch1('pearson') # fetch oracle score
    return oracle_traces, oracle_score, frame_times_set

In [5]:
sub_dfs = []
for i in range(10):
    sub_df = oracle_df.query('oracle_decile == @i').sample(100, random_state=99)
    sub_df['oracle_trace'] = sub_df.apply(lambda r: fetch_trial_average_oracle_raster(
        {'session': r.session, 'scan_idx': r.scan_idx, 'unit_id': r.unit_id},
        average_over_repeats=False,
        normalize_traces=False
    )[0], axis=1)
    sub_dfs.append(sub_df)
oracle_traces_df = pd.concat(sub_dfs).sort_values('oracle_score')

# save files

In [6]:
oracle_df = pd.read_pickle('oracle_scores.pkl')

In [7]:
oracle_traces_df.to_pickle('oracle_traces.pkl')