In [None]:
reset

In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.io
from scipy.stats import pearsonr
import math,pickle
import itertools
import glob, os
from subprocess import Popen, PIPE
import PIL
import numpy as np
import tifffile as tiff
import scipy.signal 
import scipy.linalg
sns.set_context('notebook')
sns.set_style('ticks')
%matplotlib inline
s="""0.00025699
0.00850739
0.06541583
0.0784609 
0.07641301
0.06659586
0.05790289
0.04679429
0.02320798
0.01445644
0.00695772
0.00526551
0.002995  
0.0019852 
0.00128512
0.00134175
0.00040317
0
"""
GCaMP6s = np.array(s.replace("[","").replace("]","").split(),dtype=float)

def axcov(data, maxlag=5):
    """
    Compute the autocovariance of data at lag = -maxlag:0:maxlag
    Parameters
    ----------
    data : array
        Array containing fluorescence data
    maxlag : int
        Number of lags to use in autocovariance calculation
    Returns
    -------
    axcov : array
        Autocovariances computed from -maxlag:0:maxlag
    """
    
    data = data - np.mean(data)
    T = len(data)
    bins = np.size(data)
    xcov = np.fft.fft(data, np.power(2, nextpow2(2 * bins - 1)))
    xcov = np.fft.ifft(np.square(np.abs(xcov)))
    xcov = np.concatenate([xcov[np.arange(xcov.size - maxlag, xcov.size)],
                           xcov[np.arange(0, maxlag + 1)]])
    #xcov = xcov/np.concatenate([np.arange(T-maxlag,T+1),np.arange(T-1,T-maxlag-1,-1)])
    return np.real(xcov/T)
    
def nextpow2(value):
    """
    Find exponent such that 2^exponent is equal to or greater than abs(value).
    Parameters
    ----------
    value : int
    Returns
    -------
    exponent : int
    """
    
    exponent = 0
    avalue = np.abs(value)
    while avalue > np.power(2, exponent):
        exponent += 1
    return exponent        

def estimate_time_constant(fluor, p = 2, sn = None, lags = 5, fudge_factor = 1.):
    """    
    Estimate AR model parameters through the autocovariance function    
    Inputs
    ----------
    fluor        : nparray
        One dimensional array containing the fluorescence intensities with
        one entry per time-bin.
    p            : positive integer
        order of AR system  
    sn           : float
        noise standard deviation, estimated if not provided.
    lags         : positive integer
        number of additional lags where he autocovariance is computed
    fudge_factor : float (0< fudge_factor <= 1)
        shrinkage factor to reduce bias
        
    Return
    -----------
    g       : estimated coefficients of the AR process
    """    
    

    if sn is None:
        sn = GetSn(fluor)
        
    lags += p
    xc = axcov(fluor,lags)        
    xc = xc[:,np.newaxis]
    
    A = scipy.linalg.toeplitz(xc[lags+np.arange(lags)],xc[lags+np.arange(p)]) - sn**2*np.eye(lags,p)
    g = np.linalg.lstsq(A,xc[lags+1:])[0]
    gr = np.roots(np.concatenate([np.array([1]),-g.flatten()]))
    gr = (gr+gr.conjugate())/2.
    gr[gr>1] = 0.95 + np.random.normal(0,0.01,np.sum(gr>1))
    gr[gr<0] = 0.15 + np.random.normal(0,0.01,np.sum(gr<0))
    g = np.poly(fudge_factor*gr)
    g = -g[1:]    
        
    return g.flatten()
    

def estimate_parameters(fluor, p = 2, sn = None, g = None, range_ff = [0.25,0.5], method = 'logmexp', lags = 5, fudge_factor = 1):
    """
    Estimate noise standard deviation and AR coefficients if they are not present
    p: positive integer
        order of AR system  
    sn: float
        noise standard deviation, estimated if not provided.
    lags: positive integer
        number of additional lags where he autocovariance is computed
    range_ff : (1,2) array, nonnegative, max value <= 0.5
        range of frequency (x Nyquist rate) over which the spectrum is averaged  
    method: string
        method of averaging: Mean, median, exponentiated mean of logvalues (default)
    fudge_factor: float (0< fudge_factor <= 1)
        shrinkage factor to reduce bias
    """
    
    if sn is None:
        sn = GetSn(fluor,range_ff,method)
        
    if g is None:
        if p == 0:
            g = np.array(0)
        else:
            g = estimate_time_constant(fluor,p,sn,lags,fudge_factor)

    return g,sn

def GetSn(fluor, range_ff = [0.25,0.5], method = 'logmexp'):
    """    
    Estimate noise power through the power spectral density over the range of large frequencies    
    Inputs
    ----------
    fluor    : nparray
        One dimensional array containing the fluorescence intensities with
        one entry per time-bin.
    range_ff : (1,2) array, nonnegative, max value <= 0.5
        range of frequency (x Nyquist rate) over which the spectrum is averaged  
    method   : string
        method of averaging: Mean, median, exponentiated mean of logvalues (default)
        
    Return
    -----------
    sn       : noise standard deviation
    """
    

    ff, Pxx = scipy.signal.welch(fluor)
    ind1 = ff > range_ff[0]
    ind2 = ff < range_ff[1]
    ind = np.logical_and(ind1,ind2)
    Pxx_ind = Pxx[ind]
    sn = {
        'mean': lambda Pxx_ind: np.sqrt(np.mean(Pxx_ind/2)),
        'median': lambda Pxx_ind: np.sqrt(np.median(Pxx_ind/2)),
        'logmexp': lambda Pxx_ind: np.sqrt(np.exp(np.mean(np.log(Pxx_ind/2))))
    }[method](Pxx_ind)

    return sn


def cvxpy_foopsi(fluor,  g=None, sn=None, b=None, c1=None, bas_nonneg=True,solvers=None):
    '''Solves the deconvolution problem using the cvxpy package and the ECOS/SCS library. 
    Parameters:
    -----------
    fluor: ndarray
        fluorescence trace 
    g: list of doubles
        parameters of the autoregressive model, cardinality equivalent to p        
    sn: double
        estimated noise level
    b: double
        baseline level. If None it is estimated. 
    c1: double
        initial value of calcium. If None it is estimated.  
    bas_nonneg: boolean
        should the baseline be estimated        
    solvers: tuple of two strings
        primary and secondary solvers to be used. Can be choosen between ECOS, SCS, CVXOPT    
    Returns:
    --------
    c: estimated calcium trace
    b: estimated baseline
    c1: esimtated initial calcium value
    g: esitmated parameters of the autoregressive model
    sn: estimated noise level
    sp: estimated spikes 
        
    '''
    if g is None or sn is None:        
        g,sn = estimate_parameters(fluor, p=2, sn=sn, g = g, range_ff=noise_range, method=noise_method, lags=lags, fudge_factor=1)
    try:
        import cvxpy as cvx
    except ImportError:
        raise ImportError('cvxpy solver requires installation of cvxpy.')
    if solvers is None:
        solvers=['ECOS','SCS']
        
    T = fluor.size
    
    # construct deconvolution matrix  (sp = G*c)     
    G=scipy.sparse.dia_matrix((np.ones((1,T)),[0]),(T,T))
    
    for i,gi in enumerate(g):
        G = G + scipy.sparse.dia_matrix((-gi*np.ones((1,T)),[-1-i]),(T,T))
        
    gr = np.roots(np.concatenate([np.array([1]),-g.flatten()])) 
    gd_vec = np.max(gr)**np.arange(T)  # decay vector for initial fluorescence
    gen_vec = G.dot(scipy.sparse.coo_matrix(np.ones((T,1))))                          
 
    c = cvx.Variable(T) # calcium at each time step
    constraints=[]
    cnt = 0
    if b is None:
        flag_b = True
        cnt += 1
        b =  cvx.Variable(1) # baseline value
        if bas_nonneg:
            b_lb = 0
        else:
            b_lb = np.min(fluor)            
        constraints.append(b >= b_lb)
    else:
        flag_b = False

    if c1 is None:
        flag_c1 = True
        cnt += 1
        c1 =  cvx.Variable(1) # baseline value
        constraints.append(c1 >= 0)
    else:
        flag_c1 = False    
    
    thrNoise=sn * np.sqrt(fluor.size)
    
    try:
        objective=cvx.Minimize(cvx.norm(G*c,1)) # minimize number of spikes
        constraints.append(G*c >= 0)
        constraints.append(cvx.norm(-c + fluor - b - gd_vec*c1, 2) <= thrNoise) # constraints
        prob = cvx.Problem(objective, constraints) 
        result = prob.solve(solver=solvers[0])    
        
        if  not (prob.status ==  'optimal' or prob.status == 'optimal_inaccurate'):
            raise ValueError('Problem solved suboptimally or unfeasible')            
        
#         print 'PROBLEM STATUS:' + prob.status 
#         sys.stdout.flush()
    except (ValueError,cvx.SolverError) as err:     # if solvers fail to solve the problem           
#          print(err) 
#          sys.stdout.flush()
         lam=sn/500;
         constraints=constraints[:-1]
         objective = cvx.Minimize(cvx.norm(-c + fluor - b - gd_vec*c1, 2)+lam*cvx.norm(G*c,1))
         prob = cvx.Problem(objective, constraints)
         try: #in case scs was not installed properly
             try:
#                  print('TRYING AGAIN ECOS') 
#                  sys.stdout.flush()
                 result = prob.solve(solver=solvers[0]) 
             except:
                 result = prob.solve(solver=solvers[1]) 
         except:             
#              sys.stderr.write('***** SCS solver failed, try installing and compiling SCS for much faster performance. Otherwise set the solvers in tempora_params to ["ECOS","CVXOPT"]')
#              sys.stderr.flush()
             result = prob.solve(solver='CVXOPT')
             raise
             
         if not (prob.status ==  'optimal' or prob.status == 'optimal_inaccurate'):
#             print 'PROBLEM STATUS:' + prob.status 
            raise Exception('Problem could not be solved')
            
    
    
    sp = np.squeeze(np.asarray(G*c.value))    
    c = np.squeeze(np.asarray(c.value))                
    if flag_b:    
        b = np.squeeze(b.value)        
    if flag_c1:    
        c1 = np.squeeze(c1.value)
        
    return c,b,c1,g,sn,sp

In [2]:
def constrained_foopsi_parallel(fluor):
    cc_,cb_,c1_,gn_,sn_,sp_ = cvxpy_foopsi(fluor)
    return cc_,cb_,c1_,gn_,sn_,sp_
from joblib import Parallel,delayed
noise_method = 'logmexp'
noise_range = [.25,.5]
pool = Parallel(n_jobs=11, verbose=2)
lags = 5 
fudge_factor = 1

In [3]:
directory='Emmanuel'
savedirectory='Emmanuel/Results'
os.chdir('/mnt/downloads/'+directory+'/')
p=Popen(['ls'], shell=False, stdout=PIPE, close_fds=True).stdout.readlines()
filelist=[]
for filename in p:
    if filename.startswith('fish'):
        if filename.endswith('tif\n'):
            filelist.append(filename.rstrip('\n'))
            
filelist.sort()

In [4]:
filelist

['fish100_depth4_age12.tif',
 'fish101_depth4_age12.tif',
 'fish102_depth4_age12.tif',
 'fish103_depth4_age6.tif',
 'fish104_depth4_age6.tif',
 'fish105_depth4_age6.tif',
 'fish106_depth4_age6.tif',
 'fish107_depth4_age9.tif',
 'fish108_depth4_age9 second.tif',
 'fish108_depth4_age9.tif',
 'fish110_depth4_age6 second.tif',
 'fish110_depth4_age6.tif',
 'fish111_depth4_age6.tif',
 'fish66_depth4_age6.tif',
 'fish67_depth4_age6.tif',
 'fish68_depth4_age6.tif',
 'fish71_depth4_age12.tif',
 'fish72_depth4_age6.tif',
 'fish73_depth4_age6.tif',
 'fish80_depth4_age12.tif',
 'fish81_depth4_age12.tif',
 'fish82_depth4_age12.tif',
 'fish84_depth4_age6.tif',
 'fish84_depth4_age6_second.tif',
 'fish85_depth4_age6.tif',
 'fish86_depth4_age6.tif',
 'fish89_depth4_age9.tif',
 'fish93_depth4_age12.tif',
 'fish94_depth4_age12.tif',
 'fish95_depth4_age9.tif',
 'fish96_depth4_age12.tif',
 'fish98_depth4_age12.tif',
 'fish99_depth4_age12.tif']

In [None]:
import pickle
from joblib import Parallel, delayed
counter=-1
dict_allindex=dict.fromkeys(filelist, 0)
#import skimage.external.tifffile as tiff


for i,filename in enumerate(filelist):  
    Mask = PIL.Image.open('/mnt/downloads/'+directory+'/Mask_AVG_'+filename)
    Mask=np.asarray(Mask,dtype=np.int64)    
    for value in xrange(0,Mask.max()+1):
        size=(Mask==value).sum()
        if 8 < size < 200:
            Mask[Mask==value]=counter
            counter=counter-1
        else:
            Mask[Mask==value]=0
    Mask=np.absolute(Mask)
    Mask.astype(np.uint32)
    tiff.imsave('/mnt/downloads/'+savedirectory+'/Maskb_'+filename, Mask)
    MeanFluo_ROI=tiff.imread('/mnt/downloads/'+directory+'/'+filename)
    MeanFluo_ROI=MeanFluo_ROI.swapaxes(0,2).swapaxes(0,1)
    uniq=np.unique(Mask)
    meanVals=[np.mean(MeanFluo_ROI[Mask == grp],axis=0) for grp in uniq if grp != 0]
    meanVals=np.asarray(meanVals,dtype=np.uint16)
    Centroids=[Parallel(n_jobs=-1)(delayed(scipy.ndimage.measurements.center_of_mass)(Mask,Mask,index=grp) for grp in uniq if grp !=0)]
    Centroids=np.asarray(Centroids).squeeze()    
    scipy.io.savemat('/mnt/downloads/'+savedirectory+'/'+filename+'-FluoTraces-Maskb.mat', mdict={'Centroids':Centroids,'FluoTraces':meanVals}, oned_as='column')
    if i==0 :
        Centroids_ROI=Centroids
        SegmentData=meanVals
        dict_allindex[filename]=SegmentData.shape[0]
    else:
        Centroids_ROI=np.concatenate((Centroids_ROI,Centroids),axis=0)
        SegmentData=np.concatenate((SegmentData,meanVals),axis=0)
        dict_allindex[filename]=SegmentData.shape[0]
    if meanVals.shape[0]+1 != np.unique(Mask).shape[0]:
        print "there was an error with the ROIs"
        break
    if np.absolute(counter+1) != SegmentData.shape[0]:
        print "there was an error with the counter"
        break

In [25]:
SegmentData.shape

(4679, 700)

In [None]:
scipy.io.savemat('/mnt/downloads/'+savedirectory+'/'+directory+'-FluoTraces-Maskb.mat', mdict={'Centroids':Centroids_ROI,'FluoTraces':SegmentData}, oned_as='column')
with open('/mnt/downloads/'+savedirectory+'/'+directory+'-dict_allindex.pickle', 'w') as handle:
      pickle.dump(dict_allindex, handle)    
filelist=np.zeros(len(dict_allindex.keys()),dtype=np.object)
values=np.zeros(len(dict_allindex.keys()),dtype=int)
i=0
for key in sorted(dict_allindex):
    filelist[i]=key
    values[i]=dict_allindex[key]
    i=i+1
scipy.io.savemat('/mnt/downloads/'+savedirectory+'/'+directory+'-dict_allindex_Maskb.mat',mdict={'files':filelist,'index':values}, oned_as='column')

In [None]:
aud8freq=np.zeros((8,700),dtype=np.float16);
aud8freq[0,21:21+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[0,407:407+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[0,612:612+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[1,47:47+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[1,382:382+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[1,532:532+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[2,71:71+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[2,356:356+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[2,507:507+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[3,96:96+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[3,331:331+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[3,557:557+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[4,121:121+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[4,308:308+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[4,487:487+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[5,146:146+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[5,281:281+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[5,586:586+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[6,175:175+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[6,256:256+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[6,461:461+GCaMP6s.shape[0]]=GCaMP6s;

aud8freq[7,200:200+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[7,231:231+GCaMP6s.shape[0]]=GCaMP6s;
aud8freq[7,636:636+GCaMP6s.shape[0]]=GCaMP6s;

row,col=np.where(aud8freq==0.00850739000000000)
corr=np.zeros(SegmentData.shape[0])
for x in enumerate(SegmentData):
    for y in col:
        for z in xrange(-3,6):
            w=y+z;
            temp=x[1][w:w+GCaMP6s.shape[0]]
            temp=np.corrcoef(temp,GCaMP6s)
            if temp[0,1]>0.75:
                corr[x[0]]=corr[x[0]]+1
                break
idx_corr=corr>3
SelectCorr=SegmentData[idx_corr]

In [None]:
scipy.io.savemat('/mnt/downloads/'+savedirectory+'/GCaMP6f-ganglia-FluoTraces-Maskb-SelectStrict.mat', mdict={'SelectCorr':SelectCorr,'Nb_corr':corr}, oned_as='column')

In [28]:
scipy.io.savemat('/mnt/downloads/'+savedirectory+'GCaMP6f-'+directory+'_oopsi.mat', mdict={'Sp_infer':sp_,'Ca_infer':cc_,'baseline_infer':cb_,'Ca_init':c1_,'Params_est':gn_,'Noise':sn_}, oned_as='column')

In [27]:
#SegmentData=np.transpose(SegmentData)
import scipy.io
results = pool(delayed(constrained_foopsi_parallel)(rr) for rr in SegmentData)
cc_,cb_,c1_,gn_,sn_,sp_ = zip(*results)
scipy.io.savemat('/mnt/downloads/'+savedirectory+'GCaMP6f-'+savedirectory+'_oopsi.mat', mdict={'Sp_infer':sp_,'Ca_infer':cc_,'baseline_infer':cb_,'Ca_init':c1_,'Params_est':gn_,'Noise':sn_}, oned_as='column')

  return new.astype(intype)
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done  19 tasks      | elapsed:    1.0s
[Parallel(n_jobs=11)]: Done 140 tasks      | elapsed:    2.4s
[Parallel(n_jobs=11)]: Done 343 tasks      | elapsed:    4.6s
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 626 tasks      | elapsed:    7.7s
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 991 tasks      | elapsed:   12.1s
  return new.astype(intype)
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 1436 tasks      | elapsed:   17.1s
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 1963 tasks      | elapsed:   23.1s
  return new.astype(intype)
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 2570 tasks      | elapsed:   32.3s
  return new.astype(intype)
  return new.astype(intype)
[Parallel(n_jobs=11)]: Done 3259 tasks      | elapsed:   42.3s
[Parallel(n_jobs=11)]: Done 4028 tasks      | elapsed:   55.2s
[Parallel(n_jobs=11)]: Done 4679 out of 4679 | elapsed:  1.1min fin

IOError: [Errno 2] No such file or directory: '/mnt/downloads/ptf1a/resultsGCaMP6f-ptf1a/results_oopsi.mat'

In [None]:
aud5vol=np.zeros((5,550),dtype=np.float16);
aud5vol[0,51:51+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[0,301:301+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[0,376:376+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[1,76:76+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[1,276:276+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[1,451:451+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[2,101:101+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[2,251:251+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[2,426:426+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[3,126:126+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[3,226:226+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[3,401:401+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[4,151:151+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[4,201:201+GCaMP6s.shape[0]]=GCaMP6s;
aud5vol[4,351:351+GCaMP6s.shape[0]]=GCaMP6s;

row,col=np.where(aud5vol==0.00850739000000000)
corr=np.zeros(SegmentData.shape[0])
for x in enumerate(SegmentData):
    for y in col:
        for z in xrange(-3,6):
            w=y+z;
            temp=x[1][w:w+GCaMP6s.shape[0]]
            temp=np.corrcoef(temp,GCaMP6s)
            if temp[0,1]>0.75:
                corr[x[0]]=corr[x[0]]+1
                break
idx_corr=corr>3
SelectCorr=SegmentData[idx_corr]

In [None]:
from matplotlib import pyplot as plt
 
import sys
import numpy as np
import ca_source_extraction as cse
import gc, pickle

from time import time
from scipy.sparse import coo_matrix
import tifffile
import subprocess
import time as tm
from time import time
import pylab as pl
import psutil
import glob
import os
import scipy
#%%

#%% FOR LOADING ALL TIFF FILES IN A FILE AND SAVING THEM ON A SINGLE MEMORY MAPPABLE FILE
fnames_all=[]
base_folder='/mnt/downloads/Lucy/' # folder containing the demo files
for file in glob.glob(os.path.join(base_folder,'f*.tif')):
    if file.endswith(".tif"):
        fnames_all.append(file)
fnames_all.sort()
#%% stop server and remove log files
cse.utilities.stop_server() 
log_files=glob.glob('Yr*_LOG_*')
for log_file in log_files:
    os.remove(log_file)
from ipyparallel import Client
n_processes = 11
sys.stdout.flush()  
cse.utilities.stop_server()
cse.utilities.start_server()
c=Client()
dview=c[:n_processes]
import gc
import scipy.io
# n_processes = 12
p=2
gSig=[3,3]
K=500
# K=250
# cse.utilities.stop_server() # trying to stop in case it was already runnning
# cse.utilities.start_server(n_processes)
# with open('/mnt/downloads/Hindbrain-dict_allindex.pickle', 'r') as handle:
#     dict_allindex=pickle.load(handle)
a=-1
for i,name in enumerate(fnames_all):
    if i>a:
#         cse.utilities.start_server(n_processes)
        fnames=[name]
        fname_new=cse.utilities.save_memmap(fnames,base_name='Yr',resize_fact=(1,1,1))
        Yr,dims,T=cse.utilities.load_memmap(fname_new)
#         d,T=np.shape(Yr)
        Y=np.reshape(Yr,dims+(T,),order='F')
        options = cse.utilities.CNMFSetParms(Y,n_processes,p=p,gSig=gSig,K=K,tsub=3)
        Cn = cse.utilities.local_correlations(Y[:1000])    
        Yr,sn,g,psx = cse.pre_processing.preprocess_data(Yr, dview=dview,**options['preprocess_params'])
        Ain,Cin, b_in, f_in, center=cse.initialization.initialize_components(Y,sn=sn, **options['init_params'])
        if np.isfinite(Cin).max(): 
            options['temporal_params']['p'] = 0 # set this to zero for fast updating without deconvolution
            Cin,f,S,bl,c1,neurons_sn,g,YrA = cse.temporal.update_temporal_components(Yr,Ain,b_in,Cin,f_in,dview=dview,bl=None,c1=None,sn=None,g=None,**options['temporal_params'])
            #%% UPDATE SPATIAL COMPONENTS
            Ain,b_in,Cin = cse.spatial.update_spatial_components(Yr, Cin, f_in, Ain, sn=sn, dview=dview, **options['spatial_params'])
            options['temporal_params']['p'] = 0 # set this to zero for fast updating without deconvolution
            Cin,f,S,bl,c1,neurons_sn,g,YrA = cse.temporal.update_temporal_components(Yr,Ain,b_in,Cin,f_in,bl=None,c1=None,sn=None,g=None,**options['temporal_params'])
            Ain,Cin,nr_m,merged_ROIs,S_m,bl_m,c1_m,sn_m,g_m=cse.merging.merge_components(Yr,Ain,b_in,Cin,f,S,sn,options['temporal_params'], options['spatial_params'], bl=bl, c1=c1, sn=neurons_sn, g=g, thr=0.8,mx=1000, fast_merge = True)
            Ain,b_in,Cin = cse.spatial.update_spatial_components(Yr, Cin, f, Ain, dview=dview, sn=sn, **options['spatial_params'])
            options['temporal_params']['p'] = p # set it back to original value to perform full deconvolution
            Cin,f2,S,bl2,c12,neurons_sn2,g21,YrA = cse.temporal.update_temporal_components(Yr,Ain,b_in,Cin,f, dview=dview,bl=None,c1=None,sn=None,g=None,**options['temporal_params'])    
            traces=Cin+YrA
            idx_components, fitness, erfc = cse.utilities.evaluate_components(traces,N=5,robust_std=True)
            scipy.io.savemat(fnames[0][:-4]+'_output_analysis_matlab.mat',mdict={'ROIs':Ain,'DenoisedTraces':Cin, 'Noise':YrA, 'Spikes': S , 'idx_components':idx_components, 'fitness':fitness})
        #             dict_allindex[name]=Cin.shape[0]
        #         if name == fnames_all[0]:
        #             Spikes=S
        #             Calcium=Cin
        #         else:
        #             Spikes=np.vstack((Spikes,S))
        #             Calcium=np.vstack((Calcium,Cin))
#             cse.utilities.stop_server()
            log_files=glob.glob('Yr*_LOG_*')
            for log_file in log_files:
                os.remove(log_file)
            del(Y,YrA)
            del(Yr,Ain,S,Cin)
        #         np.savez('/mnt/downloads/Temp.npz',Spikes=Spikes,Calcium=Calcium,i=i,name=name)
        #             with open('/mnt/downloads/Hindbrain-dict_allindex.pickle', 'w') as handle:
        #                 pickle.dump(dict_allindex, handle)
        np.save('/mnt/downloads/Temp.npy',i)
        gc.collect()
cse.utilities.stop_server()

In [None]:
from scipy.spatial import distance
def compute_bic(kmeans,X):
    """
    Computes the BIC metric for a given clusters

    Parameters:
    -----------------------------------------
    kmeans:  List of clustering object from scikit learn

    X     :  multidimension np array of data points

    Returns:
    -----------------------------------------
    BIC value
    """
    # assign centers and labels
    centers = kmeans.cluster_centers_
    labels  = kmeans.labels_
    #number of clusters
    m = kmeans.n_clusters
    # size of the clusters
    n = np.bincount(labels)
    #size of data set
    N, d = X.shape

    #compute variance for all clusters beforehand
    cl_var = (1.0 / (N - m) / d) * sum([sum(distance.cdist(X[np.where(labels == i)], [centers[i]], 'euclidean')**2) for i in range(m)])

    const_term = 0.5 * m * np.log(N) * (d+1)

    BIC = np.sum([n[i] * np.log(n[i]) -
               n[i] * np.log(N) -
             ((n[i] * d) / 2) * np.log(2*np.pi*cl_var) -
             ((n[i] - 1) * d/ 2) for i in range(m)]) - const_term
    

    return(BIC)


In [5]:
kmin=1
kmax=200
bic=np.zeros(kmax-kmin)
for i in xrange(kmin,kmax):
    mbkmv = MiniBatchKMeans(i, max_iter=200, batch_size=10000).fit(SelectCorr)
    bic[i]=compute_bic(mbkmv,SelectCorr)

NameError: name 'MiniBatchKMeans' is not defined

In [None]:
plt.plot(bic[:30])