In [1]:
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from caveclient import CAVEclient
import datajoint as dj

# initialize data sources

In [2]:
client = CAVEclient('minnie65_phase3_v1')
client.materialize.version = 1181

In [3]:
%%capture
nda = dj.create_virtual_module('nda', 'microns_phase3_nda')

# get tables

In [4]:
nuc_df = client.materialize.query_table('nucleus_detection_v0').rename(columns={'id': 'nucleus_id'})
coreg_df = client.materialize.query_table('coregistration_manual_v4').rename(columns={'target_id': 'nucleus_id'})

In [5]:
oracle_df = pd.read_pickle('oracle_scores.pkl') # code to make this in Fig 3 panel f notebook

# get presyn

In [6]:
pre_nucleus_id = 294657

In [7]:
pre_df = nuc_df.query(f'nucleus_id=={pre_nucleus_id}')[['nucleus_id', 'pt_root_id']]

In [8]:
pre_root_id = nuc_df.query(f'nucleus_id=={pre_nucleus_id}').pt_root_id.values[0]

# get postsyns

In [9]:
post_syn_df = client.materialize.query_table(
    'synapses_pni_2',
    filter_equal_dict={'pre_pt_root_id': pre_root_id}
).query(
    'pre_pt_root_id!=post_pt_root_id'
)

post_syn_nuc_df = post_syn_df.merge(
    nuc_df.rename(columns={'nucleus_id': 'post_nucleus_id'}),
    left_on='post_pt_root_id',
    right_on='pt_root_id'
)[['pre_pt_root_id', 'post_pt_root_id', 'size', 'post_nucleus_id']]

post_df = post_syn_nuc_df.groupby('post_nucleus_id', as_index=False)['size'].sum().rename(columns={'post_nucleus_id': 'nucleus_id', 'size': 'summed_size'})

# match to functional data

In [10]:
pre_func_df = pre_df.merge(coreg_df)[['nucleus_id', 'session', 'scan_idx', 'unit_id']].merge(oracle_df)

In [11]:
post_func_df = post_df.merge(coreg_df)[['nucleus_id', 'session', 'scan_idx', 'unit_id', 'summed_size']].merge(oracle_df)
post_func_df = post_func_df.sort_values('oracle_score', ascending=False).drop_duplicates('nucleus_id')  # get scan with highest oracle score
post_func_df = post_func_df.sort_values('summed_size', ascending=False) # sort by largest synapse size

# get oracle traces

In [12]:
def fetch_trial_average_oracle_raster(unit_key, desired_fps=None, average_over_repeats=True, normalize_traces=True):
    """Fetches the responses of the provided unit to the oracle trials
    Args:
        unit_key      (dict):        dictionary to uniquely identify a functional unit (must contain the keys: "session", "scan_idx", "unit_id") 
        
    Returns:
        oracle_score (float):        
        responses    (array):        array of oracle responses interpolated to scan frequency: 10 repeats x 6 oracle clips x f response frames
    """
    if desired_fps is None:
        fps = (nda.Scan & unit_key).fetch1('fps')  # get frame rate of scan
    else:
        fps = desired_fps

    oracle_rel = (dj.U('condition_hash').aggr(nda.Trial & unit_key, n='count(*)', m='min(trial_idx)') & 'n=10')  # get oracle clips
    oracle_hashes = oracle_rel.fetch('KEY', order_by='m ASC')  # get oracle clip hashes sorted temporally

    frame_times_set = []
    # iterate over oracle repeats (10 repeats)
    for first_clip in (nda.Trial & oracle_hashes[0] & unit_key).fetch('trial_idx'): 
        trial_block_rel = (nda.Trial & unit_key & f'trial_idx >= {first_clip} and trial_idx < {first_clip+6}')  # uses the trial_idx of the first clip to grab subsequent 5 clips (trial_block) 
        start_times, end_times = trial_block_rel.fetch('start_frame_time', 'end_frame_time', order_by='condition_hash DESC')  # grabs start time and end time of each clip in trial_block and orders by condition_hash to maintain order across scans
        frame_times = [np.linspace(s, e , np.round(fps * (e - s)).astype(int)) for s, e in zip(start_times, end_times)]  # generate time vector between start and end times according to frame rate of scan
        frame_times_set.append(frame_times)

    trace, fts, delay = ((nda.Activity & unit_key) * nda.ScanTimes * nda.ScanUnit).fetch1('trace', 'frame_times', 'ms_delay')  # fetch trace delay and frame times for interpolation
    f2a = interp1d(fts + delay/1000, trace)  # create trace interpolator with unit specific time delay
    oracle_traces = np.array([f2a(ft) for ft in frame_times_set])  # interpolate oracle times to match the activity trace
    if average_over_repeats:
        oracle_traces = oracle_traces.mean(0)
    if normalize_traces:
        oracle_traces -= np.min(oracle_traces, axis=(0,1), keepdims=True)
        oracle_traces /= np.max(oracle_traces, axis=(0,1), keepdims=True)
    oracle_score = (nda.Oracle & unit_key).fetch1('pearson') # fetch oracle score
    return oracle_traces, oracle_score, frame_times_set

In [13]:
pre_unit_key = pre_func_df[['session', 'scan_idx', 'unit_id']].to_dict(orient='records')[0]

In [14]:
post_unit_keys = post_func_df[['session', 'scan_idx', 'unit_id']].to_dict(orient='records')

In [15]:
pre_trace, _, _ = fetch_trial_average_oracle_raster(pre_unit_key, desired_fps=6.3)

In [16]:
post_traces = []
for k in post_unit_keys:
    t, _, _ = fetch_trial_average_oracle_raster(k, desired_fps=6.3)
    post_traces.append(t)
post_traces = np.stack(post_traces)

# save files

In [17]:
pre_func_df.to_pickle('pre_df.pkl')

In [18]:
post_func_df.to_pickle('post_df.pkl')

In [19]:
np.save('pre_trace.npy', pre_trace)

In [20]:
np.save('post_traces.npy', post_traces)