In [None]:
import os
import glob
import h5py
import pickle
import shutil

import numpy as np
from skimage import io
import matplotlib.pyplot as plt
from scipy import signal, ndimage

from roifile import ImagejRoi
import caiman as cm
import caiman.paths
from caiman.source_extraction.volpy import utils
from caiman.source_extraction.volpy.volparams import volparams
from caiman.source_extraction.volpy.volpy import VOLPY
from caiman.source_extraction.volpy.spikepursuit import signal_filter

In [None]:
#working_dir = '/Volumes/CLab/hour_long_recording/record3/moco_aff_crop/'
working_dir = '/Volumes/CLab/hour_long_recording_jedi/recording1/moco_aff_crop/'
img_list = glob.glob(os.path.join(working_dir, '*.tif'))

save_dir = os.path.join(working_dir, 'volpy_results')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
weights = None                                # if None, use ROIs for initialization; to reuse weights check reuse weights block 

fr = 500 
template_size = 0.01                          # half size of the window length for spike templates, default is 20 ms 
context_size = 1                          # number of pixels surrounding the ROI to censor from the background PCA
visualize_ROI = False                         # whether to visualize the region of interest inside the context region
flip_signal = True                            # Important!! Flip signal or not, True for Voltron indicator, False for others
hp_freq_pb = 1 / 3                            # parameter for high-pass filter to remove photobleaching
hp_freq = 10
clip = 100                                    # maximum number of spikes to form spike template
threshold_method = 'adaptive_threshold'       # adaptive_threshold or simple 
min_spikes= 4                                # minimal spikes to be found  # was 50 for Jiannis' data
pnorm = 0.5                                   # a variable deciding the amount of spikes chosen for adaptive threshold method
desired_fp = 10**(-4)  # note: was hard-coded for current use case
threshold = 3                                 # threshold for finding spikes only used in simple threshold method, Increase the threshold to find less spikes
do_plot = False                               # plot detail of spikes, template for the last iteration
ridge_bg= 0.01                                # ridge regression regularizer strength for background removement, larger value specifies stronger regularization 
sub_freq = 20                                 # frequency for subthreshold extraction
weight_update = 'ridge'                       # ridge or NMF for weight update
n_iter = 1

In [None]:
def smooth(x, axis=0, wid=5):
    # this is way faster than convolve
    if wid < 2:
        return x
    cumsum_vec = np.cumsum(np.insert(x, 0, 0, axis=axis), axis=axis)
    ma_vec = (cumsum_vec[wid:] - cumsum_vec[:-wid]) / wid
    y = x.copy()
    start_ind = int(np.floor((wid-1)/2))
    end_ind = wid-1-start_ind
    y[start_ind:-end_ind] = ma_vec
    return y

In [None]:
def remove_moving_frames(img, mean_img, thr=0.7):
    img_smoothed = smooth(img, axis=0, wid=10)
    corr_series = np.zeros((img.shape[0], ))
    # compute the cross correlation between the mean image and every frame
    for i in np.arange(img_smoothed.shape[0]):
        tmp = img_smoothed[i, :, :]
        corr_series[i] = np.corrcoef(tmp.ravel(), mean_img.ravel())[0, 1]
    
    moving_frame_mask = corr_series < thr
    # blur the mask to remove the boundary effect
    moving_frame_mask = ndimage.gaussian_filter1d(moving_frame_mask.astype(np.float), sigma=200)
    removed_frames = np.where(moving_frame_mask > 0.0)[0]
    raw_img_index = np.arange(img.shape[0])
    # remove the frames with low correlation
    img = np.delete(img, removed_frames, axis=0)
    img_index = np.delete(raw_img_index, removed_frames, axis=0)
    current_to_raw_mapping = {i: j for i, j in zip(np.arange(img.shape[0]), img_index)}

    return img, current_to_raw_mapping

In [None]:
img_list.sort()
for img_path in img_list:
    ## read image and ROI
    img = io.imread(img_path)
    mean_img = img.mean(axis=0)
    img, current_to_raw_mapping = remove_moving_frames(img, mean_img)
    ROI_path = img_path[:-11] + '_moco_mask.h5'
    img_id = os.path.basename(img_path).replace('.tif', '')
    print(img_id)
    with h5py.File(ROI_path, 'r') as fl:
        ROI = fl['cell_mask'][()]
    #img = remove_moving_frames(img, ROI)
    T, d1, d2 = img.shape
    img_reshape = img.reshape(T, d1*d2, order='F')
    ## generating memory mapping
    mmap_path = caiman.paths.memmap_frames_filename(img_id[:15], [d1, d2], T, 'C')
    mmap_path = os.path.join(working_dir, mmap_path)
    fp = np.memmap(mmap_path, dtype='float32', mode='w+', shape=(d1*d2, T), order='C')
    fp[:] = img_reshape[:].T
    fp.flush()
    del fp
    ## assembly parameters for volpy
    ROIs = np.expand_dims(ROI.T, axis=0)
    index = list(range(len(ROIs)))     # index of ROIs to be used for spike extraction
    opts_dict={'fnames': mmap_path,
            'ROIs': ROIs,
            'fr': fr,
            'index': index,
            'weights': weights,
            'template_size': template_size, 
            'context_size': context_size,
            'visualize_ROI': visualize_ROI, 
            'flip_signal': flip_signal,
            'hp_freq': hp_freq,
            'hp_freq_pb': hp_freq_pb,
            'clip': clip,
            'threshold_method': threshold_method,
            'min_spikes':min_spikes,
            'pnorm': pnorm,
            #'desired_fp': desired_fp, 
            'threshold': threshold,
            'do_plot':do_plot,
            'ridge_bg':ridge_bg,
            'sub_freq': sub_freq,
            'weight_update': weight_update,
            'n_iter': n_iter}

    opts = volparams(params_dict=opts_dict)
    ## run volpy
    if 'dview' in locals():
        cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)
    # try:
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    vpy.fit(n_processes=n_processes, dview=dview)
    ## visualize and save results
    print(np.where(vpy.estimates['locality'])[0])    # neurons that pass locality test
    idx = np.where(vpy.estimates['locality'] > 0)[0]
    utils.view_components(vpy.estimates, mean_img, [0], save_path=os.path.join(save_dir, img_id + '_summary.png'))
    spike_locs = vpy.estimates['spikes'].copy().ravel()
    for i, loc in enumerate(spike_locs):
        spike_locs[i] = current_to_raw_mapping[loc]
    vpy.estimates['spikes'] = spike_locs
    with open(os.path.join(save_dir, img_id + '_volpy.pkl'), 'wb') as f:
        pickle.dump(vpy.estimates, f)
    # except:
    #     shutil.move(src=img_path, dst=os.path.join(working_dir, 'failed'))

# 4 min analysis

In [None]:
working_dir = '/Volumes/CLab/hour_long_recording/ASAP5_4min/moco_aff_crop'
working_dir = '/Volumes/CLab/hour_long_recording_jedi/jedi_4min/moco_aff_crop'

img_list = glob.glob(os.path.join(working_dir, '*.tif'))

save_dir = os.path.join(working_dir, 'volpy_results')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
weights = None                                # if None, use ROIs for initialization; to reuse weights check reuse weights block 

fr = 500 
template_size = 0.01                          # half size of the window length for spike templates, default is 20 ms 
context_size = 1                          # number of pixels surrounding the ROI to censor from the background PCA
visualize_ROI = False                         # whether to visualize the region of interest inside the context region
flip_signal = True                            # Important!! Flip signal or not, True for Voltron indicator, False for others
hp_freq_pb = 1 / 3                            # parameter for high-pass filter to remove photobleaching
hp_freq = 10
clip = 100                                    # maximum number of spikes to form spike template
threshold_method = 'adaptive_threshold'       # adaptive_threshold or simple 
min_spikes= 5                                # minimal spikes to be found  # was 50 for Jiannis' data
pnorm = 0.5                                   # a variable deciding the amount of spikes chosen for adaptive threshold method
desired_fp = 10**(-4)  # note: was hard-coded for current use case
threshold = 3                                 # threshold for finding spikes only used in simple threshold method, Increase the threshold to find less spikes
do_plot = False                               # plot detail of spikes, template for the last iteration
ridge_bg= 0.01                                # ridge regression regularizer strength for background removement, larger value specifies stronger regularization 
sub_freq = 20                                 # frequency for subthreshold extraction
weight_update = 'ridge'                       # ridge or NMF for weight update
n_iter = 1

In [None]:
img_list.sort()
for img_path in img_list:
    ## read image and ROI
    img = io.imread(img_path)
    mean_img = img.mean(axis=0)
    img, current_to_raw_mapping = remove_moving_frames(img, mean_img)
    ROI_path = img_path[:-4] + '_mask.h5'
    img_id = os.path.basename(img_path).replace('.tif', '')
    print(img_id)
    with h5py.File(ROI_path, 'r') as fl:
        ROI = fl['cell_mask'][()]
    #img = remove_moving_frames(img, ROI)
    T, d1, d2 = img.shape
    img_reshape = img.reshape(T, d1*d2, order='F')
    ## generating memory mapping
    mmap_path = caiman.paths.memmap_frames_filename(img_id[:15], [d1, d2], T, 'C')
    mmap_path = os.path.join(working_dir, mmap_path)
    fp = np.memmap(mmap_path, dtype='float32', mode='w+', shape=(d1*d2, T), order='C')
    fp[:] = img_reshape[:].T
    fp.flush()
    del fp
    ## assembly parameters for volpy
    ROIs = np.expand_dims(ROI.T, axis=0)
    index = list(range(len(ROIs)))     # index of ROIs to be used for spike extraction
    opts_dict={'fnames': mmap_path,
            'ROIs': ROIs,
            'fr': fr,
            'index': index,
            'weights': weights,
            'template_size': template_size, 
            'context_size': context_size,
            'visualize_ROI': visualize_ROI, 
            'flip_signal': flip_signal,
            'hp_freq': hp_freq,
            'hp_freq_pb': hp_freq_pb,
            'clip': clip,
            'threshold_method': threshold_method,
            'min_spikes':min_spikes,
            'pnorm': pnorm,
            #'desired_fp': desired_fp, 
            'threshold': threshold,
            'do_plot':do_plot,
            'ridge_bg':ridge_bg,
            'sub_freq': sub_freq,
            'weight_update': weight_update,
            'n_iter': n_iter}

    opts = volparams(params_dict=opts_dict)
    ## run volpy
    if 'dview' in locals():
        cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)
    # try:
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    vpy.fit(n_processes=n_processes, dview=dview)
    ## visualize and save results
    print(np.where(vpy.estimates['locality'])[0])    # neurons that pass locality test
    idx = np.where(vpy.estimates['locality'] > 0)[0]
    utils.view_components(vpy.estimates, mean_img, [0], save_path=os.path.join(save_dir, img_id + '_summary.png'))
    spike_locs = vpy.estimates['spikes'].copy().ravel()
    for i, loc in enumerate(spike_locs):
        spike_locs[i] = current_to_raw_mapping[loc]
    vpy.estimates['spikes'] = spike_locs
    with open(os.path.join(save_dir, img_id + '_volpy.pkl'), 'wb') as f:
        pickle.dump(vpy.estimates, f)
    # except:
    #     shutil.move(src=img_path, dst=os.path.join(working_dir, 'failed'))

# dev

In [None]:
for img_path in img_list:
    ## read image and ROI
    img = io.imread(img_path)
    mean_img = img.mean(axis=0)
    img, current_to_raw_mapping = remove_moving_frames(img, mean_img)
    img_id = os.path.basename(img_path).replace('.tif', '')
    with open(os.path.join(save_dir, img_id + '_volpy.pkl'), 'rb') as f:
        estimates = pickle.load(f)
        spike_locs = estimates['spikes'].copy().ravel()
        for i, loc in enumerate(spike_locs):
            spike_locs[i] = current_to_raw_mapping[loc]
        estimates['spikes'] = spike_locs
        with open(os.path.join(save_dir, img_id + '_volpy.pkl'), 'wb') as f:
            pickle.dump(estimates, f)
    

In [None]:
from scipy.io import savemat

working_dir = '/Users/ykhao/Downloads/hour_long_recording/moco/'
working_dir = '/Volumes/CLab/hour_long_recording/record3/moco_aff_crop/'
save_dir = os.path.join(working_dir, 'volpy_results')

pkl_list = glob.glob(os.path.join(save_dir, '*.pkl'))

for pkl in pkl_list:
    with open(pkl, 'rb') as f:
        estimates = pickle.load(f)
    img_id = os.path.basename(pkl).replace('_volpy.pkl', '')
    #mdic = {'t': estimates['t'][0], 't_rec':estimates['t_rec'][0]}
    mdic  = estimates
    savemat(os.path.join(save_dir, img_id + '.mat'), mdic)

In [None]:
working_dir = '/Users/ykhao/Downloads/mouse_2p/test_volpy/'
save_dir = os.path.join(working_dir, 'volpy_results')

pkl_list = glob.glob(os.path.join(save_dir, '*.pkl'))

for pkl in pkl_list:
    print(pkl)
    with open(pkl, 'rb') as f:
        estimates = pickle.load(f)
    spikes = estimates['spikes'][0]
    t = estimates['t'][0]
    t = t - np.median(t)
    t_hp = signal_filter(t, 30, 500, order=5)
    selectSpikes = np.zeros(t.shape)
    selectSpikes[spikes] = 1
    sgn = np.mean(t[selectSpikes > 0])
    
    t_nonspike = np.zeros(t.shape)
    t_nonspike[spikes] = 1
    t_nonspike = np.convolve(t_nonspike, np.ones(20)/20, 'same') 
    t_nonspike = t_nonspike == 0 

    ff1 = -t_hp * (t_hp < 0) * t_nonspike
    Ns = np.sum(ff1 > 0)
    noise = np.sqrt(np.divide(np.sum(ff1**2), Ns)) 
    snr = sgn / noise
    print(snr)

In [None]:
from matplotlib.path import Path
import matplotlib.patches as patches

def coords2path(coords):
    codes = [Path.MOVETO]
    for i, _ in enumerate(coords):
        if i > 0:
            codes.append(Path.LINETO)
    coords = np.append(coords, coords[0][None], axis=0)
    codes.append(Path.CLOSEPOLY)

    path = Path(coords, codes)
    return path

def path2mask(path, brain_shape):
    pixX = np.arange(brain_shape[1])
    pixY = np.arange(brain_shape[0])
    xv, yv = np.meshgrid(pixX, pixY)
    roi_pix = np.vstack((xv.flatten(), yv.flatten())).T

    mask = np.zeros(shape=xv.shape)

    xy_indices = np.reshape(path.contains_points(roi_pix, radius=0.5), xv.shape)
    mask[xy_indices] = 1

    mask = mask == 1
    return mask

    # ROI_path = img_path.replace('.tif', '_mask.roi')
    # img_id = os.path.basename(img_path).replace('.tif', '')
    # print(img_id)
    # roi = ImagejRoi.fromfile(ROI_path)
    # coords = roi.coordinates()
    # path = coords2path(coords)
    # ROI = path2mask(path, mean_img.shape)
    # ROI = ROI.T

In [None]:
# def remove_moving_frames(img, ROI, threshold=2):
#     '''
#     Remove frames with large std from the image
#     '''
#     ROI = ROI.T > 0
#     assert img.shape[1:] == ROI.shape
#     img_tmp = img.copy()
#     img_tmp = img_tmp.reshape(img_tmp.shape[0], -1)
#     img_tmp = img_tmp[:, ROI.ravel()]
#     t = img_tmp.mean(-1)
#     std_t = np.array([t[i:i+200].std() for i in range(0, t.shape[0], 200)])
#     # repeat std_t for every 200 frames
#     std_t = np.repeat(std_t, 200)
#     # high pass filter of std_t
#     # std_t = signal_filter(std_t, freq=0.1, fr=5) # detrend
#     return img[std_t < std_t.mean() + threshold*std_t.std()]

# def remove_moving_frames(img, ROI, threshold=12):
#     '''
#     Remove frames with large std from the image
#     '''
#     ROI = ROI.T > 0
#     assert img.shape[1:] == ROI.shape
#     img_tmp = img.copy()
#     img_tmp = img_tmp.reshape(img_tmp.shape[0], -1)
#     img_tmp = img_tmp[:, ROI.ravel()]
#     t = -img_tmp.mean(-1)
#     t = t - np.median(t)
#     t = signal_filter(t, freq=1, fr=500) # detrend
#     pks = signal.find_peaks(t, height=None)[0]
#     pks_snr = t[pks] / t.std()
#     clipping_positions = pks[pks_snr > threshold]

#     bw = np.zeros_like(t)
#     bw[clipping_positions] = 1
#     bw = np.convolve(bw, np.ones(100), mode='same')
#     bw = bw == 0.0
#     return img[bw]