In [None]:
try:
    get_ipython().run_line_magic('load_ext', 'autoreload')
    get_ipython().run_line_magic('autoreload', '2')
    get_ipython().run_line_magic('matplotlib', 'qt')
except:
    pass

import logging
import matplotlib.pyplot as plt
import numpy as np

logging.basicConfig(format=
                          "%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s",
                    # filename="caiman.log",
                    level=logging.DEBUG)

import caiman as cm
from caiman.source_extraction import cnmf
from caiman.utils.utils import download_demo
from caiman.utils.visualization import inspect_correlation_pnr, nb_inspect_correlation_pnr
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import params as params
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
from caiman.source_extraction.cnmf.initialization import init_neurons_corr_pnr
from caiman.utils.visualization import view_quilt
import cv2
import scipy.io as sio
import pickle

try:
    cv2.setNumThreads(0)
except:
    pass
import bokeh.plotting as bpl
import holoviews as hv
bpl.output_notebook()
hv.notebook_extension('bokeh')

In [None]:
fnames_file = [r'H:\CM2scope_experimental_data\trace_fear_conditioning\train\gzc_rasgrf-ai148d-371\My_V4_Miniscope\AMF_despeckle_MC_denoised_8bit.tif'.replace('\\','\\'),
             ]
print(fnames_file)

In [None]:
for fnames in fnames_file:
    fnames=[fnames]
    #%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
    if 'dview' in locals():
        cm.stop_server(dview=dview)
    c, dview, n_processes = cm.cluster.setup_cluster(
        backend='local', n_processes=None, single_thread=False)
    
    # dataset dependent parameters
    frate = 9.3                      # movie frame rate
    decay_time = 0.4                 # length of a typical transient in seconds
    
    # motion correction parameters
    motion_correct = False    # flag for performing motion correction
    pw_rigid = True         # flag for performing piecewise-rigid motion correction (otherwise just rigid)
    gSig_filt = (5, 5)      # size of high pass spatial filtering, used in 1p data
    max_shifts = (20, 20)      # maximum allowed rigid shift
    strides = (128, 128)       # start a new patch for pw-rigid motion correction every x pixels
    overlaps = (32, 32)      # overlap between pathes (size of patch strides+overlaps)
    max_deviation_rigid = 10  # maximum deviation allowed for patch with respect to rigid shifts
    border_nan = 'copy'      # replicate values along the boundaries
    
    num_frames_split = 100  #根据数据量大大小进行更改
    
    mc_dict = {
        'fnames': fnames,
        'fr': frate,
        'decay_time': decay_time,
        'pw_rigid': pw_rigid,
        'max_shifts': max_shifts,
        'gSig_filt': gSig_filt,
        'strides': strides,
        'overlaps': overlaps,
        'max_deviation_rigid': max_deviation_rigid,
        'border_nan': border_nan,
        'num_frames_split':num_frames_split
    }
    
    opts = params.CNMFParams(params_dict=mc_dict)
    
    
    fname_new = cm.save_memmap(fnames, base_name='memmap_',order='C', border_to_0=0, dview=dview)
    
    # load memory mappable file
    Yr, dims, T = cm.load_memmap(fname_new)
    images = Yr.T.reshape((T,) + dims, order='F')
    print(dims)
    print(T)
    print(Yr.shape)
    print(images.shape)
    
    # parameters for source extraction and deconvolution
    p = 1               # order of the autoregressive system
    K = None            # upper bound on number of components per patch, in general None
    gSig = (2,2)       # gaussian width of a 2D gaussian kernel, which approximates a neuron
    gSiz = (7,7)     # average diameter of a neuron, in general 4*gSig+1
    Ain = None          # possibility to seed with predetermined binary masks
    merge_thr = .7     # merging threshold, max correlation allowed
    rf = 80            # half-size of the patches in pixels. e.g., if rf=40, patches are 80x80
    stride_cnmf = 24    # amount of overlap between the patches in pixels
    #                     (keep it at least large as gSiz, i.e 4 times the neuron size gSig)
    tsub = 1            # downsampling factor in time for initialization,
    #                     increase if you have memory problems
    ssub = 1            # downsampling factor in space for initialization,
    #                     increase if you have memory problems
    p_ssub = 1
    p_tsub = 1
    #                     you can pass them here as boolean vectors
    low_rank_background = None  # None leaves background of each patch intact,
    #                     True performs global low-rank approximation if gnb>0
    gnb = -2             # number of background components (rank) if positive,
    #                     else exact ring model with following settings
    #                         gnb= 0: Return background as b and W
    #                         gnb=-1: Return full rank background B
    #                         gnb<-1: Don't return background
    nb_patch = 0        # number of background components (rank) per patch if gnb>0,
    #                     else it is set automatically
    min_corr = 0.6   # min peak value from correlation image
    min_pnr = 15      # min peak to noise ration from PNR image
    ssub_B = 1          # additional downsampling factor in space for background
    ring_size_factor = 1.2 # radius of ring is gSiz*ring_size_factor
    
    bord_px=0
    
    merge_parallel=True
    
    
    opts.change_params(params_dict={'method_init': 'corr_pnr',  # use this for 1 photon
                                    'K': K,
                                    'gSig': gSig,
                                    'gSiz': gSiz,
                                    'merge_thr': merge_thr,
                                    'p': p,
                                    'tsub': tsub,
                                    'ssub': ssub,
                                    'p_ssub':p_ssub,
                                    'p_tsub':p_tsub,
                                    'rf': rf,
                                    'stride': stride_cnmf,
                                    'only_init': True,    # set it to True to run CNMF-E
                                    'nb': gnb,
                                    'nb_patch': nb_patch,
                                    'method_deconvolution': 'oasis',       # could use 'cvxpy' alternatively
                                    'low_rank_background': low_rank_background,
                                    'update_background_components': False,  # sometimes setting to False improve the results
                                    'min_corr': min_corr,
                                    'min_pnr': min_pnr,
                                    'normalize_init': False,               # just leave as is
                                    'center_psf': True,                    # leave as is for 1 photon
                                    'ssub_B': ssub_B,
                                    'ring_size_factor': ring_size_factor,
                                    'del_duplicates': True, # whether to remove duplicates from initialization
                                    'merge_parallel':merge_parallel,
                                    'method_exp': 'dilate',
                                    'border_pix': bord_px})                # number of pixels to not consider in the borders)
    
    cn_filter, pnr,data_max,data_noise,std = cm.summary_images.correlation_pnr_cuda(images, gSig=2, swap_dim=False,center_psf=True) # change swap dim if output looks weird, it is a problem with tiffile
    
    
    name = fnames[0].split(".")[0]+'_variables.pkl'
    with open(name, 'wb') as file:
        pickle.dump({'cn_filter': cn_filter, 'pnr': pnr, 'data_max': data_max, 'data_noise': data_noise, 'std': std, 'opts': opts}, file)
    
    name = fnames[0].split(".")[0]+'_correlation_images_sigma2.png'
    cn_threshold=cn_filter.copy()
    cn_threshold[cn_threshold<0]=0
    cm.summary_images.save_summary_images(cn_threshold,name)
    
    name = fnames[0].split(".")[0]+'pnr_images_sigma2.png'
    cm.summary_images.save_summary_images(pnr,name)
    
    name = fnames[0].split(".")[0]+'_data_max_sigma2.png'
    cm.summary_images.save_summary_images(data_max,name)
    
    name = fnames[0].split(".")[0]+'_data_std_sigma2.png'
    cm.summary_images.save_summary_images(std,name)
    
    name = fnames[0].split(".")[0]+'_data_noise_sigma2.png'
    cm.summary_images.save_summary_images(data_noise,name)
    
    cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, Ain=Ain, params=opts)
    cnm.fit(images)
    
    rval_lowest =-0.5
    SNR_lowest =3
    cnn_lowest=0.8
    #high threshold
    min_SNR = 3.5
    r_values_min = 0.8
    min_cnn_thr=1
    cnm.params.set('quality', {'rval_lowest': rval_lowest,'SNR_lowest': SNR_lowest,'min_SNR': min_SNR,'cnn_lowest':cnn_lowest,
                               'rval_thr': r_values_min,'min_cnn_thr':min_cnn_thr,
                               'use_cnn': False,'use_ecc': True,'max_ecc': 1.8,'gSig_range': (2.5,3)})
    cnm.estimates.evaluate_components(images, cnm.params, dview=dview)
    
    print(' ***** ')
    print('Number of total components: ', len(cnm.estimates.C))
    print('Number of accepted components: ', len(cnm.estimates.idx_components))
    
    cnm.estimates.rval_lowest =rval_lowest
    cnm.estimates.SNR_lowest =SNR_lowest
    cnm.estimates.min_SNR = min_SNR
    cnm.estimates.r_values_min = r_values_min

    name = fnames[0].split(".")[0]+'_cnm_estimates.pkl'
    with open(name, 'wb') as file:
        pickle.dump(cnm.estimates, file)
    
    #%% plot contour plots of accepted and rejected components
    cnm.estimates.plot_contours_nb(img=pnr*cn_filter, idx=cnm.estimates.idx_components,thr=0.8,cmap='hot',thr_method='nrg')
    
    good_caAct = cnm.estimates.C
    good_resid = cnm.estimates.YrA
    good_rawFl = good_caAct + good_resid
    good_rawFl=good_rawFl.astype(np.float32)
    
    cnm.estimates.detrend_df_f()
    
    
    sio.savemat(fnames[0].split(".")[0]+'_caiman_result.mat',{"A_neuron_good_idx":cnm.estimates.idx_components.astype(np.int32),"A_neuron_bad_idx":cnm.estimates.idx_components_bad.astype(np.int32),"C_trace":cnm.estimates.C.astype(np.float32),"A_neuron_sparse":cnm.estimates.A.astype(np.float32),"C_raw":good_rawFl,"detrended_trace":cnm.estimates.F_dff.astype(np.float32),"coordinates":cnm.estimates.coordinates})

    del cnm