In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('presentation')
from fish.image.vol import get_stack_dims, get_stack_freq, unfilter_flat
from fish.util.fileio import read_image
from fish.util.plot import proj_fuse
from fish.ephys.ephys import load, estimate_onset, chop_trials
from skimage.io import imread
from glob import glob
import thunder as td

from os.path import exists
%matplotlib inline

In [340]:
def flip_pcs(weights):
    # return an array of 1s or -1s for each PC signifying whether that PC is "flipped"
    to_flip = (abs(weights.max(0)) > abs(weights.min(0))).astype('int')
    to_flip[to_flip == False] = -1
    to_flip[to_flip == True] = 1
    
    return to_flip

def wheremax(arr):
    from numpy import unravel_index, argmax
    return unravel_index(argmax(arr), arr.shape)

def get_max_rois(vol, sigma=(8,8)):
    from scipy.ndimage.filters import gaussian_filter
    from numpy import array, hstack
    z_inds = np.arange(vol.shape[0]).reshape(-1,1)
    mx_inds = array([wheremax(gaussian_filter(z, sigma)) for z in vol])
    return hstack([z_inds, mx_inds])

def get_valid_frames(data, rois, threshold=1):    
    roi_ts = data.map(lambda v: v[list(zip(*rois))]).toarray()
    return roi_ts.T
    
def clean_trial_mean(kvp, masks, tr_len):
    from numpy.ma import array as marray
    from numpy import array
    coords, data = kvp    
    # figure out which plane we are in and choose the correct temporal mask
    mask = masks[coords[0]]
    num_trials = len(data) // tr_len
    return coords, array(marray(data, mask=mask).reshape(num_trials, tr_len).mean(0))

def kvp_to_array(kvp, dims, baseline=0):
    from numpy import zeros
    output = zeros(dims) + baseline
    for kv in kvp:
        output[kv[0]] = kv[1]
    return output

In [3]:
from os.path import sep
paths = {}

# for my data
#paths['ephys'] = '/groups/ahrens/ahrenslab/davis/data/ephys/20171108/7dpf_cy171xec43_f1_opto_2_12mw.10chFlt'
#paths['raw'] = '/groups/ahrens/ahrenslab/davis/data/spim/raw/20171108/7dpf_cy171xec43_f1_opto_2_12mw_20171108_215904/'
#paths['proc'] = paths['raw'].replace('raw', 'proc')
#paths['reg'] = paths['proc'] + 'reg/'
#paths['opto'] = paths['proc'] + 'opto_triggering/'

# for yumu's data
paths['raw'] = '/groups/ahrens/ahrenslab/YuMu/SPIM/active_datasets/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/raw/'
paths['proc'] = '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/'
paths['ephys'] = '/groups/ahrens/ahrenslab/YuMu/SPIM/active_datasets/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/ephys/20171229_1_5_gfapcochr_hucrgeco_7dpf_stimwaist_withctrl_staticpulse_noimaging.10chFlt'
paths['opto'] = paths['proc'] + 'opto_triggering/'

exp_name = paths['raw'].split(sep)[-2]

mask = imread(paths['opto'] + 'mask.tif')
mask_ds = mask[:,::2,::2]
# function to reconstruct linearized data, given 2x downsampling in xy
recon = lambda v: unfilter_flat(v, mask_ds.astype('bool'))
av_window = np.load(paths['opto'] + 'av_window.npy')

# set correct aspect ratio for plotting
from fish.image.vol import get_metadata
exp_data = get_metadata(paths['raw'] + 'ch0.xml')
bidirectional_stack = exp_data['bidirectional_stack'] == 'T'
z_step = exp_data['z_step']
ds_xy = 2
aspect = (int(z_step / .406) // ds_xy, 1, 1)
dims = get_stack_dims(paths['raw'])[::-1]
single_plane = False
if dims[0] == 50:
    single_plane = True

fnames = sorted(glob(paths['raw'] + 'TM*'))

num_frames = len(fnames)
# Number of timepoints per file is 50 for high speed single plane
if single_plane:
    num_frames = len(fnames) * 50

from fish.image.vol import get_stack_freq
fs_im = 1 / get_stack_freq(paths['raw'])[0]    
print('Sampling rate:  {0} Hz'.format(1 / fs_im))
print(num_frames)

cond_paths = sorted(glob(paths['opto'] + 'condition*'))
conds = [int(c.split('_')[-1]) for c in cond_paths]

Sampling rate:  1.7100000000000002 Hz
7250


In [6]:
trials = [td.images.fromlist(sorted(glob(c + '/t_*.npy')), accessor=np.load, engine=sc) for c in cond_paths]
first_trial = 5
num_trials = [tr.shape[0] // len(av_window) - first_trial for tr in trials]

In [4]:
from fish.ephys.ephys import estimate_onset
epdat = load(paths['ephys'])
fs_ep = 6000

In [7]:
chopped = chop_trials(epdat[4])
stim_dur = np.median(np.diff(np.array(list(zip(*chopped[1.0]))))) / fs_ep
stim_dur_frames = np.ceil(stim_dur / fs_im)
stim_window = np.arange(stim_dur_frames).astype('int') + np.where(av_window==0)[0]
stim_frames = np.concatenate([stim_window + t for t in range(0, len(av_window) * num_trials[0], len(av_window))])

In [239]:
mx_projs = [tr.max().map(recon).toarray() for tr in trials]
artifact_rois = [get_max_rois(mx) for mx in mx_projs]
vol_data = [tr.map(recon) for tr in trials]
artifact_ts = [get_valid_frames(vol_data[ind], artifact_rois[ind]) for ind in range(len(vol_data))]
ser_rdds = [tr.map(recon).toseries().tordd() for tr in trials]
thr = 1.0
cleaned = [ser_rdds[ind].map(lambda v: clean_trial_mean(v, artifact_ts[ind] > thr, len(av_window))).collect() for ind in range(len(ser_rdds))]
cleaned_vols = [kvp_to_array(cleaned[ind], [*mask_ds.shape, len(av_window)]).transpose([3,0,1,2]) for ind in range(len(cleaned))]

In [343]:
glob(paths['opto'] + '*')

['/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/opto_triggering/trial_mean_cleaned_condition_2',
 '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/opto_triggering/condition_2',
 '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/opto_triggering/mask.tif',
 '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/opto_triggering/trial_mean_condition_2',
 '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/20171229_1_5_hucrgeco_gfapcochr_7dpf_stimwaist_imagedwiithhigherpower_20171229_135523/opto_triggering/trial_mean_condition_1',
 '/groups/ahrens/ahrenslab/davis/data/spim/proc/20171229/fish1/2017

In [341]:
# save trial-average movies to disk
from skimage.io import imsave
from os.path import exists
from os import mkdir
for ind, val in enumerate(cleaned_vols):
    tmp_path = paths['opto'] + 'trial_mean_cleaned_condition_{0}/'.format(ind + 1)
    if not exists(tmp_path):
        mkdir(tmp_path)
    for ind_t, vol in enumerate(val):
        imsave(tmp_path + 'tm_{0:05d}.tif'.format(ind_t), vol.astype('float32'))

In [338]:
unrolled = cleaned_vols[0].reshape(cleaned_vol.shape[0], cleaned_vol.shape[1] * cleaned_vol.shape[2], cleaned_vol.shape[3])
pq.image(unrolled)

In [7]:
mean_responses = [tr.toseries().map(lambda v: v[len(av_window) * first_trial:].reshape(num_trials[ind], len(av_window)).mean(0)).toarray() for ind,tr in enumerate(trials)]
mean_responses = [np.array(list(map(recon, mr.T))) for mr in mean_responses]

In [8]:
# save trial-average movies to disk
from skimage.io import imsave
from os.path import exists
from os import mkdir
for ind, val in enumerate(mean_responses):
    tmp_path = paths['opto'] + 'trial_mean_condition_{0}/'.format(ind + 1)
    if not exists(tmp_path):
        mkdir(tmp_path)
    for ind_t, vol in enumerate(val):
        imsave(tmp_path + 'tm_{0:05d}.tif'.format(ind_t), vol.astype('float32'))

In [None]:
%%time
from factorization import PCA
for ind_c, cond in enumerate(conds):
    ser = trials[ind_c].toseries().map(lambda v: v[first_trial * len(av_window):])
    ser.cache()
    ser.count()
    pca = PCA(k=15, svd_method='em').fit(ser)
    to_flip = flip_pcs(pca[0])
    pca[0] = pca[0] * to_flip
    pca[1] = (pca[1].T * to_flip).T
    np.save(paths['opto'] + 'pca_condition_{0}.npy'.format(cond), np.array(pca))
    ser.tordd().unpersist()

In [None]:
pca_vols = np.array([recon(pca[0][:,ind]) for ind in range(pca[0].shape[1])])

In [None]:
from skimage.exposure import adjust_gamma, rescale_intensity
to_plot = range(pca[1].shape[0])
fig, axs = plt.subplots(nrows=len(to_plot), ncols=2, figsize=(16, len(to_plot) * 4), gridspec_kw={'width_ratios' : (1,2)})
for ind, ax in enumerate(axs):    
    pc_ = to_plot[ind]
    ax[0].plot(pca[1][pc_].reshape(num_trials[ind_c], len(av_window)).T, color='gray', alpha=.5);
    ax[0].plot(pca[1][pc_].reshape(num_trials[ind_c], len(av_window)).T.mean(1), color='m', linewidth=3);
    ax[0].set_ylabel('PC {0}'.format(pc_), fontsize=18)
    im = pca_vols[pc_].max(0)[10:-10,10:-10]
    im = rescale_intensity(im, out_range=(0, 1))    
    ax[1].imshow(im ** .8, cmap='magma')
    ax[1].axis('off')

In [None]:
pcas_group = {}

In [None]:
%%time
# Try local factorizaton on trial average
trial_means = sorted(glob(paths['proc'] + 'opto_triggering/*trial_mean*'))
results = []
for ind, c in enumerate(trial_means):
        fnames = sorted(glob(c + '/*.tif'))
        results.append(np.array([filter_flat(imread(fn), mask_ds.astype('bool')) for fn in fnames]))

from factorization import PCA
pcas = [PCA(k=10).fit(r) for r in results]

In [None]:
pcas_group[exp_name] = pcas

In [None]:
def pca_summary(comps, maps):
    from fish.util.plot import proj_fuse
    n_comps = comps.shape[1]
    nr = n_comps
    nc = 3
    # duration of stim in seconds
    stim_dur = 3
    fig, axs = plt.subplots(nrows=nr, ncols=nc, figsize=(18, nr * 6), gridspec_kw={'width_ratios' : (1,2,2)})
    for ind, ax_ in enumerate(axs):
        map_ = recon(maps[ind])
        ax_[0].plot(av_window * fs_im, comps[:, ind], color='orange', linewidth=3)
        ax_[0].axvspan(0, stim_dur, alpha=.2)
        axs_format(ax_[0])
        ax_[1].imshow(proj_fuse(map_, np.max, aspect=aspect)[:,:1024], cmap='terrain', origin='lower')
        ax_[2].imshow(proj_fuse(map_, np.min, aspect=aspect)[:,:1024], cmap='terrain_r', origin='lower')
        [ax.axis('off') for ax in [ax_[1], ax_[2]]]
    plt.subplots_adjust(wspace=0, hspace=.05)

In [None]:
exps = sorted(pcas_group.keys())

In [None]:
comp_range = slice(0, 4)
[pca_summary(pca_[0][:,comp_range], pca_[1][comp_range,:]) for pca_ in pcas_group[exps[0]]]

In [None]:
comp_range = slice(0, 4)
[pca_summary(pca_[0][:,comp_range], pca_[1][comp_range,:]) for pca_ in pcas_group[exps[1]]]