In [None]:
import os
import glob
import pickle
import shutil
from itertools import chain

import numpy as np
from skimage import io
import scipy.io as sio
import matplotlib.pyplot as plt
from scipy import signal, ndimage

from caiman.source_extraction.volpy.spikepursuit import signal_filter

In [None]:
def flatten(x, wid):
    t = np.arange(x.shape[0])
    n_wid = np.ceil(x.shape[0]/wid).astype('int')
    xq = np.zeros(n_wid)
    tq = np.zeros(n_wid)
    for i in range(n_wid):
        tmp = x[(i*wid):((i+1)*wid-1)]
        lo = np.quantile(tmp, 0.3)
        hi = np.quantile(tmp, 0.8)
        tmp = tmp[tmp>lo]
        tmp = tmp[tmp<hi]
        xq[i] = tmp.mean()
        tq[i] = t[(i*wid):((i+1)*wid-1)].mean()
    y = np.interp(t, tq, xq)
    return y

In [None]:
def spike_SNR(t, spikes, fr=500):
    t = t - np.median(t)
    t_hp = signal_filter(t, 30, fr)
    selectSpikes = np.zeros(t.shape)
    selectSpikes[spikes] = 1
    sgn = np.mean(t[selectSpikes > 0])
    
    t_nonspike = np.zeros(t.shape)
    t_nonspike[spikes] = 1
    t_nonspike = np.convolve(t_nonspike, np.ones(20)/20, 'same') 
    t_nonspike = t_nonspike == 0 

    ff1 = -t_hp * (t_hp < 0) * t_nonspike
    ff1 = ff1[ff1 < np.quantile(ff1, 0.995)]
    Ns = np.sum(ff1 > 0)
    noise = np.sqrt(np.divide(np.sum(ff1**2), Ns)) 
    return sgn / noise

def firing_rate(t, spikes, fr=500, wid=0.5):
    selectSpikes = np.zeros(t.shape, dtype='float')
    selectSpikes[spikes] = 1.0
    T = int(wid * fr)
    return np.convolve(selectSpikes, np.ones(T), 'same') / wid

def snr_trace(t, spikes, fr=500):
    # note that the t here should have upward spikes
    # but if we assume t has equal std for upward or downward noise, it doesn't matter
    t = t - np.median(t)
    t_hp = signal_filter(t, 30, fr)
    selectSpikes = np.zeros(t.shape)
    selectSpikes[spikes] = 1
    
    t_nonspike = np.zeros(t.shape)
    t_nonspike[spikes] = 1
    t_nonspike = np.convolve(t_nonspike, np.ones(20)/20, 'same') 
    t_nonspike = t_nonspike == 0 

    ff1 = -t_hp * (t_hp < 0) * t_nonspike
    ff1 = ff1[ff1 < np.quantile(ff1, 0.99)]
    Ns = np.sum(ff1 > 0)
    noise = np.sqrt(np.divide(np.sum(ff1**2), Ns)) 
    return t / noise

In [None]:
def trialwise_average(value_group):
    avg_group = {}
    for sensor, mice in value_group.items():
        avg_group[sensor] = []
        for mouse in mice:
            avg_group[sensor] += (value_group[sensor][mouse])
    return avg_group

def animalwise_average(value_group):
    avg_group = {}
    for sensor, mice in value_group.items():
        avg_group[sensor] = []
        for mouse in mice:
            avg_group[sensor].append(np.mean(value_group[sensor][mouse]))
    return avg_group

def trialwise_filter_average(value_group, spike_group, threshold=10):
    avg_group = {}
    for sensor, mice in value_group.items():
        avg_group[sensor] = []
        for mouse in mice:
            spike_number_filter = np.array(spike_group[sensor][mouse]) > threshold
            values = value_group[sensor][mouse]
            avg_group[sensor] += [val for val, con in zip(values, spike_number_filter) if con]
    return avg_group

In [None]:
group_key = {
    'ASAP5': [],
    'ASAP3': [],
    'JEDI2P': [],
    'JEDI1P': []    
}

In [None]:
dff_group = {
    'ASAP5': {},
    'ASAP3': {},
    'JEDI2P': {},
    'JEDI1P': {}
}

snr_group = {
    'ASAP5': {},
    'ASAP3': {},
    'JEDI2P': {},
    'JEDI1P': {}
}

spike_number_group = {
    'ASAP5': {},
    'ASAP3': {},
    'JEDI2P': {},
    'JEDI1P': {}
}

trace_removed = {
    'ASAP5': [],
    'ASAP3': [],
    'JEDI2P': [],
    'JEDI1P': []    
}

In [None]:
CROPPING = 600

class TrialAverage(object):
    def __init__(self, working_dir) -> None:
        self.working_dir = working_dir
        self.dff_average = []
        self.SNR_average = []
        self.trial_list = []
        self.spike_number = []
        self.trace_removed = []
        self.flatten_wid = 250  # frames
        self.fr = 500  # frame rate
        self.get_trials()
    
    def get_trials(self):
        self.trial_list = glob.glob(os.path.join(self.working_dir, '*dff.mat'))
        self.trial_list.sort()

    def remove_trace(self, volpy_out):
        # removing traces with motion artifacts
        trace = -volpy_out['t'].ravel()
        trace = signal_filter(trace, 10, self.fr, mode='low')
        neg_peaks = trace[signal.find_peaks(trace, height=None)[0]]
        tolorence = np.quantile(volpy_out['t'].ravel()[volpy_out['spikes']], 0.2)
        N_neg_peaks = np.sum(neg_peaks > tolorence)
        #print(N_neg_peaks)
        #print(tolorence)
        return N_neg_peaks > np.min([20, len(volpy_out['spikes'].ravel())])
    
    def get_average(self):
        for img_path in self.trial_list:
            trial_name = os.path.basename(img_path).split('_moco')[0]
            print(trial_name)
            img = sio.loadmat(img_path)
            dff = img['dff'].ravel()
            f_trace = img['raw_trace'].ravel()
            dff = dff[CROPPING:]  # consistent to volpy cropping
            f_trace = f_trace[CROPPING:]
            baseline = flatten(f_trace, self.flatten_wid)
            f_trace_flat = f_trace - baseline

            spike_path = os.path.join(self.working_dir, 'volpy_results', trial_name + '_moco_volpy.pkl')
            with open(spike_path, 'rb') as f:
                volpy_out = pickle.load(f)
                spikes = volpy_out['spikes']
            if not self.remove_trace(volpy_out):
                self.spike_number.append(len(spikes.ravel()))
                self.dff_average.append(dff[spikes].mean())
                self.SNR_average.append(spike_SNR(-f_trace_flat, spikes, fr=self.fr))   
            else:
                print(trial_name + ' trace removed due to motion!')
                self.trace_removed.append(trial_name)

In [None]:
data_dir = './Real analysis_traces flattened/'
for sensor, mice in group_key.items():
    for mouse in mice:
        print(mouse)
        mouse_dir = os.path.join(data_dir, mouse)
        trial_avg = TrialAverage(mouse_dir)
        trial_avg.get_average()
        dff_group[sensor][mouse] = trial_avg.dff_average
        snr_group[sensor][mouse] = trial_avg.SNR_average
        spike_number_group[sensor][mouse] = trial_avg.spike_number
        trace_removed[sensor] += trial_avg.trace_removed

In [None]:
trialwise_average(snr_group)

In [None]:
trialwise_filter_average(snr_group, spike_number_group)

In [None]:
removed_trace_path = './removed_traces'
if not os.path.exists(removed_trace_path):
    os.makedirs(removed_trace_path)

remove_trial_list = list(chain(*trace_removed.values()))
for trial in remove_trial_list:
    mouse = trial.split(' ')[0]
    trial_figure_src = os.path.join(data_dir, mouse+'/volpy_results/', trial + '_moco_summary.png')
    # copy trial_figure from src to removed_trace_path
    shutil.copy(trial_figure_src, removed_trace_path)