# VolPy pipeline for processing voltage imaging data 
The processing pipeline includes motion correction, memory mapping, segmentation, denoising and source extraction. The demo shows how to construct the params, MotionCorrect and VOLPY objects and call the relevant functions. 


In [1]:
from base64 import b64encode
import cv2
import glob
import h5py
import imageio
from IPython import get_ipython
from IPython.display import HTML, display, clear_output
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

try:
    cv2.setNumThreads(0)
except:
    pass

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        
        get_ipython().run_line_magic('autoreload', '2')
        #get_ipython().run_line_magic('matplotlib', 'qt')
except NameError:
    pass

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.utils.utils import download_demo, download_model
from caiman.source_extraction.volpy import utils
from caiman.source_extraction.volpy.volparams import volparams
from caiman.source_extraction.volpy.volpy import VOLPY
from caiman.source_extraction.volpy.mrcnn import visualize, neurons
import caiman.source_extraction.volpy.mrcnn.model as modellib
from caiman.summary_images import local_correlations_movie_offline
from caiman.summary_images import mean_image
from caiman.paths import caiman_datadir

logfile = None # Replace with a path if you want to log to a file
logger = logging.getLogger('caiman')
logger.setLevel(logging.ERROR)
logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')
if logfile is not None:
    handler = logging.FileHandler(logfile)
else:
    handler = logging.StreamHandler()
handler.setFormatter(logfmt)
logger.addHandler(handler)


# Setup some parameters for data and motion correction dataset parameters
fr = 147                                        # sample rate of the movie
fnames = []
ROIs = None                                     # Region of interests
index = None                                    # index of neurons
weights = None                                  # reuse spatial weights by 
                                                # opts.change_params(params_dict={'weights':vpy.estimates['weights']})
# Motion correction parameters
pw_rigid = False                                # flag for pw-rigid motion correction
gSig_filt = (3, 3)                              # size of filter, in general gSig (see below),
                                                # change this one if algorithm does not work
max_shifts = (5, 5)                             # maximum allowed rigid shift
strides = (48, 48)                              # start a new patch for pw-rigid motion correction every x pixels
overlaps = (24, 24)                             # overlap between patches (size of patch strides+overlaps)
max_deviation_rigid = 3                         # maximum deviation allowed for patch with respect to rigid shifts
border_nan = 'copy'

opts_dict = {
    'fnames': fnames,
    'fr': fr,
    'index': index,
    'ROIs': ROIs,
    'weights': weights,
    'pw_rigid': pw_rigid,
    'max_shifts': max_shifts,
    'gSig_filt': gSig_filt,
    'strides': strides,
    'overlaps': overlaps,
    'max_deviation_rigid': max_deviation_rigid,
    'border_nan': border_nan,
    'use_cuda': False
}

opts = volparams(params_dict=opts_dict)

## Set-up File Directory

In [2]:
# File path to movie file (will download if not present)
# Create mean and correlation images
from pathlib import Path
import tifffile

Subject = 'H1564'
Date ='2025-03-31'
Sequence = ''
S2PDirectory = '/HwangLab/Suite2P/'+Date+'/'+ Subject+'/'+Subject + '-' + Date + '-2P' + Sequence+'/plane1'

tif_file = S2PDirectory + '/data.tif'
with tifffile.TiffFile(tif_file) as tiff:
    dims = tiff.series[0].shape

print(f"Tiff file dimenions: {dims}")

fnames = [tif_file]
parent, stem = os.path.split(fnames[0])
save_base_name = parent + '/' + 'data'
fname_tot = cm.paths.memmap_frames_filename(save_base_name, dims[1:], dims[0], order='F')
fname_new = fname_tot
mmap_file= fname_tot
parent, stem = os.path.split(fname_tot)
path_ROIs = parent + '/' + stem[:-5] + '_mrcnn_ROIs.hdf5'
print(f'ROI file: {path_ROIs}')
print(f'mmap file: {fname_tot}')


opts_dict = {
    'fnames': fnames,
    'fr': fr,
    'index': index,
    'ROIs': ROIs,
    'weights': weights,
    'pw_rigid': pw_rigid,
    'max_shifts': max_shifts,
    'gSig_filt': gSig_filt,
    'strides': strides,
    'overlaps': overlaps,
    'max_deviation_rigid': max_deviation_rigid,
    'border_nan': border_nan,
    'use_cuda': False
}

opts.change_params(params_dict=opts_dict); 

In [3]:
#### SKIP THIS IF MEMMAP ALREADY EXISTS
if not os.path.exists(fname_tot):
#if 1:
    m_reg = cm.load(tif_file).astype(np.float16)
    print(f'Done reading, dimension: {m_reg.shape}')
#    parent, stem = os.path.split(fnames[0])
#    save_base_name = parent +'/'+stem[:-4]

    order = 'F'
    dims = m_reg.shape
    print(dims)
    fname_tot = cm.paths.memmap_frames_filename(save_base_name, dims[1:], dims[0], order)
    big_mov = np.memmap(fname_tot, mode='w+', dtype=np.float16,
                shape=cm.mmapping.prepare_shape((np.prod(dims[1:]), dims[0])), order=order)
    big_mov[:] = np.reshape(m_reg.transpose(1, 2, 0), (np.prod(dims[1:]), dims[0]), order='F')
    big_mov.flush()
    del big_mov

## Start Multiprocessing before ROI Detection

In [3]:
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)

## ROI Detection

In [5]:
img = mean_image(mmap_file, window = 1000, dview=dview)

img = (img-np.nanmean(img))/np.nanstd(img)

gaussian_blur = False        # Use gaussian blur when there is too much noise in the video
Cn = local_correlations_movie_offline(mmap_file, fr=fr, window=fr*4, 
                                      stride=fr*4, winSize_baseline=fr, 
                                      remove_baseline=True, gaussian_blur=gaussian_blur,
                                      dview=dview).max(axis=0)
img_corr = (Cn-np.mean(Cn))/np.std(Cn)
summary_images = np.stack([img, img, img_corr], axis=0).astype(np.float16)

np.save(S2PDirectory+'/summary_images.npy', summary_images)
# Save summary images which could be further used in the VolPy GUI
cm.movie(summary_images).save(tif_file[:-5] + '_summary_images.tif')

fig, axs = plt.subplots(1, 2)
axs[0].imshow(summary_images[0]); axs[1].imshow(summary_images[2])
axs[0].set_title('mean image'); axs[1].set_title('corr image')

In [5]:
## Use ROI detected by Suite2P
print(S2PDirectory)
stat = np.load(S2PDirectory+'/stat.npy', allow_pickle=True)

summary_images = np.load(S2PDirectory +'/summary_images.npy')
print(summary_images[0].shape)
ROIs = []
for cell in range(len(stat)):
    roi = np.full_like(summary_images[0], False)
    for y, x in zip(stat[cell]['ypix'], stat[cell]['xpix']):
        roi[y][x] = True
    ROIs.append(roi)
ROIs = np.array(ROIs)

fig, axs = plt.subplots(1, 2)
axs[0].imshow(summary_images[0]); 
axs[1].imshow(ROIs.sum(0))
axs[0].set_title('mean image'); axs[1].set_title('masks')
print(f"Number of ROIS:{len(ROIs)}")

cm.movie(ROIs).save(path_ROIs)

## Clean-up Multiprocessing Before Spike Extraction

In [6]:
# Restart cluster to clean up memory
if 'dview' in locals():
    cm.stop_server(dview=dview)
#cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False, maxtasksperchild=1)

## Trace Denoising and Spike Extraction

In [7]:
import datetime
# Parameters for trace denoising and spike extraction
with h5py.File(path_ROIs, 'r') as fl:
    ROIs = fl['mov'][()]  # load ROIs

ROIs = ROIs                                   # region of interests
index = list(range(len(ROIs)))                # index of neurons
weights = None                                # if None, use ROIs for initialization; to reuse weights check reuse weights block 

template_size = 0.02                          # half size of the window length for spike templates, default is 20 ms 
context_size = 70                             # number of pixels surrounding the ROI to censor from the background PCA for 2X imaging sessions: use 35 for 1X imaging sessions
visualize_ROI = False                         # whether to visualize the region of interest inside the context region
flip_signal = True                            # Important!! Flip signal or not, True for Voltron indicator, False for others
hp_freq_pb = 1 / 3                            # parameter for high-pass filter to remove photobleaching
clip = 100                                    # maximum number of spikes to form spike template
threshold_method = 'simple'                   # adaptive_threshold or simple 
min_spikes= 10                                # minimal spikes to be found
pnorm = 0.5                                   # a variable deciding the amount of spikes chosen for adaptive threshold method
threshold = 2                                 # threshold for finding spikes only used in simple threshold method, Increase the threshold to find less spikes
do_plot = False                               # plot detail of spikes, template for the last iteration
ridge_bg= 0.01                                # ridge regression regularizer strength for background removement, larger value specifies stronger regularization 
sub_freq = 20                                 # frequency for subthreshold extraction
weight_update = 'ridge'                       # ridge or NMF for weight update
n_iter = 2                                    # number of iterations alternating between estimating spike times and spatial filters

for i in range(0,len(ROIs),1):
    index = [i]

    opts_dict={'fnames':fname_new,
                'ROIs': ROIs,
                'index': index,
                'weights': weights,
                'template_size': template_size, 
                'context_size': context_size,
                'visualize_ROI': visualize_ROI, 
                'flip_signal': flip_signal,
                'hp_freq_pb': hp_freq_pb,
                'clip': clip,
                'threshold_method': threshold_method,
                'min_spikes':min_spikes,
                'pnorm': pnorm, 
                'threshold': threshold,
                'do_plot':do_plot,
                'ridge_bg':ridge_bg,
                'sub_freq': sub_freq,
                'weight_update': weight_update,
                'n_iter': n_iter}

    opts.change_params(params_dict=opts_dict);    

    n_processes = 1
    vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)
    print(f'Setting_ROI{i}_{datetime.datetime.now()}')
    vpy.fit(n_processes=n_processes, dview=dview)
    vpy.estimates['ROIs'] = ROIs
    save_name = f'volpy_{os.path.split(fname_tot)[1][:-5]}_{threshold_method}_{i}'
    np.save(os.path.join(S2PDirectory, save_name), vpy.estimates)
print('Done')

## Visualization

In [7]:
# Visualize spatial footprints and traces
#idx = np.where(vpy.estimates['locality'] > 0)[0]
#utils.view_components(vpy.estimates, img_corr, idx)
# Visualize spatial footprints and traces
i=0
#volpy_name = f'volpy_{os.path.split(fname_tot)[1][:-5]}_adaptive_threshold_{i}'
volpy_name = f'volpy_{os.path.split(fname_tot)[1][:-5]}_{threshold_method}_{i}'

vpy = np.load(S2PDirectory + '/' + volpy_name + '.npy', allow_pickle=True).item()
idx = np.where(vpy['locality'] > 0)[0]
print(idx)
utils.view_components(vpy, img_corr, idx)

## Spike Gaussian Convolution Per ROI

In [8]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import fftconvolve

def event_times_to_rate(event_times, time_window, kernel_width, time_step=1):
    """
    Fast conversion of event times to event rate using a Gaussian kernel.

    Parameters:
        event_times (array-like): Sequence of event times.
        time_window (tuple): Start and end time of the observation window.
        kernel_width (float): Standard deviation of the Gaussian kernel.
        time_step (float): Time resolution of the rate function.

    Returns:
        times (np.ndarray): Array of time points.
        rates (np.ndarray): Event rate at each time point.
    """
    # Generate time points
    times = np.arange(time_window[0], time_window[1], time_step)
    
    # Create a Gaussian kernel
    kernel_half_width = 4 * kernel_width  # Use 4 standard deviations for truncation
    kernel_times = np.arange(-kernel_half_width, kernel_half_width, time_step)
    kernel = np.exp(-0.5 * (kernel_times / kernel_width)**2)
    kernel /= kernel.sum()  # Normalize kernel to preserve total event count
    
    # Create a histogram of event times
    event_histogram, _ = np.histogram(event_times, bins=len(times), range=time_window)
    
    # Convolve the histogram with the Gaussian kernel
    rates = fftconvolve(event_histogram, kernel, mode='same') / time_step
    
    return times, rates

with h5py.File(path_ROIs, 'r') as fl:
    ROIs = fl['mov'][()]  # load ROIs

FiringRate = []
Spikes_time = []
snr = []
low = []
cell = []
for i in range(len(ROIs)):
    volpy_name = f"volpy_{os.path.split(fname_tot)[1][:-5]}_{threshold_method}_{i}"

    vpy = np.load(S2PDirectory + '/' + volpy_name + '.npy', allow_pickle=True).item()
    event_times = vpy['spikes'][0]
    time_window = (0, len(vpy['t'][0]))
    kernel_width = 0.01*fr
    
    #print(vpy['snr'][0], vpy['low_spikes'][0])

    times, rates = event_times_to_rate(event_times, time_window, kernel_width)
    FiringRate.append(rates)
    Spikes_time.append(event_times)
    snr.append(vpy['snr'][0])
    low.append(vpy['low_spikes'][0])
    cell.append(vpy['num_spikes'][0][-1]/len(vpy['t'][0])*fr>0.05)

FiringRates = np.array(FiringRate)*fr
SpikesTimes = np.array(Spikes_time, dtype=object)
SNRs = np.array(snr)
Sparse = np.array(low)
# cells = np.array(cell)
# iscell = np.vstack((cells*1,SNRs)).T

np.save(S2PDirectory + '/rates.npy', FiringRates )
np.save(S2PDirectory + '/SpikesTimes.npy', SpikesTimes)
np.save(S2PDirectory + '/SNR.npy', SNRs)
np.save(S2PDirectory + '/Sparse.npy', Sparse)

In [9]:
# Stop cluster and clean up log files
cm.stop_server(dview=dview)
log_files = glob.glob('*_LOG_*')
for log_file in log_files:
    os.remove(log_file)