In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from fish.image.vol import get_stack_freq
from glob import glob
%matplotlib inline

In [15]:
def h5_loader(fn, roi=None):
    from h5py import File
    with File(fn) as f:
        if roi is not None:
            return f['default'][roi]
        else:
            return f['default'][:]
        
def load_planewise(fnames, timepoint, timepoints_per_file=50):    
    cur_file = timepoint // timepoints_per_file
    cur_slice = timepoint % timepoints_per_file    
    return h5_loader(fnames[cur_file], roi=(slice(cur_slice, cur_slice + 1), slice(0, None), slice(0, None)))

def load_timerange(fnames, timerange, timepoints_per_file=50):
    from numpy import concatenate
    from collections import OrderedDict
    file_inds = timerange // timepoints_per_file
    file_slices = timerange % timepoints_per_file    
    chunks = OrderedDict()
    for ind in sorted(set(file_inds)):
        slices = file_slices[file_inds == ind]
        chunks[ind] = slice(slices.min(), slices.max() + 1)
    
    return concatenate([h5_loader(fnames[key], roi=val) for key,val in chunks.items()])

def vol_dff(data, q=20, offset=5):
    from numpy import percentile
    baseline = percentile(data, q, axis=0)
    return (data - baseline) / (baseline + offset)

def register_timerange(data, timerange, regparams):    
    from scipy.ndimage.interpolation import shift
    from scipy.signal import medfilt
    import thunder as td
    
    filtered = -medfilt(regparams[timerange,:2,-1], 2001)
    registered = td.images.fromarray(data, engine=sc)
    registered = registered.map(lambda kv: shift(kv[1], filtered[kv[0][0]], order=1), with_keys=True).toarray()
    
    return registered

In [4]:
paths = {}
paths['base'] = '/groups/ahrens/ahrenslab//YuMu/SPIM/active_datasets/20170609/fish1/20170609_1_1_th1gc6s_gfaprgeco_probe_multivel_20170609_102623/'
paths['ephys'] = paths['base'] + 'ephys/'
paths['ims'] = paths['base'] + 'raw/'
regparams = np.load(paths['base'] + 'regparams_affine_yumu.npy')
fnames = sorted(glob(paths['ims'] + 'TM*'))

trigger1 = np.array(loadmat(paths['ephys'] + 's_m_frame.mat')['s_m_frame']);
trigger2 = np.array(loadmat(paths['ephys'] + 's_frame.mat')['s_frame']);
trigger3 = np.array(loadmat(paths['ephys'] + 'm_frame.mat')['m_frame']);

In [62]:
output_dir = '/groups/ahrens/ahrenslab/davis/for_yumu/20170609_1_1_th1gc6s_gfaprgeco_probe_multivel_20170609_102623/'
# 50 planes per stack times the reported number of stacks per second
fs_im = get_stack_freq(paths['ims'])[0] * 50
window = np.arange(-1 * fs_im, 4 * fs_im, dtype='int')
np.save(output_dir + 'average_window.npy', window)
timeranges = [window + t for t in [trigger1, trigger2, trigger3]]

In [55]:
def get_triggered_averages(timeranges, fnames):
    import thunder as td
    from scipy.ndimage.filters import median_filter, gaussian_filter
    from scipy.signal import savgol_filter
    tr_kv = [(ind , val) for ind, val in enumerate(timeranges)]
    trials_raw = sc.parallelize(tr_kv, numSlices=len(tr_kv)).mapValues(lambda v: load_timerange(fnames, v))    
    trials_filt = trials_raw.mapValues(lambda v: median_filter(v, (1,3,3))).mapValues(lambda v: gaussian_filter(v, (0,2,2)))
    trials_smoothed = trials_filt.mapValues(lambda v: savgol_filter(v, 21, polyorder=3,axis=0))
    trials_dff = trials_smoothed.mapValues(lambda v: vol_dff(v))
    trial_ims = td.images.fromrdd(trials_dff)
    tr_mean = trial_ims.mean().toarray().astype('float32')
    
    return tr_mean

In [56]:
triggered_means = [get_triggered_averages(tr, fnames) for tr in timeranges]

In [60]:
from skimage.io import imsave
[imsave('/groups/ahrens/ahrenslab/davis/trigger_{0}_mean.tif'.format(1 + ind), trm) for ind, trm in enumerate(triggered_means)]

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)


[None, None, None]