In [None]:
# This might contain a few more packages than you need
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sys import path
from oasis.functions import gen_data, gen_sinusoidal_data, deconvolve, estimate_parameters
from oasis.plotting import simpleaxis
from oasis.oasis_methods import oasisAR1, oasisAR2
from suite2p.extraction import dcnv
import h5py as h5
from numba import jit, prange
from scipy.ndimage import filters
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from scipy.interpolate import interp1d

In [None]:
# If you want to run NND from Suite2p the way I have and with the same parameters, use this code:

def oasis_trace(F, v, w, t, l, s, tau, fs):
    """ spike deconvolution on a single neuron """
    NT = F.shape[0]
    g = -1./(tau * fs)

    it = 0
    ip = 0

    while it<NT:
        v[ip], w[ip],t[ip],l[ip] = F[it],1,it,1
        while ip>0:
            if v[ip-1] * np.exp(g * l[ip-1]) > v[ip]:
                # violation of the constraint means merging pools
                f1 = np.exp(g * l[ip-1])
                f2 = np.exp(2 * g * l[ip-1])
                wnew = w[ip-1] + w[ip] * f2
                v[ip-1] = (v[ip-1] * w[ip-1] + v[ip] * w[ip]* f1) / wnew
                w[ip-1] = wnew
                l[ip-1] = l[ip-1] + l[ip]
                ip -= 1
            else:
                break
        it += 1
        ip += 1

    s[t[1:ip]] = v[1:ip] - v[:ip-1] * np.exp(g * l[:ip-1])
    return(s)

#function for computing the non-regularized deconvolution
# baseline operation

#function for computing the non-regularized deconvolution
# baseline operation

def nndv(tau, fs, F, Fneu=None):
    #tau = 1.0 # timescale of indicator
    fs = 150.0 # sampling rate in Hz
    neucoeff = 0.7 # neuropil coefficient
    # for computing and subtracting baseline
    baseline = 'maximin' # take the running max of the running min after smoothing with gaussian
    sig_baseline = 5*10.0# in bins, standard deviation of gaussian with which to smooth
    win_baseline = 60.0 # in seconds, window in which to compute max/min filters

    ops = {'tau': tau, 'fs': fs, 'neucoeff': neucoeff,
           'baseline': baseline, 'sig_baseline': sig_baseline, 'win_baseline': win_baseline, 'batch_size': 200}

    # load traces and subtract neuropil
    #F = f_cell['f_cell'][()]
    #Fneu = f_cell['f_np'][()]
    if Fneu:
        Fc = F - ops['neucoeff'] * Fneu
    else:
        Fc = F

    Fc = dcnv.preprocess(
         F=Fc,
         baseline=ops['baseline'],
         win_baseline=ops['win_baseline'],
         sig_baseline=ops['sig_baseline'],
         fs=ops['fs'],
         prctile_baseline=8.0
     )

    Fc=Fc.ravel()

    NT = Fc.shape[0]
    Fc = Fc.astype(np.float32)

    v = np.zeros(NT, dtype=np.float32)
    w = np.zeros(NT, dtype=np.float32)
    t = np.zeros(NT, dtype=np.int64)
    l = np.zeros(NT, dtype=np.float32)
    s = np.zeros(NT, dtype=np.float32)

    spikes_nndv = oasis_trace(F=Fc, v=v, w=w, t=t, l=l, s=s, tau=ops['tau'], fs=ops['fs'])
    return spikes_nndv


In [None]:
#Here is how I ran the NND algorithm on all of the results for the 80 recordings of the final dataset and saved the results
#It is important though, that you use the correct values for the various genotypes. You probably just want the one for tet-O!

#Load the dff data - I had it in a dictionary of the form {recordingid(tif), np.array w/ dff values}, adjust as needed. I resampled all recordings to exactly 150 Hz:
dff_150hz = np.load('/home/peterl/dff_resampled_dict.npy', allow_pickle=True).item()

tau_dict = {'Emx1-s': 0.46594163866468175, 'tetO-s': 0.9622924335325923, 'Emx1-f': 0.16341948426755523, 'Cux2-f': 0.12241159191308901}

nndv_150hz_dict = {} #nndv spikes resampled dictionary - name as needed:
for tid, dff in dff_150hz.items(): 
    if tid in r2c.tid.unique().tolist():#just to select the recordings to run this on, you probably don't need this, since you only have four recordings

        gt = r2c[r2c.tid==tid].gtype.values[0]#determine the genotype - you don't need that if you only have one. r2c is a pandas dataframe that contains the correspondence between recordings (tidoasis), cells, animals, and genotype. You prbably don't need it.

        tau = np.exp(-1/(tau_dict[gt] * 150)) #set the frame rate! mine was 150.
        print(tau)
        nndv_spikes = nndv(tau, 1.0/150., dff[np.newaxis, :])
        nndv_150hz_dict[tid] = nndv_spikes
#save        
np.save('/home/peterl/nndv_150hz_dict', nndv_150hz_dict)

In [None]:
# Load your dictionaries of ground truth spikes and events events:
gt_150hz_dict = np.load('/home/peterl/gt_150hz_dict.npy', allow_pickle=True).item()

In [None]:
#helper function that rebins the events and the ground truth spikes by summation across non-overlapping window of width winsize
def rebin(vector, winsize): # winsize must be integer
    new_vector = np.zeros(vector.shape[0] // winsize)
    for n, value in enumerate(new_vector):
        new_vector[n] = np.sum(vector[n*winsize:(n+1)*winsize])
    return new_vector

In [None]:
# These are the things I've been comparing. Adapt to the names of your dictionaries.
methods = {'g': mlspike_150hz_dict.item(), 'b': nndv_150hz_dict,'k':l0_150hz_dict.item()}

##Emx1-f
for gtype in r2c.gtype.unique(): #if you don't have different genotypes, you don't need this loop
    #Loop over all cells in genotype
    factors = 5*np.array([1, 3, 6, 9, 15]) #These are binning factors. I compared 33, 100, 200, 300, 500 ms bins

    fig, ax = plt.subplots(nrows=1, ncols=5, sharex=True, sharey=True, figsize=(15,3))
    
    for k, factor in enumerate(factors):
        ax[k].set_title(gtype + ' - bin:' + str(np.round(factor * 1000./150.)) + ' ms')
        cellids = r2c.acid[r2c.gtype==gtype].unique()
        for color, method in methods.items():
            g_tpri = np.zeros((len(cellids), 100)) # g_tpri == global truth true positive rate, interpolated (averages per cell)

            for j, cid in enumerate(cellids): # I computed averages over cells, not recordings - you can eliminate this loop if you are using 1 recording per cell. 
                #For each cell loop over recordings
                tids = r2c.tid[r2c.acid==cid].unique()

                tpri = np.zeros((len(tids), 100)) #true positive rate for each recording (tid)
                for i, tid in enumerate(tids):
                #For each recording, loop over dictionary of dictionaries containing the inference results
                #compute ROC curves 
                    gt = rebin(gt_150hz_dict[tid].ravel(), factor) #rebin ground truth according to factor
                    gt = gt/np.max(gt) #normalize to max
                    test = method[tid].ravel() #load the events you are comparing
                    test[np.isnan(test)]=0 #if any entries are nan, set to 0
                    test = rebin(test, factor) #rebin loaded events according to factor
                    test = test/np.max(test) #normalize events after rebinning
                    fpr, tpr, _ = roc_curve(gt>0, test) #compute the ROC curve - does not work is ground truth is non-binary
                    interp_fn = interp1d(fpr, tpr, kind='previous') #left-bound interpolato
                    tpri[i,:] = interp_fn(np.linspace(0.,1., 100)) #interpolate the ROC curve for averaging
                #plot average of interpolated ROC curves for each cell     
                ax[k].plot(np.linspace(0.,1., 100)[:-2], tpri.mean(axis=0)[:-2], color, alpha=1, linewidth=0.2)
                #Compute average of interpolated ROC curves for each genotype
                g_tpri[j,:] = tpri.mean(axis=0)

            #Plot average of interpolated ROC curves for each genotype
            ax[k].plot(np.linspace(0.,1., 100)[:-2], g_tpri.mean(axis=0)[:-2], color, alpha=1)    
        ax[k].set_xlabel('False Positive Rate')
        ax[k].set_ylabel('True Positive Rate')

        # save plots
    plt.savefig(gtype + '_high_zoom_horizontal_line.pdf', dpi=600, facecolor='w', edgecolor='w',
        orientation='portrait', papertype=None, format=None,
        transparent=True, bbox_inches=None, pad_inches=0.1,
        frameon=None, metadata=None)
    plt.show()