# About

This notebook is a template for calculating orientation selectivity in pyramidal cells. These are the steps:

	1. Rigid motion correction for the full frame: Matlab (Kaspar's code).
	2. Average the first 50 frames (after discarding LED off time) and present to the user to draw an ROI: Matlab (Kaspar's code). 
	3. Initial trace is average over user-drawn ROI. Background trace is average over not ROI. Export traces to jupyter notebook.  
	4. De-trending: Find F0 by doing a lowpass filter (F0_t = smooth(trace', 1000/length(trace),'lowess')’;) and subtract it from the trace. Do this also for a 'background' trace.
	5. High-pass filter the de-trended traces (foreground and background) by taking an FFT and discarding 0.5% of the lowest frequencies
	6. Spike detection: 

		1. Smooth data for initial detection
		2. Finding a good threshold for spike detection:  For a particular threshold, detect spikes as peaks above the threshold that are spaced at a minimum distance of 4 points (should also be dependent on frame rate). Then do this over a range of thresholds and plot the number of spikes detected as a function of threshold. Find the inflection point (maximum slope) and choose this as your threshold. 
	7. Aligning to visual stimulus: Create vector of visual stimulus on time for each of the eight conditions. Figure out how to do this in a smart way. 
	8. Bin the spikes according to trials and orientation. Plot the orientation selectivity (with error bars)

Most of these are translated from Kaspar's Matlab code.

# Setup

## Imports

In [13]:
from PIL import Image
import math
import matplotlib.path as mplPath
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import pickle
import pylab as pl
import random
np.random.seed(1)
import re
import scipy.io as spio
from scipy.optimize import curve_fit
from scipy.signal import medfilt
from scipy.signal import savgol_filter
import sys
import time



## Specify parameters

In [15]:
animal_ID = 402362
LED_off_removed = 1 # To monitor if LED off time has been removed from the traces

if animal_ID == 402362:
    cells = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    frame_rates = [500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 400, 400, 400, 400]
elif animal_ID == 402361:
    cells = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
    frame_rates = [510, 500, 500, 500, 500, 500, 500, 500, 500, 500, 666, 500, 400, 450, 450, 500, 500,
                  500, 500, 500, 500, 450, 500, 500]

%matplotlib qt

animal_folder = 'E:\ST_Voltron\{0}'.format(animal_ID)
#animal_folder = 'E:\ST_Voltron\{0}\{1}'.format(animal_ID, 20180205)
daq_folder = 'Z:\ST-Voltron_DAQ' # Where camera TTLs and vis stim timings are stored

num_cells = len(cells)

num_preview = 100 # Number of frames to average for initial ROI drawing

binning = 4*4 # Camera setting - number of pixels binned while imaging
offset = 100 # Camera offset per pixel (in bits)

# Spike detection parameters
num_window = 20 # Size of window for detrending by piecewise linear fit = num_frames/num_window
freq_discard = 50 # Highest frequency in Hz to discard for high-pass filtering
num_thresh_test = 1000 # Number of spike detection thresholds to test between min and max
fpr = 0.2 # Desired false positive rate as a fraction of detected spikes
sta_time = 40 # Time in ms over which spike triggered average is calculated
min_spikes = 100 # Minimum number of spikes detected initially

# Visual stimulus parameters
dur = 400 # Approximate duration (greater than actual) of recordings in seconds
num_ori = 8
vis_freq = 1 # Temporal frequency of visual stimulus in Hz
vis_on = 1 # Stimulus on time in seconds
vis_off = 1 # Stimulus off time in seconds
num_vis_stim = int(dur/(vis_on + vis_off)) # Approximate (more than actual)
max_ori_pos = 3
ori_degrees = [-135, -90, -45, 0, 45, 90, 135, 180]

# Function definitions

## Data input

### File name sorting

In [17]:
# Sort the file names of frames in a natural ascending order without requiring 
# leading zeros 
# C is a list of filenames of individual frames to be sorted
# Frame number is the last run of digits in the filename

def filename_sort(animal_folder, cell_id):
    # Get list of filenames for frames (sorted)
    cell_folder = '{0}\Cell{1}'.format(animal_folder, cell_id)
    filenames = [f for f in os.listdir(cell_folder) if f.endswith('.tif')]
    num_frames = len(filenames)
    filenames = sort_nat_ascend(filenames, num_frames)
    return filenames, num_frames

def sort_nat_ascend(C, num_frames):
    
    C2 = [re.sub('\d', '0', str) for str in C] # Replace runs of digits with zeros
    C3 = [np.array(list(str)) for str in C2] # Convert each string to an array of chars
    digits = [str == '0' for str in C3] # Positions of digits 
    
    # Extract the start and end indices of the last run of digits (frame number)
    end = np.array([np.squeeze(digit.nonzero())[-1]  for digit in digits]) + 1
    start = [end[frame] - np.argmax(digits[frame][(end[frame] - 1):0:-1] == False) for frame in range(num_frames)]
    
    # Sort by numerical values of frame numbers
    frame_numbers = [int(C[frame][start[frame]:end[frame]]) for frame in range(num_frames)]
    return [str for _,str in sorted(zip(frame_numbers,C))] 
    

### ROI drawing and saving metadata

In [19]:
# This function also serves as the point at which all 
# metadata about the cell is saved. 

def get_ROI(animal_folder, cell_id, num_frames, num_preview, 
            filenames, frame_rate, offset, binning):
    
    
    # Make preview file - 100 frames evenly spaced in the session
    fnums = np.round(np.linspace(0, num_frames - 1, 
                                 num_preview)).astype(int)

    cell_folder = '{0}\Cell{1}'.format(animal_folder, cell_id)
    im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[0]]))
    w, h = im.size
    
    preview_array = np.zeros([h, w, num_preview])

    for frame in range(num_preview):
        im = Image.open('{0}\{1}'.format(cell_folder, 
                                         filenames[fnums[frame]]))
        preview_array[:, :, frame] = np.array(im)
    preview = np.mean(preview_array, 2)
    
    # Get ROI
    %matplotlib qt
    pl.imshow(preview)
    plt.title('Click to draw polygon ROI and doubleclick when done')
    my_roi = roipoly(roicolor='r') # draw new ROI in red color
    plt.pause(20)
    mask = my_roi.getMask(preview)
    
    cols_mask = np.array(mask.nonzero())[1]
    rows_mask = np.array(mask.nonzero())[0]
    left_edge = np.min(cols_mask) - 10 if np.min(cols_mask) > 10 else 0
    right_edge = np.max(cols_mask) + 10 if np.max(cols_mask) + 10 < w else w
    top_edge = np.min(rows_mask) - 10 if np.min(rows_mask) > 10 else 0
    bottom_edge = np.max(rows_mask) + 10 if np.max(rows_mask) + 10 < h else h
    
    area = np.sum(mask)
    area_bg = w*h - area
    
    # Time vector
    time_vec = np.arange(num_frames)/frame_rate
    
    # Save data
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                        cell_id), 'wb') as f:
        pickle.dump({'mask':mask, 'h':h, 'w':w, 'num_frames':num_frames, 
                     'filenames':filenames, 'area':area, 'area_bg':area_bg,
                     'frame_rate':frame_rate, 'time_vec':time_vec,
                    'offset': offset, 'binning':binning, 'top_edge': top_edge, 'bottom_edge': bottom_edge,
                   'left_edge': left_edge, 'right_edge': right_edge, 'cell_folder': cell_folder}, f)
    
    # Ask if there are more cells 
    next_cell = 1
    while (next_cell):
        next_cell = int(input('Do you want to segment another cell? Enter 1 if yes, 0 if no'))
        if next_cell:
            
            %matplotlib qt
            pl.imshow(preview)
            plt.title('Click to draw polygon ROI and doubleclick when done')
            my_roi = roipoly(roicolor='r') # draw new ROI in red color
            plt.pause(20)
            mask = my_roi.getMask(preview)

            area = np.sum(mask)
            area_bg = w*h - area
            
            prev_cell_id = cell_id
            ind = cells.index(prev_cell_id) + 1
            cell_id = int(input('Enter cell id (previous cell was {0})'.format(cell_id)))
            cells.insert(ind, cell_id)
            
            with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                        cell_id), 'wb') as f:
                pickle.dump({'mask':mask, 'h':h, 'w':w, 'num_frames':num_frames, 
                     'filenames':filenames, 'area':area, 'area_bg':area_bg,
                     'frame_rate':frame_rate, 'time_vec':time_vec,
                    'offset': offset, 'binning':binning, 'cell_folder': cell_folder}, f)
               
    

In [21]:
# This code is from https://github.com/jdoepfert, 
#which can be used without permission

class roipoly:
    def __init__(self, fig=[], ax=[], roicolor='b'):
        if fig == []:
            fig = plt.gcf()

        if ax == []:
            ax = plt.gca()

        self.previous_point = []
        self.allxpoints = []
        self.allypoints = []
        self.start_point = []
        self.end_point = []
        self.line = None
        self.roicolor = roicolor
        self.fig = fig
        self.ax = ax
        #self.fig.canvas.draw()

        self.__ID1 = self.fig.canvas.mpl_connect(
            'motion_notify_event', self.__motion_notify_callback)
        self.__ID2 = self.fig.canvas.mpl_connect(
            'button_press_event', self.__button_press_callback)

        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()

    def getMask(self, currentImage):
        ny, nx = np.shape(currentImage)
        #print(self.allxpoints)
        poly_verts = [(self.allxpoints[0], self.allypoints[0])]
        for i in range(len(self.allxpoints)-1, -1, -1):
            poly_verts.append((self.allxpoints[i], self.allypoints[i]))

        # Create vertex coordinates for each grid cell...
        # (<0,0> is at the top left of the grid in this system)
        x, y = np.meshgrid(np.arange(nx), np.arange(ny))
        x, y = x.flatten(), y.flatten()
        points = np.vstack((x,y)).T

        ROIpath = mplPath.Path(poly_verts)
        grid = ROIpath.contains_points(points).reshape((ny,nx))
        return grid

    def displayROI(self,**linekwargs):
        l = plt.Line2D(self.allxpoints +
                     [self.allxpoints[0]],
                     self.allypoints +
                     [self.allypoints[0]],
                     color=self.roicolor, **linekwargs)
        ax = plt.gca()
        ax.add_line(l)
        plt.draw()

    def displayMean(self,currentImage, **textkwargs):
        mask = self.getMask(currentImage)
        meanval = np.mean(np.extract(mask, currentImage))
        stdval = np.std(np.extract(mask, currentImage))
        string = "%.3f +- %.3f" % (meanval, stdval)
        plt.text(self.allxpoints[0], self.allypoints[0],
                 string, color=self.roicolor,
                 bbox=dict(facecolor='w', alpha=0.6), **textkwargs)

    def __motion_notify_callback(self, event):
        if event.inaxes:
            ax = event.inaxes
            x, y = event.xdata, event.ydata
            # Move line around
            if (event.button == None or event.button == 1) and self.line != None: 
                self.line.set_data([self.previous_point[0], x],
                                   [self.previous_point[1], y])
                self.fig.canvas.draw()


    def __button_press_callback(self, event):
        if event.inaxes:
            x, y = event.xdata, event.ydata
            ax = event.inaxes
            # If you press the left button, single click
            if event.button == 1 and event.dblclick == False:  
                if self.line == None: # if there is no line, create a line
                    self.line = plt.Line2D([x, x],
                                           [y, y],
                                           marker='o',
                                           color=self.roicolor)
                    self.start_point = [x,y]
                    self.previous_point =  self.start_point
                    self.allxpoints=[x]
                    self.allypoints=[y]

                    ax.add_line(self.line)
                    self.fig.canvas.draw()
                    # add a segment
                else: # if there is a line, create a segment
                    self.line = plt.Line2D([self.previous_point[0], x],
                                           [self.previous_point[1], y],
                                           marker = 'o',
                                           color=self.roicolor)
                    self.previous_point = [x,y]
                    self.allxpoints.append(x)
                    self.allypoints.append(y)

                    event.inaxes.add_line(self.line)
                    self.fig.canvas.draw()
            elif ((event.button == 1 and event.dblclick==True) or (event.button == 3 and event.dblclick==False)) and self.line != None: # close the loop and disconnect
                    self.fig.canvas.mpl_disconnect(self.__ID1) #joerg
                    self.fig.canvas.mpl_disconnect(self.__ID2) #joerg

                    self.line.set_data([self.previous_point[0],
                                    self.start_point[0]],
                                   [self.previous_point[1],
                                    self.start_point[1]])
                    ax.add_line(self.line)
                    self.fig.canvas.draw()
                    self.line = None

                    if sys.flags.interactive:
                        pass
                    else:
                        #figure has to be closed so that code can continue
                        plt.close(self.fig) 

### Loading data

In [23]:
def trace_from_movie(animal_folder, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                                cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; 
    num_frames = md['num_frames']; filenames = md['filenames']; 
    mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
    offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; cell_folder = md['cell_folder']
    
    # Load movie tifs
    data_array = np.zeros([h, w, num_frames])
    for frame in range(num_frames):
        if(np.mod(frame, 1000) == 0):
                    print(frame)
        im = Image.open('{0}\{1}'.format(cell_folder, filenames[frame]))
        data_array[:, :, frame] = np.array(im)
    
    # Use ROI mask to get data and background traces
    print('Calculating data trace')
    data = np.sum(data_array[mask, :], 0)
    #print('Calculating background trace')
    #background = np.sum(data_array[~mask, :], 0)
    background = np.zeros(data.shape)
    
    del data_array # Free up memory

    # Go from pixel intensity to electrons
    data = np.subtract(data, area*offset*binning)*0.48
    background = np.subtract(background, area_bg*offset*binning)*0.48

    # Save data and background traces
    with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                               cell_id), 'wb') as f:
        pickle.dump({'data':data, 'background':background}, f)

    
    

### Remove LED_off

In [25]:
def remove_LED_off(animal_folder, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                                cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; 
    num_frames = md['num_frames']; filenames = md['filenames']; 
    mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
    offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; cell_folder = md['cell_folder']
    
    # Load data
    with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                               cell_id), 'rb') as f:
        data_dict = pickle.load(f) 
        
    data = data_dict['data']
    background = data_dict['background']
    
    # Plot data 
    %matplotlib qt
    plt.plot(range(1, 10*frame_rate + 1), data[:10*frame_rate], 
             label = 'data')
    plt.title('Cell {0}: raw data from user drawn ROI'.format(cell_id))                                       
    plt.xlabel('Frames')
    plt.ylabel('Electrons summed over mask')
    plt.grid()
    plt.pause(2)
    
    LED_off = int(input('Cell {0}: how many frames to delete from beginning?'.format(cell_id)))
    
    # Remove LED off frames
    data_dict['data'] = data[LED_off:]
    data_dict['background'] = background[LED_off:]
    
    # Save modified traces
    with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, 
                                                cell_id), 'wb') as f:
        pickle.dump(data_dict, f)
    
    # Save changed metadata
    md['num_frames'] = len(filenames) - LED_off
    md['time_vec'] = np.arange(md['num_frames'])/frame_rate
    md['LED_off'] = LED_off

    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                                cell_id), 'wb') as f:
        pickle.dump(md, f)
       
    

## Spike detection

### Manual spike annotation

In [27]:
class manual_spikes:
    
    def __init__(self, animal_folder = '', cell_id = 2):  

        print('Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike and q to exit.')
        
        # Plot the detrended trace
        marker_type = None; x_vals = 'frames'
        F0, window, self.data = detrend(animal_folder, cell_id, num_window, marker_type, x_vals)
        plt.pause(10) 
        print('Ready')

        # Find peaks (local minima in windows of three points) in the data
        dif = np.diff(self.data[:- 1], 1) # Discrete difference, ignoring last frame
        # Boolean array, true for points lower than the preceeding point
        diff_pos = dif < 0 
        # Reverse discrete difference, ignoring the first frame
        dif_rev = np.flip(np.diff(np.flip(self.data, 0)[:-1], 1), 0) 
        # Boolean array, true for points lower than the following point
        diff_rev_pos = dif_rev < 0 
        peaks = np.logical_and(diff_pos, diff_rev_pos)

        # Take care of the length of peaks being less than the length of data:
        peaks = np.insert(peaks, 0, False) 
        peaks = np.append(peaks, False)

        self.peak_pos = peaks.nonzero()[0]
        self.data_peaks = self.data[peaks]

        self.spike_frames_manual = np.zeros(0)

        self.next_click_is_spike = False
        self.spikes_are_done = False
        self.click = False
        self.approx_frame = 0
        self.approx_height = 0
        self.delete = False
        
        fig = plt.gcf()

        self.click_id = fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.key_id = fig.canvas.mpl_connect('key_press_event', self.keypress)
        
        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()
        

    def keypress(self, event):
        if event.key == ' ':
            self.next_click_is_spike = True
        elif event.key == 'd':
            np.delete(self.spike_frames_manual, -1)
            self.spike_pt.remove()
            self.near_pts.remove()
            print('Last spike deleted')
        elif event.key == 'q':
            self.spikes_are_done = True

    def onclick(self, event):
        ax = plt.gca()
        if self.next_click_is_spike:
            self.approx_frame = event.xdata
            self.approx_height = event.ydata
            
            self.near_peaks = self.peak_pos[np.argpartition(np.abs(self.peak_pos - self.approx_frame), 20)[:20]]
            #self.near_pts = ax.scatter(self.near_peaks, -self.data[self.near_peaks], color = 'k')
            self.spike_frame = self.near_peaks[np.argmin(np.power(self.data[self.near_peaks] + self.approx_height, 2) + 
                                               np.power(self.near_peaks - self.approx_frame, 2))]           
            #print('Spike of height {0} at frame {1}'.format(-self.data[self.spike_frame], self.spike_frame))
            self.spike_pt = ax.scatter(self.spike_frame, -self.data[self.spike_frame], color = 'r')
            self.spike_frames_manual = np.append(self.spike_frames_manual, self.spike_frame)
            self.next_click_is_spike = False      
     
    def get_spikes(self):
        #num_spikes_manual = len[self.spike_frames_manual]
        return self.spike_frames_manual
    
        

### De-trending

In [29]:
def detrend(animal_folder, cell_id, num_window, marker_type, x_vals): 
    # num_window should be about 20
    # x_vals should be 'time' or 'frames' depending on what the detrended data should be plotted against
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load data
    with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id), 'rb') as f:
        data_dict = pickle.load(f) # Data
        
    data = data_dict['data']
    background = data_dict['background']
    
    # Fit a piecewise linear curve to the data trace (F0)
    window = int(num_frames/num_window) # Should be made dependent on frame rate
    if window%2 == 0:
        window += 1
    poly = 1 # Degree of polynomial to fit
    F0 = savgol_filter(data, window, poly)    

    # Plot F0
    % matplotlib qt
    plt.figure()
    plt.plot(time_vec, data, label = 'Raw data', color = 'orange', linewidth = 0.8)
    plt.plot(time_vec, F0, label = 'F0', color = 'b', linewidth = 2)
    plt.legend(loc = 'best')
    plt.title('Cell {1}: smoothed data with window size {0} frames'.format(window, cell_id))
    plt.xlabel('Time (s)')
    plt.ylabel('Electrons summed over mask')
    plt.savefig('{0}\Traces_2\Cell{1}_F0'.format(animal_folder, cell_id))
    
    # Calculate dF
    #data2 = np.divide(data - F0, F0)*100
    data2 = data - F0
    # Plot dF/F
    % matplotlib qt
    plt.figure()
    if x_vals == 'time':
        plt.plot(time_vec, -data2, marker = marker_type, linewidth = 0.7)
        plt.xlabel('Time (s)')
    elif x_vals == 'frames':
        plt.plot(-data2, marker = marker_type, linewidth = 0.7)
        plt.xlabel('Frames')
    #plt.yticks(-np.array(plt.yticks())[0])
    plt.title('Cell {1} detrended data (F0 subtracted, not divided) with window size {0} frames'.format(window, cell_id))
    plt.ylabel('dF (Electrons summed over mask)')
    #plt.savefig('{0}\Traces_2\Cell{1}_dFF'.format(animal_folder, cell_id))
    
    return F0, window, data2


### High-pass filter

In [31]:
def high_pass(animal_folder, cell_id, data2, freq_discard, x_vals): # freq_discard should be 5
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    data2_fft = np.fft.fft(data2)
    num_freq_discard = np.where(np.fft.fftfreq(num_frames, 1/frame_rate) == freq_discard)
    data2_fft[0:num_freq_discard] = 0 # Positive frequencies zeroed
    data2_fft[num_frames - num_freq_discard:] = 0 # Negative frequencies zeroed
    data3 = np.real(np.fft.ifft(data2_fft))

    # Plot high pass filtered trace
    % matplotlib qt
    if x_vals == 'time':
        plt.plot(time_vec, -data3, linewidth = 0.5)
        plt.xlabel('Time (s)')
    elif x_vals == 'frames':
        plt.plot(-data3, linewidth = 0.5)
        plt.xlabel('Frames')

    plt.yticks(-np.array(plt.yticks())[0])
    plt.title('Cell {1} high pass filtered data (lowest {0}% of freqs discarded)'.format(freq_discard, cell_id))
    plt.ylabel('dF (Electrons summed over mask)')
    plt.savefig('{0}\Traces_2\Cell{1}_high_pass.png'.format(animal_folder, cell_id))
    
    return data3, freqs

### Detect spikes

#### Select threshold

In [32]:
# num_test: number of thresholds to test, about 1000 is good
# fpr = desired false positive rate in percentage

def select_threshold(animal_folder, cell_id, data3, num_test, fpr, min_spike_count):

    data4 = - data3 # Just to make things easier to read

    mean = np.mean(data4)
    thresh_max = np.max(data4)
    thresh_min = 0.2*thresh_max # This will be around zero
    thresh_vals = np.arange(thresh_min, thresh_max, (thresh_max - thresh_min)/num_test)

    # Find peaks (local minima in windows of three points) in the data
    dif = np.diff(data4[:- 1], 1) # Discrete difference, ignoring last frame
    # Boolean array, true for points greater than the preceeding point
    diff_pos = dif > 0 

    # Reverse discrete difference, ignoring the first frame
    dif_rev = np.flip(np.diff(np.flip(data4, 0)[:-1], 1), 0) 
    # Boolean array, true for points greater than the following point
    diff_rev_pos = dif_rev > 0 

    peaks = np.logical_and(diff_pos, diff_rev_pos)
    troughs = np.logical_and(~diff_pos, ~diff_rev_pos)

    # Take care of the length of peaks being less than the length of data:
    peaks = np.insert(peaks, 0, False) 
    peaks = np.append(peaks, False)

    troughs = np.insert(troughs, 0, False) 
    troughs = np.append(troughs, False)

    spike_count = [np.sum(data4[peaks] > thresh) for thresh in thresh_vals]
    neg_spike_count = [np.sum(data4[troughs] < 2*mean-thresh) for thresh in thresh_vals]

    fpos = np.divide(neg_spike_count, spike_count)
    if np.min(fpos) > fpr:
        print('Minimum false positive rate is {0}%'.format(int(100*np.min(fpos))))
        thresh_opt = thresh_vals[(fpos == np.min(fpos)).nonzero()[0][0]]                      
    else:
        # Threshold at which false positive rate falls below criterion
        thresh_opt = thresh_vals[(fpos < fpr).nonzero()[0][0]] 
        
#     # If there are too few spikes for this threshold, raise threshold
#     if spike_count[(thresh_vals == thresh_opt).nonzero()[0][0]] < min_spike_count:
#         if np.max(spike_count) > min_spike_count:
#             # Highest threshold which gives at least 30 spikes
#             thresh_opt = thresh_vals[(np.array(spike_count)>min_spike_count).nonzero()[0][-1]]
#             fpr = fpos[(thresh_vals == thresh_opt).nonzero()[0][0]]
#             print('False positive rate changed to {0}% to have at least {1} spikes'.format(int(100*fpr), min_spikes))
#         else:
#             print('There are less than 30 spikes for all thresholds greater than {0}% dF/F'.format(thresh_min))
        
        
    %matplotlib qt
    plt.figure()
    #plt.plot(thresh_vals/thresh_max, spike_count)
    plt.plot(thresh_vals, spike_count, label = 'Positive spikes')
    plt.plot(thresh_vals, neg_spike_count, label = 'Negative spikes')
    plt.plot(np.ones(1000)*thresh_opt, np.arange(0, 4000, 4), 
             label = 'Selected threshold with {0}% false positive rate'.format(int(fpr*100)))
    plt.legend(loc = 'best', fontsize = 10)
    plt.title('Cell {0}: spikes detected vs detection threshold'.format(cell_id), fontsize = 10)
    plt.xlabel('Detection threshold (electrons)', fontsize = 15)
    plt.ylabel('Number of spikes', fontsize = 15)
    plt.savefig('{0}\Traces_2\Cell{1}_spikes_vs_thresh.png'.format(animal_folder, cell_id))
    
    return peaks, data4, thresh_opt, fpr

#### Find spike times

In [33]:
def plot_spikes(animal_folder, cell_id, peaks, data4, F0, thresh_opt):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                                cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; 
    num_frames = md['num_frames']; filenames = md['filenames']; 
    mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
    offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; 
    LED_off = md['LED_off']
    
    spikes_binary = np.logical_and(peaks, data4 > thresh_opt)
    spike_frames = np.squeeze(spikes_binary.nonzero())
    spike_times = time_vec[spikes_binary]
    num_spikes = np.sum(spikes_binary)

    # Make plot
    data_plot = np.divide(data4, F0)*100
    %matplotlib qt
    plt.figure()
    plt.plot(time_vec, data_plot, linewidth = 1, label = 'High pass filtered data',
             #marker = 'o'
            )
    plt.scatter(spike_times, (np.ones(spike_times.shape)*np.max(data_plot)), color = 'k')

    plt.yticks(-np.array(plt.yticks())[0])
    plt.title('Cell %d spikes' %cell_id, fontsize = 20)
    plt.xlabel('Time (s)', fontsize = 17)
    plt.ylabel('dF/F (%)', fontsize = 17)
    plt.grid()
    plt.legend(fontsize = 17)
    plt.savefig('{0}\Traces_2\Cell{1}_spikes'.format(animal_folder, cell_id))
    
    return num_spikes, spike_frames

### Spike triggered average

In [34]:
def find_sta(animal_folder, cell_id, sta_time, data2, F0, spike_frames, num_spikes, fpr):

    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                                        cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; 
    num_frames = md['num_frames']; filenames = md['filenames']; 
    mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
    offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; 
    LED_off = md['LED_off']
    
    sta_frames = int(sta_time*frame_rate/1000)
    sta_time_vec = np.linspace(-sta_time, sta_time, sta_frames*2)
    data2 = np.divide(data2, F0)*100
    sta_vals = [data2[(frame - sta_frames):(frame + sta_frames)] 
                for frame in spike_frames if frame > sta_frames and 
                frame < num_frames - sta_frames]                              
    sta = -np.mean(sta_vals, 0) # Negated

    sta_error = np.std(sta_vals, 0)


    %matplotlib qt
    plt.figure()
    plt.plot(sta_time_vec, sta, linewidth = 1, color = 'k')
    plt.fill_between(sta_time_vec, sta + sta_error, sta - sta_error, 
                     where=sta - sta_error <= sta + sta_error, 
                     facecolor='blue', alpha = 0.2, interpolate=True, 
                     label = 'Standard deviation')
    plt.title('Cell {0} spike triggered average from {1} spikes ({2}% false positive rate)'.format(cell_id, 
                                                               num_spikes, int(fpr*100), fontsize = 20))
    plt.xlabel('Time from spike (ms)', fontsize = 17)
    plt.ylabel('dF/F (%)', fontsize = 17)
    plt.yticks(-np.array(plt.yticks())[0].astype(float))
    plt.grid()
    plt.legend(loc = 'best')
    plt.grid()
    
    #plt.savefig('{0}\Traces_2\Cell{1}_sta_fpr_{2}'.format(animal_folder, cell_id, int(fpr*100)))
    
    return sta, sta_error

### Temporal filter

In [35]:
def temp_filter(animal_folder, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load spiking data
    with open('{0}\Traces\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    F0 = sd['F0']; spike_frames = sd['spike_frames']; sta = -sd['sta']
    num_spikes = sd['num_spikes']; data_detrend = sd['data_detrend']
    
    # Matched filter - convolve spike template with raw data
    
    # Matched filter parameters
    Rv = np.mean(np.multiply(data_detrend, data_detrend)) # Expected value of cross-correlation of noise
    a = 1/np.sqrt(np.multiply(sta, sta)/Rv) # Normalization coefficient
    h = a*sta/Rv # Filter kernel
    
    data_filt = np.convolve(data_detrend, h, mode = 'same')
    
    return - data_filt

### Spatial filter

In [None]:
def sp_filter(animal_folder, cell_id):
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load spiking data
    with open('{0}\Traces\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    F0 = sd['F0']; spike_frames = sd['spike_frames']; sta = sd['sta']
    num_spikes = sd['num_spikes']; data_detrend = -sd['data_detrend']

## Orientation selectivity

### Get visual stimulus and camera timings

In [36]:
# 'cell' should be between 0 and num_cells - 1
# 'marker_style' should be 'o' or None
# 'offset' is a period with inncorrect TTLs to be ignored in seconds
# 'ax' is the plotting axes handle

def onclick(event):
    print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
          ('double' if event.dblclick else 'single', event.button,
           event.x, event.y, event.xdata, event.ydata))

    cid = fig.canvas.mpl_connect('button_press_event', onclick)

def get_daq(animal_ID, daq_folder, cell_id):
    
    filename = '{0}\ANM{1}_cell{2}_DAQ'.format(daq_folder, 
                                               animal_ID, cell_id)
    mat = spio.loadmat(filename, squeeze_me=True)
    
    vis_stim = np.array(mat['data'][:, 0])
    camera_output = np.array(mat['data'][:, 1])
    daq_rate = mat['rate']
    daq_time_vec = mat['timeStamps']
    
    return vis_stim, camera_output, daq_rate, daq_time_vec
    
def plot_camera(animal_folder, cell_id, start_time, end_time, marker_style,
               daq_rate, daq_time_vec, camera_output):
    
    start_point = int(start_time*daq_rate)
    end_point = int(end_time*daq_rate)
    
    %matplotlib qt
    plt.figure()
    plt.plot(
            daq_time_vec[start_point:end_point],
            camera_output[start_point:end_point],
            marker = marker_style
            )
    plt.title('Cell {0} camera TTLs'.format(cell_id))
    plt.pause(5)
    offset = int(input('After what time to look for the first frame?'))
    plt.savefig('{0}\Vis_stim_info\Cell{1}_camera.png'.format(animal_folder, cell_id))
    
    return offset

    
def find_frame_times(cell_id, offset, camera_output, daq_rate):

    camera_output = camera_output[int(offset*daq_rate):]
    
    high_pts = np.array((camera_output > 0.2).nonzero())
    camera_output[high_pts] = 1
    
    low_pts = np.array((camera_output < 0.2).nonzero())
    camera_output[low_pts] = 0
    
    dif = np.diff(camera_output, 1)
    # Camera frame on times, in terms of daq data points
    frame_times = np.array((dif == 1).nonzero()) + offset*daq_rate
    
    
    %matplotlib qt
    plt.figure()
    
    plt.plot(daq_time_vec[int(offset*daq_rate):], camera_output)
    plt.scatter(daq_time_vec[frame_times], np.ones(frame_times.shape)*np.max(camera_output)*1.1, 
                color = 'k', marker = 'o')
    plt.title('Cell {0}  frames'.format(cell_id))
    
    return frame_times

def plot_vis(animal_folder, cell_id, start_time, end_time, marker_style, 
             daq_rate, vis_stim):

    start_point = int(start_time*daq_rate)
    end_point = int(end_time*daq_rate)       
    
    %matplotlib qt
    plt.figure()
    plt.plot(range(start_point, end_point), 
             vis_stim[start_point:end_point],
            marker = marker_style
            )
    plt.title('Cell {0} visual stimulus'.format(cell_id))
    plt.pause(30)
    
    plt.savefig('{0}\Vis_stim_info\Cell{1}_vis_stim.png'.format(animal_folder, cell_id))
    
    
def vis_start(cell_id, daq_rate, num_vis_stim, vis_on, vis_off, 
              frame_times, num_ori):
    
    # Ask user for visual stimulus start point
    vis_stim_start_point = int(input('Enter visual stimulus start frame'))
     
    # Find the time and orientation of first visual stimulus during imaging
    vis_stim_start_time = vis_stim_start_point/daq_rate

    # Jitter in the DAQ acquisition times is less than 5e-14 seconds, 
    #hence we can do this:
    trial_start_times = np.arange(vis_stim_start_point, 
                             num_vis_stim*((vis_on + vis_off)*daq_rate), 
                             (vis_on + vis_off)*daq_rate)

    pre_imaging_stim = (trial_start_times < frame_times[0][0]).nonzero()
    first_stim_num = pre_imaging_stim[0][-1]
    first_stim_time = trial_start_times[first_stim_num]
    # Ori goes from 0 to 7
    first_ori = np.mod(first_stim_num, num_ori - 1)
    
    return vis_stim_start_point, vis_stim_start_time, trial_start_times, first_stim_num, first_stim_time, first_ori
    
def plot_vis_start_times(cell_id, daq_time_vec, vis_stim, stim_on_time):
        
    plt.plot(daq_time_vec, vis_stim)

    stim_on = np.zeros(vis_stim.shape) + 0.15
    stim_on[stim_on_time.astype(int)] = 0.25

    %matplotlib qt
    plt.figure()
    plt.plot(daq_time_vec, stim_on)
    plt.show()
    plt.title('Cell {0} vis stim trial start times'.format(cell_id))
    plt.pause(2)
    

### Bin frames by orientation

In [37]:
# Find the camera frames corresponding to trials and ITIs for each orientation
def find_ori_frames(animal_folder, cell_id):
    
    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    
    daq_rate= v['daq_rate']; trial_start_times = v['trial_start_times']; frame_times = v['frame_times']
    first_stim_time = v['first_stim_time']; first_ori = v['first_ori']; first_stim_num = v['first_stim_num']
    
#     vis_stim = v['vis_stim'], camera_output = v['camera_output'], 
#     daq_time_vec = v['daq_time_vec'], offset = v['offset'], 
#     vis_stim_start_point = v['vis_stim_start_point'], vis_stim_start_time = v['vis_stim_start_time']    
        
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # There are some extra TTLs as compared to actually recorded frames. 
    # These amount to about 40ms of imaging time.
    # We will assume, for now, that the first 40ms of TTLs are spurious. 
    l = len(frame_times[0])
    frame_times_trunc = frame_times[0][l - (num_frames + LED_off):]
    
    # Leave out the first trial because camera would have started in the middle
    first_trial_time = first_stim_time + (vis_on + vis_off)*daq_rate
    first_trial_num = first_stim_num + 1
    first_ori = np.mod(first_ori + 1, num_ori)
    
    first_trial_ori = np.roll(range(num_ori), first_ori)
    trial_times_ori = [trial_start_times[first_trial_ori[ori]::num_ori] for ori in range(num_ori)]
    frames_ori = np.zeros([num_ori, num_frames + LED_off])
    frames_iti_ori = np.zeros([num_ori, num_frames + LED_off])
    for ori in range(num_ori):    
        for time in trial_times_ori[ori]:
            frames_ori[ori][(frame_times_trunc > time).nonzero()[0]] = 1
            frames_ori[ori][(frame_times_trunc > time + vis_on*daq_rate).nonzero()[0]] = 0
            frames_iti_ori[ori][(frame_times_trunc > time - vis_off*daq_rate).nonzero()[0]] = 1
            frames_iti_ori[ori][(frame_times_trunc > time).nonzero()[0]] = 0
            
    # Find the number of times all orientations are repeated
    num_rep = int((trial_start_times.shape - first_trial_num)/num_ori)
    
    # Find the first camera frame for each repetition
    rep_dur = num_ori*(vis_on + vis_off)*daq_rate
    rep_times = np.arange(first_trial_time, first_trial_time + (num_rep + 2)*rep_dur, rep_dur) 
    rep_frames = np.zeros(num_rep + 1)
    for rep in range(num_rep + 1):
        try:
            rep_frames[rep] = (frame_times_trunc > rep_times[rep]).nonzero()[0][0]
        except IndexError as e:
            print(e)
            rep_frames[rep] = 0
        
    return [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
            first_trial_num, first_ori, num_rep, rep_frames]

### Bin spikes by orientation 

In [38]:
def find_ori_spikes(animal_folder, cell_id, frames_ori, iti_frames_ori, num_rep, rep_frames, manual):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load spiking data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    if manual:
        spike_frames = np.unique(sd['spike_frames_manual'].astype(int)) + LED_off
        num_spikes = sd['num_spikes_manual']
    else:
        spike_frames = sd['spike_frames'] + LED_off # Spike frames as if starting from first camera frame
        num_spikes = sd['num_spikes']
    
    
    spikes_ori = np.zeros([num_ori, num_rep])
    spikes_iti_ori = np.zeros([num_ori, num_rep])
    for ori in range(num_ori): 
        for rep in range(num_rep):
            spikes_rep = spike_frames[spike_frames > rep_frames[rep]]
            spikes_rep = spikes_rep[spikes_rep < rep_frames[rep + 1]]
            print(spikes_rep)
            spikes_ori[ori][rep] = len(frames_ori[ori][spikes_rep].nonzero()[0])
            spikes_iti_ori[ori][rep] = len(frames_iti_ori[ori][spikes_rep].nonzero()[0])
    
    return spikes_ori, spikes_iti_ori, num_spikes
            
    

### Plot orientation selectivity 

#### All cells

In [39]:
def plot_os_all_cells(animal_folder, spikes_ori, spikes_iti_ori, num_rep, normalized):
    # spikes_ori, spikes_iti_ori should be dictionaries with keys as cell_id 
    # num_rep should be a num_cellsX1 array
    
    %matplotlib qt
    plt.figure(figsize = [8, 6])
    
    labels = []
    
    for cell_id in cells:
        if cell_id == 15:
            continue
        try:
            mean_spikes_ori = np.mean(spikes_ori['{0}'.format(cell_id)], 1) # Average over repetitions
            sd_spikes_ori = np.std(spikes_ori['%d' %cell_id], 1) # Error bars
        except Exception as e:
            print(e)
            mean_spikes_ori = np.zeros(num_ori)
            sd_spikes_ori = np.zeros(num_ori)
        ori_max = np.argmax(mean_spikes_ori)
        ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
        if normalized:
            mean_spikes_ori = mean_spikes_ori/np.max(mean_spikes_ori)
            sd_spikes_ori = sd_spikes_ori/np.max(mean_spikes_ori)
            yl = 'Proportion of spikes in session' 
            
        else:
            yl = 'Number of spikes per trial'
        plt.errorbar(ori_degrees, mean_spikes_ori[ori_order], sd_spikes_ori[ori_order], marker = 'o')
        labels.append('Cell %d' %(cell_id))
    
    plt.title('Orientation selectivity from spikes, {0}-{1} repetitions'.format(np.min(num_rep), np.max(num_rep)))
    plt.xlim((-140, 350))
    plt.ylabel(yl)
    plt.xlabel('Degrees away from preferred orientation')
    plt.legend(labels, fontsize = 15)

In [40]:
def plot_os_all_cells_subth(animal_folder, data, frames_ori, frames_iti_ori, num_rep, normalized):
    # frames_ori, frames_iti_ori should be dictionaries with keys as cell_id 
    # num_rep should be a num_cellsX1 array
    # data should be detrended (dF/F), dictionary with keys as cell_id 
    
    %matplotlib qt
    plt.figure(figsize = [8, 6])
    
    labels = []
    
    for cell_id in cells:
        try:
            mean_spikes_ori = np.mean(spikes_ori['{0}'.format(cell_id)], 1) # Average over repetitions
            sd_spikes_ori = np.std(spikes_ori['%d' %cell_id], 1) # Error bars
        except Exception as e:
            print(e)
            mean_spikes_ori = np.zeros(num_ori)
            sd_spikes_ori = np.zeros(num_ori)
        ori_max = np.argmax(mean_spikes_ori)
        ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
        if normalized:
            mean_spikes_ori = mean_spikes_ori/np.max(mean_spikes_ori)
            sd_spikes_ori = sd_spikes_ori/np.max(mean_spikes_ori)
            yl = 'Proportion of spikes in session' 
            
        else:
            yl = 'Number of spikes per trial'
        plt.errorbar(ori_degrees, mean_spikes_ori[ori_order], sd_spikes_ori[ori_order], marker = 'o')
        labels.append('Cell %d' %(cell_id))
    
    plt.title('Orientation selectivity from spikes, {0}-{1} repetitions'.format(np.min(num_rep), np.max(num_rep)))
    plt.xlim((-140, 350))
    plt.ylabel(yl)
    plt.xlabel('Degrees away from preferred orientation')
    plt.legend(labels, fontsize = 15)

#### Single cell

In [41]:
def plot_os_single_cell(axes, animal_folder, cell_id, spikes_ori, spikes_iti_ori, num_rep, normalized, text, summary):
    
    try:
        mean_spikes_ori = np.mean(spikes_ori, 1) # Average over repetitions
        mean_spikes_iti = np.mean(spikes_iti_ori, 1) # Average over repetitions
        sd_spikes_ori = np.std(spikes_ori, 1) # Error bars
    except Exception as e:
        print(e)
        mean_spikes_ori = np.zeros(num_ori)
        sd_spikes_ori = np.zeros(num_ori)
        mean_spikes_iti = np.zeros(num_ori)
    if cell_id == 18:
        ori_max = 3
    else:
        ori_max = np.argmax(mean_spikes_ori)
    print('Preferred direction is {0}'.format(ori_max))
    ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
    if normalized:
        spikes_ori = np.divide(spikes_ori, np.max(spikes_ori, 0))
        mean_spikes_ori = mean_spikes_ori/np.max(mean_spikes_ori)
        mean_spikes_iti = mean_spikes_iti/np.max(mean_spikes_ori)
        sd_spikes_ori = sd_spikes_ori/np.max(mean_spikes_ori)
        yl = 'Proportion of spikes' 

    else:
        yl = 'Number of spikes'
    axes.plot(ori_degrees, mean_spikes_ori[ori_order], color = 'k', linewidth = 2, label = 'Vis stim')
    axes.plot(ori_degrees, mean_spikes_iti[ori_order], color = 'k', linestyle = '--', linewidth = 1.5, label = 'ITI')
    axes.plot(ori_degrees, spikes_ori[ori_order], color = 'grey', alpha = 0.4, linewidth = 1.5,
             label = 'Individual repetitions')
    
    if summary:
        if text:
            axes.set_title('Spikes')
    else:
        axes.set_title('Cell {0} orientation selectivity from spikes, {1} repetitions'.format(cell_id, num_rep))
        axes.set_ylabel(yl)
        axes.set_xlim((np.min(ori_degrees) - 10, np.max(ori_degrees) + 100))
        axes.set_xlabel('Degrees away from preferred orientation')
        axes.legend(loc = 'best')
        
    return ori_max

In [42]:
def plot_os_single_cell_subth(axes, animal_folder, cell_id, data, frames_ori, frames_iti_ori, rep_frames, ori_max,
                              num_rep, normalized, text, summary):
    # Data should not be negated (i.e. should be raw data detrended)
    
    frames_ori = frames_ori.astype(bool)
    frames_iti_ori = frames_iti_ori.astype(bool)
    
    subth_ori = np.zeros([num_ori, num_rep])
    subth_iti_ori = np.zeros([num_ori, num_rep])
    
    for rep in range(num_rep - 1):
        data_rep = - data[rep_frames[rep]:rep_frames[rep + 1] - 1]       
        for ori in range(num_ori):
            frames_rep = frames_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
            frames_iti_rep = frames_iti_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
            data_ori = data_rep[frames_rep]
            data_iti_ori = data_rep[frames_iti_rep]
            subth_ori[ori][rep] = np.mean(data_ori)
            subth_iti_ori[ori][rep] = np.mean(data_iti_ori)
            
    mean_subth_ori = np.mean(subth_ori, 1)
    mean_subth_iti = np.mean(subth_iti_ori, 1)
    baseline =  np.mean(mean_subth_iti)
    
    ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
    if normalized:
        subth_ori = np.divide(subth_ori, np.max(subth_ori, 0))
        mean_subth_ori = mean_subth_ori/np.max(mean_subth_ori)
        mean_subth_iti = mean_subth_iti/np.max(mean_subth_ori)
        yl = 'Mean dF/F in trial, normalized to preferred orientation'

    else:
        yl = 'Mean dF/F in trial'
    axes.plot(ori_degrees, mean_subth_ori[ori_order], color = 'k', linewidth = 2, label = 'Vis stim')
    axes.plot(ori_degrees, mean_subth_iti[ori_order], color = 'k', linestyle = '--', linewidth = 1.5, label = 'ITI')
    axes.plot(ori_degrees, subth_ori[ori_order], color = 'grey', alpha = 0.4, linewidth = 1.5,
             label = 'Individual repetitions')
    #labels.append('Cell %d' %(cell_id))
    
    # Make sure y axis has correct tick labels
    plt.yticks(-np.array(plt.yticks())[0].astype(float))

    if summary:
        if text:
            axes.set_title('Subthreshold')
    else:
        axes.set_title('Cell {0} orientation selectivity from subthreshold potential, {1} repetitions'.format(cell_id, num_rep))
        axes.set_ylabel(yl)
        axes.set_xlim((np.min(ori_degrees) - 10, np.max(ori_degrees) + 100))
        axes.set_xlabel('Degrees away from preferred orientation by spikes')
        axes.legend(loc = 'best')
    
    return baseline

#### Subthreshold - with baseline

In [46]:
def subth_os(animal_folder, cell_id, axes, num_rep, rep_frames, data, frames_ori, frames_iti_ori, 
             errorbars = False, ori_max = [], summary = False, title = False, normalized = False,):
   
    # axes: axis handles on which to plot figure
    # num_rep: number of repetitions for that animal
    # rep_frames: num_repX1 list of first frame in each repetition
    # data: detrended trace of cell
    # frames_ori: num_oriXnum_frames array with 1 if frame had visual stimulus of that orientation
    # frames_iti_ori: num_oriXnum_frames array with 1 if frame was in ITI just preceeding 
    #                  visual stimulus of that orientation
    # errorbars: default false, determines if error bars are plotted or spread of values over individual repetitions
    # ori_max: Location of preferred orientation in plot is ori_max
    # summary: default False, whether this figure is part of a larger summary figure (changes things that are displayed)
    # title: default False, whether to print figure title if part of summary figure
    # normalized: default False, whether absolute values or normalized to values for preferred orientation
    
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                            cell_id), 'rb') as f:
        md = pickle.load(f) 

    h = md['h']; w = md['w']; 
    num_frames = md['num_frames']; filenames = md['filenames']; 
    mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
    offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; #cell_folder = md['cell_folder']

    cell = cells.index(cell_id)
    rep_frames = rep_frames.astype(int)
    frames_ori = frames_ori.astype(bool)
    frames_iti_ori = frames_iti_ori.astype(bool)
    
    num_frames_ori = frame_rate*vis_on
    num_frames_iti = frame_rate*vis_off
    
    # Create arrays to hold activity binned by ori/rep
    subth_ori = np.zeros([num_ori, num_rep])
    subth_iti_ori = np.zeros([num_ori, num_rep])
    
    data_vis_on = np.zeros([num_ori, num_rep, num_frames_ori + 10])
    data_vis_off = np.zeros([num_ori, num_rep, num_frames_iti + 10])

    for rep in range(num_rep - 1):
        data_rep = - data[rep_frames[rep]:rep_frames[rep + 1] - 1]       
        for ori in range(num_ori):
            
            frames_rep = frames_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
            frames_iti_rep = frames_iti_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]

            data_ori = data_rep[frames_rep]
            data_iti_ori = data_rep[frames_iti_rep]

            data_vis_on[ori][rep][:data_ori.shape[0]] = data_ori
            data_vis_off[ori][rep][:data_iti_ori.shape[0]] = data_iti_ori

    # Average over repetitions   
    data_vis_on_mean = np.mean(data_vis_on, 1)
    data_vis_off_mean = np.mean(data_vis_off, 1)
    
    # Average over ITI for baseline
    baseline = np.mean(data_vis_off_mean, 1)
    #baseline = np.mean(np.mean(data_vis_off_mean, 1))*np.ones(num_ori)
    print('baseline = {0}'.format(baseline))
    
    # Find the variance over repetitions
    std_vis_on = np.std(data_vis_on, 1)
    std_vis_off = np.std(data_vis_off, 1)
    
    # Get the mean and variance peak of the trace during visual stimulation
    subth_plot = np.min(data_vis_on_mean, 1) - baseline
    subth_min_pts = np.argmin(data_vis_on_mean, 1)
    subth_std = np.std([[data_vis_on[ori][rep][subth_min_pts[ori]] for rep in range(num_rep)] for ori in range(num_ori)], 
                  1)
    subth_plot_reps = [[data_vis_on[ori][rep][subth_min_pts[ori]] for rep in range(num_rep)] for ori in range(num_ori)]
   
    # Check if preferred orientation is specified by user, else calculate it
    if ori_max == []:
        ori_max = np.argmax(subth_plot)
    
    ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
    
    if normalized:
        subth_plot = np.divide(subth_plot, np.max(subth_plot, 0))
        subth_plot_reps = np.divide(subth_plot_reps, np.max(subth_plot_reps, 0))
        yl = 'Mean dF/F in trial, normalized to preferred orientation'

    else:
        yl = 'Peak dF/F in trial - baseline'
        
    if errorbars:
        axes.errorbar(ori_degrees, subth_plot, subth_std)
    else:
        axes.plot(ori_degrees, subth_plot_reps, color = 'grey', linewidth = 1.5, alpha = 0.3,
                  label = 'Individual repetitions')
        axes.plot(ori_degrees, subth_plot, color = 'k', linewidth = 2)
        
    if summary:
        if title:
            axes.set_title('Subthreshold')
    else:
        axes.set_title('Cell {0} orientation selectivity from subthreshold potential, {1} repetitions'.format(cell_id, 
                                                                                                              num_rep))
        axes.set_ylabel(yl)
        axes.set_xlim((np.min(ori_degrees) - 10, np.max(ori_degrees) + 100))
        axes.set_xticks(ori_degrees)
        axes.set_xlabel('Degrees away from preferred orientation by spikes')
        axes.legend(loc = 'best')


    

# Load data

## Get ROIs from user

In [103]:
for cell_id in [16]:
    print('Cell {0}'.format(cell_id))
    # If metadata already exists, don't do it again
    filename = '{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id)
    file = Path(filename)
    if file.is_file():
        continue
    else:
        # Sort filenames
        print('Sorting filenames')
        filenames, num_frames = filename_sort(animal_folder, cell_id)
        
        # Get ROI and save all metadata
        mask, h, w, area, area_bg = get_ROI(animal_folder, cell_id, num_frames, num_preview, 
                                            filenames, frame_rates[cells.index(cell_id)], offset, binning)


Cell 16
Sorting filenames
Do you want to segment another cell? Enter 1 if yes, 0 if no1
Enter cell id (previous cell was 16)17
Do you want to segment another cell? Enter 1 if yes, 0 if no0


TypeError: 'NoneType' object is not iterable

In [104]:
# Check masks
for cell_id in [16, 17]:
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) # Metadata
    mask = md['mask']
    plt.figure()
    plt.title('Cell {0}'.format(cell_id))
    pl.imshow(mask)


## Get ROI traces from movies

In [63]:
for cell_id in cells:
    print('Cell {0}'.format(cell_id))
    # If data trace already exists, don't do it again
    filename = '{0}\Traces\Cell_{1}.pkl'.format(animal_folder, cell_id)
    file = Path(filename)
    if file.is_file():
        continue
    else:
        trace_from_movie(animal_folder, cell_id)



Cell 14
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
Calculating data trace
Cell 15
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000


## Remove LED off time

In [115]:
for cell_id in [16, 17]:
    # If data trace with LED off time removed already exists, don't do it again
    filename = '{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id)
    file = Path(filename)
    if file.is_file():
        continue
    else:
        print('Cell {0}'.format(cell_id))
        remove_LED_off(animal_folder, cell_id)


Cell 16
Cell 16: how many frames to delete from beginning?2100
Cell 17
Cell 17: how many frames to delete from beginning?2100


# Spike detection

## Automated

In [63]:
for cell_id in [4]:
    
    print('Cell {0}'.format(cell_id))
    marker = None
    x_vals = 'time'
    F0, window, data2 = detrend(animal_folder, cell_id, num_window, marker, x_vals)
    data3 = high_pass(animal_folder, cell_id, data2, freq_discard, 'time')
    peaks, data4, thresh_opt, cell_fpr = select_threshold(animal_folder, 
                                                cell_id, data3, num_thresh_test, 
                                                fpr, min_spikes)
    num_spikes, spike_frames = plot_spikes(animal_folder, cell_id, peaks, data4, F0, thresh_opt)
    print('{0} spikes'.format(num_spikes))
    sta, sta_error = find_sta(animal_folder, cell_id, sta_time, data2, F0, spike_frames, num_spikes, cell_fpr)
    
    # Save data 
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump({'F0': F0, 'num_window': num_window, 'data_detrend': data2, 
                    'data_high_pass': data3, 'freq_discard': freq_discard, 'peaks': peaks,
                    'data_high_pass_neg': data4, 'thresh_opt': thresh_opt, 
                     'num_thresh_test': num_thresh_test, 'fpr': cell_fpr, 'num_spikes': num_spikes,
                    'spike_frames': spike_frames, 'sta': sta, 'sta_error': sta_error}, f)

Cell 4
532 spikes


## Manual

In [119]:
cell_id = 17
ms = manual_spikes(animal_folder = animal_folder, cell_id = cell_id)

Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike and q to exit.
Ready


In [120]:
# Save data

# If automated spike detection has already been done:
filename = '{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id)
file = Path(filename)
if file.is_file():
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    sd['spike_frames_manual'] = ms.get_spikes()
    sd['num_spikes_manual'] = len(spike_frames_manual)
else:
    spike_frames_manual = ms.get_spikes()
    sd = {'spike_frames_manual': spike_frames_manual, 'num_spikes_manual': len(spike_frames_manual)}

with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'wb') as f:
    pickle.dump(sd, f)

# Spatio-temporal filtering

In [None]:
num_iter = 2
for cell_id in cells:
    
    for iter in range(num_iter):
        temp_filter(animal_folder, cell_id)
        sp_filter(animal_folder, cell_id)

# Visual stimulus alignment

## From photodiode trace

In [46]:
for cell_id in [16, 17]:
    print('Cell %d' %cell_id)
    vis_stim, camera_output, daq_rate, daq_time_vec = get_daq(animal_ID, daq_folder, cell_id)

    offset = plot_camera(animal_folder, cell_id, 0, 200, None,
                   daq_rate, daq_time_vec, camera_output)

    frame_times = find_frame_times(cell_id, offset, camera_output, daq_rate)

    plot_vis(animal_folder, cell_id, 0, 30, 'o', 
                 daq_rate, vis_stim)

    vis_stim_start_point, vis_stim_start_time, trial_start_times, first_stim_num, first_stim_time, first_ori = vis_start(cell_id, 
                                                                                daq_rate, num_vis_stim, vis_on, vis_off, 
                                                                                  frame_times, num_ori)

    plot_vis_start_times(cell_id, daq_time_vec, vis_stim, trial_start_times)
    
    # Save data 
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump({'vis_stim': vis_stim, 'camera_output': camera_output, 'daq_rate': daq_rate, 
                    'daq_time_vec': daq_time_vec, 'offset': offset, 'frame_times': frame_times,
                    'vis_stim_start_point': vis_stim_start_point, 'vis_stim_start_time': vis_stim_start_time, 
                    'trial_start_times': trial_start_times, 'first_stim_num': first_stim_num, 

                     'first_stim_time': first_stim_time, 'first_ori': first_ori}, f)

Cell 16
After what time to look for the first frame?25
Enter visual stimulus start frame100828
Cell 17
After what time to look for the first frame?25
Enter visual stimulus start frame100828


# Orientation selectivity

## All cells

In [None]:
spikes_ori = {'{0}'.format(cell_id):[] for cell_id in cells}
spikes_iti_ori = {'{0}'.format(cell_id):[] for cell_id in cells}
num_rep = np.zeros(num_cells)
num_spikes = np.zeros(num_cells)

proportion = True
manual = False

for cell_id in cells:
    if cell_id == 15:
        continue
    cell = cells.index(cell_id)
    
    # Bin camera frames by orientation
    [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
     first_trial_num, first_ori, num_rep[cell], rep_frames] = find_ori_frames(animal_folder, cell_id)
    
    # Bin spikes by orientation
    spikes_ori['%d' %cell_id], spikes_iti_ori['%d' %cell_id], num_spikes[cell] = find_ori_spikes(animal_folder, 
                                                                                                 cell_id, frames_ori, 
                                       frames_iti_ori, int(num_rep[cell]), rep_frames, manual)
    print('Cell {0}: {1} spikes'.format(cell_id, num_spikes[cell]))

# Plot orientation selectivity
plot_os_all_cells(animal_folder, spikes_ori, spikes_iti_ori, num_rep, proportion)
plt.savefig('{0}\Spikes_OS.png'.format(animal_folder))
    

## Single cell

In [94]:
# Subthreshold potential - single cell
fig = plt.figure()
axes = plt.gca()
for cell_id in [16]:
    cell = cells.index(cell_id)
    text = True
    summary = False
    # Bin camera frames by orientation
    [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
     first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)
    
    data = np.divide(data_16 - F0_16, F0_16)*100

    # Plot orientation selectivity
    plot_os_single_cell_subth(axes, animal_folder, cell_id, data, frames_ori, frames_iti_ori, rep_frames.astype(int), 
                              num_rep, normalized, text, summary)

NameError: name 'data_16' is not defined

In [113]:
# Spikes - single cell
fig = plt.figure()
axes = plt.gca()
for cell_id in [17]:
    text = True
    summary = False
    normalized = False
    manual = True
    cell = cells.index(cell_id)
    
    # Bin camera frames by orientation
    [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
     first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)
    
    # Bin spikes by orientation
    spikes_ori, spikes_iti_ori, num_spikes = find_ori_spikes(animal_folder, cell_id, frames_ori, 
                                       frames_iti_ori, int(num_rep), rep_frames, manual)
    print('Cell {0}: {1} spikes'.format(cell_id, num_spikes))

    # Plot orientation selectivity
    plot_os_single_cell(axes, animal_folder, cell_id, spikes_ori, spikes_iti_ori, num_rep, normalized, text, summary)

[2323 2822 3946 4346]
[]
[14808 15655]
[21937 21951 21954 23871]
[26036 27658 29992 30379]
[]
[38837]
[48244 48248 48252 48254]
[51986 52398 54854]
[62429]
[]
[]
[79130]
[]
[95840]
[]
[108732 109052]
[109905 111113 111357 111363 111536 113797 114559]
[117246 118533 118730 118747 121799 121830 121989]
[]
[131227 131850 131885]
[139652]
[2323 2822 3946 4346]
[]
[14808 15655]
[21937 21951 21954 23871]
[26036 27658 29992 30379]
[]
[38837]
[48244 48248 48252 48254]
[51986 52398 54854]
[62429]
[]
[]
[79130]
[]
[95840]
[]
[108732 109052]
[109905 111113 111357 111363 111536 113797 114559]
[117246 118533 118730 118747 121799 121830 121989]
[]
[131227 131850 131885]
[139652]
[2323 2822 3946 4346]
[]
[14808 15655]
[21937 21951 21954 23871]
[26036 27658 29992 30379]
[]
[38837]
[48244 48248 48252 48254]
[51986 52398 54854]
[62429]
[]
[]
[79130]
[]
[95840]
[]
[108732 109052]
[109905 111113 111357 111363 111536 113797 114559]
[117246 118533 118730 118747 121799 121830 121989]
[]
[131227 131850 131885

# Summary figures

## Automatically detected spikes

In [146]:
first_cell = int(num_cells/2)
last_cell = num_cells

%matplotlib qt

num_rows = last_cell - first_cell
num_cols = 5

fig, axes = plt.subplots(num_rows, num_cols)

for cell in range(first_cell, last_cell):
    cell_id = cells[cell]
    print(cell_id)
    
     # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load spiking data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    F0 = sd['F0']; spike_frames = sd['spike_frames'] + LED_off # Spike frames as if starting from first camera frame
    num_spikes = sd['num_spikes']; sta = sd['sta']
    sta_error = sd['sta_error']; data = sd['data_detrend']
    
    # Plot cell image
    fnums = np.round(np.linspace(0, num_frames - 1, num_preview)).astype(int)
    cell_folder = '{0}\Cell{1}'.format(animal_folder, cell_id)
    im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[0]]))
    w, h = im.size
    cols_mask = np.array(mask.nonzero())[1]
    left_edge = np.min(cols_mask) - 10 if np.min(cols_mask) > 10 else 0
    right_edge = np.max(cols_mask) + 10 if np.max(cols_mask) + 10 < w else w
    preview_array = np.zeros([h, w, num_preview])

    for frame in range(num_preview):
        im = Image.open('{0}\{1}'.format(cell_folder, 
                                         filenames[fnums[frame]]))
        preview_array[:, :, frame] = np.array(im)
    preview = np.mean(preview_array[:, left_edge:right_edge, :], 2)
    axes[cell - first_cell, 0].imshow(-preview, cmap = 'Greys')
    axes[cell - first_cell, 0].set_ylabel('Cell {0}'.format(cell_id))
    
    # Plot ROI mask
    mask_show = mask[:, left_edge:right_edge]
    axes[cell - first_cell, 1].imshow(mask_show, cmap = 'Greys')
    
    # Plot STA
    sta_frames = int(sta_time*frame_rate/1000)
    sta_time_vec = np.linspace(-sta_time, sta_time, sta_frames*2)
    
    axes[cell - first_cell, 2].plot(sta_time_vec, sta, linewidth = 1, color = 'k')
    axes[cell - first_cell, 2].fill_between(sta_time_vec, sta + sta_error, sta - sta_error, 
                     where=sta - sta_error <= sta + sta_error, 
                     facecolor='blue', alpha = 0.2, interpolate=True, 
                     label = 'Standard deviation')
    
    # Plot spike OS
    normalized = False
    if cell == first_cell:
        text = True
    else:
        text = False
    summary = True
    manual = False
    
    # Bin camera frames by orientation
    [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
     first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)
    
    # Bin spikes by orientation
    spikes_ori, spikes_iti_ori, num_spikes = find_ori_spikes(animal_folder, cell_id, frames_ori, 
                                       frames_iti_ori, int(num_rep), rep_frames, manual)
    
    plot_os_single_cell(axes[cell - first_cell, 3], animal_folder, cell_id, spikes_ori, spikes_iti_ori, num_rep, normalized, text,
                       summary)
    
    # Plot subthreshold OS
    normalized = False
    if cell == first_cell:
        text = True
    else:
        text = False
    
    plot_os_single_cell_subth(axes[cell - first_cell, 4], animal_folder, cell_id, -data, frames_ori, frames_iti_ori, 
                              rep_frames.astype(int), 
                              num_rep, normalized, text, summary)
    

10
Preferred direction is 6
11
Preferred direction is 2
12
Preferred direction is 0
13
Preferred direction is 3
14
Preferred direction is 1
15
Preferred direction is 0
16
Preferred direction is 4
17
Preferred direction is 2
18
Preferred direction is 0


## Manually annotated spikes

In [186]:
cells = [12, 15, 16]

%matplotlib qt

num_rows = len(cells)
num_cols = 5

fig, axes = plt.subplots(num_rows, num_cols)

for cell_id in cells:
    print(cell_id)
    cell = cells.index(cell_id)
    first_cell = 0
    last_cell = len(cells) - 1
    
     # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; mask = md['mask'];
    area = md['area']; area_bg = md['area_bg']; offset = md['offset']; binning = md['binning'];
    time_vec = md['time_vec']; frame_rate = md['frame_rate']; LED_off = md['LED_off']
    
    # Load spiking data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f)
    F0 = sd['F0']; 
    # Spike frames as if starting from first camera frame, with repeats removed
    spike_frames = np.unique(sd['spike_frames_manual']).astype(int)  
    num_spikes = len(spike_frames); 
    data = sd['data_detrend']
    
    if cell_id == 18:
        data = data_18
        F0 = F0_18
    
    # Plot cell image
    fnums = np.round(np.linspace(0, num_frames - 1, num_preview)).astype(int)
    cell_folder = '{0}\Cell{1}'.format(animal_folder, cell_id)
    im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[0]]))
    w, h = im.size
    cols_mask = np.array(mask.nonzero())[1]
    left_edge = np.min(cols_mask) - 10 if np.min(cols_mask) > 10 else 0
    right_edge = np.max(cols_mask) + 10 if np.max(cols_mask) + 10 < w else w
    preview_array = np.zeros([h, w, num_preview])

    for frame in range(num_preview):
        im = Image.open('{0}\{1}'.format(cell_folder, 
                                         filenames[fnums[frame]]))
        preview_array[:, :, frame] = np.array(im)
    preview = np.mean(preview_array[:, left_edge:right_edge, :], 2)
    axes[cell - first_cell, 0].imshow(-preview, cmap = 'Greys')
    axes[cell - first_cell, 0].set_ylabel('Cell {0}'.format(cell_id))
    
    # Plot ROI mask
    mask_show = mask[:, left_edge:right_edge]
    axes[cell - first_cell, 1].imshow(mask_show, cmap = 'Greys')
    
    # Plot STA
    sta_frames = int(sta_time*frame_rate/1000)
    sta_time_vec = np.linspace(-sta_time, sta_time, sta_frames*2)
    
    data = np.divide(data, F0)*100
    sta_vals = [data[(frame - sta_frames):(frame + sta_frames)] 
                for frame in spike_frames if frame > sta_frames and 
                frame < num_frames - sta_frames]                              
    sta = -np.mean(sta_vals, 0) # Negated

    sta_error = np.std(sta_vals, 0)
    
    axes[cell - first_cell, 2].plot(sta_time_vec, sta, linewidth = 1, color = 'k')
    axes[cell - first_cell, 2].fill_between(sta_time_vec, sta + sta_error, sta - sta_error, 
                     where=sta - sta_error <= sta + sta_error, 
                     facecolor='blue', alpha = 0.2, interpolate=True, 
                     label = 'Standard deviation')
    #axes[cell - first_cell, 2].set_yticklabels(-np.array(axes[cell - first_cell, 2].get_yticks()).astype(float))
    axes[cell - first_cell, 2].set_title('{0} spikes'.format(num_spikes))
    axes[cell - first_cell, 2].set_ylabel(' - dF/F (%)')
    if cell == last_cell:
        axes[cell - first_cell, 2].set_xlabel('Time from spike (ms)')
    
    
    # Plot spike OS
    normalized = False
    if cell == first_cell:
        text = True
    else:
        text = False
    summary = True
    manual = True
    
    # Bin camera frames by orientation
    [frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
     first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)
    
    # Bin spikes by orientation
    spikes_ori, spikes_iti_ori, num_spikes = find_ori_spikes(animal_folder, cell_id, frames_ori, 
                                       frames_iti_ori, int(num_rep), rep_frames, manual)
    
    ori_max = plot_os_single_cell(axes[cell - first_cell, 3], animal_folder, cell_id, spikes_ori, spikes_iti_ori, num_rep, 
                        normalized, text, summary)
    if cell == last_cell:
        axes[cell - first_cell, 3].set_xlabel('Degrees away from preferred orientation')
    
    # Plot subthreshold OS
    normalized = False
    if cell == first_cell:
        text = True
    else:
        text = False
    
    plot_os_single_cell_subth(axes[cell - first_cell, 4], animal_folder, cell_id, -data, frames_ori, frames_iti_ori, 
                              rep_frames.astype(int), ori_max,
                              num_rep, normalized, text, summary)
    if cell == last_cell:
        axes[cell - first_cell, 4].set_xlabel('Degrees away from preferred orientation (by spikes)')
    

12
[ 5150  5150  5175  5203  5258  5923  5770  5749  5712  5697  5670  6775
  6822  6858  6893  6910  6924  6953  6987  7042  7172  7299  7541  7973
  7998  8041  8244  8271  8559  8639  8795  8862  8906  8917  8925  8970
  8985  9310  9456  9491  9538  9538  9538  9538  9914  9914 10432 10534
 10597 10597]
[10892 10910 10970 10970 11091 11138 11235 11315 11325 11814 11861 11861
 11849 11920 11958 12307 12374 12374 12374 12423 12495 12535 13141 13215
 13187 13205 13248 13270 13288 13317 13417 13502 13905 14107 14193 14566
 14585 14672 14753 14773 14804 14966 15072 15066 15088 15123 15168 15168
 15364 15663 15857 15904 15931 15976 15976 15998 15811 16233 16647 16720
 16758 16781 16795 16840 16840 17106 17128 17143 17167 17188 17216 17251
 17300 17692 17723 17643 17587 17587 17587 18492 18478 18466 18452 18426
 18378 18864 18864 18840 18944 18807 19369 19405 19451 19508 19478 19527
 19540 19556 19713 19733 19759 20235 20272 20405 20566 20542 20576 20590
 20616 20666 20635 21272 21326 213

[3013 3051 3078 3210 3286 3376 3484 3602 3679 3709 3768 3789 3839 3826 3848
 3877 3891 3877 3936 3936 3973 3987 4270 4510 4583 4571 4583 4555 4583 4615
 4639 4681 4784 4749 4766 4789 4825 4857 4880 5053 5174 5196 5420 5444 5455
 5478 5483 5505 5505 5533 5601 5566 5579 5593 5631 5638 5650 5667 5689 5725
 5759 6004 6311 6303 6325 6384 6384 6368 6384 6412 6384 6424 6514 6685 6651
 6673 6705 6888 7260 7230 7306 7279 7310 7379 7453 7453 7453 7482 7488 7550
 7559 7591 7599 7874]
[ 8135  8145  8193  8216  8221  8241  8296  8397  8461  8474  8487  8487
  8487  8536  8827  9027  9043  9064  9099  9114  9180  9201  9270  9370
  9426  9446  9711  9931  9951  9946  9978  9983 10010 10096 10127 10155
 10168 10185 10206 10214 10243 10259 10556 10577 10608 10796 10907 10913
 10984 11012 11048 11139 11162 11187 11187 11244 11416 11483 11495 11536
 11520 11562 11779 11753 11817 11833 11858 11910 11966 12072 12113 12127
 12697 12750 12750 12750 12811 12862 12876 12925 12960 12967 13232 13298
 13293 1345

[]
[]
[15118 15125 15884 15917]
[22672 23208 23227 23234 23274]
[]
[]
[48513 48525 48739 49014 49022]
[51372 51379 55785 55940]
[]
[66835 66858 66865 66883 70486]
[74167 74175]
[]
[]
[]
[]
[110454 110465 114099]
[]
[130278]
[132296]
[143041 143048]
[]
[]
[]
[]
[15118 15125 15884 15917]
[22672 23208 23227 23234 23274]
[]
[]
[48513 48525 48739 49014 49022]
[51372 51379 55785 55940]
[]
[66835 66858 66865 66883 70486]
[74167 74175]
[]
[]
[]
[]
[110454 110465 114099]
[]
[130278]
[132296]
[143041 143048]
[]
[]
[]
[]
[15118 15125 15884 15917]
[22672 23208 23227 23234 23274]
[]
[]
[48513 48525 48739 49014 49022]
[51372 51379 55785 55940]
[]
[66835 66858 66865 66883 70486]
[74167 74175]
[]
[]
[]
[]
[110454 110465 114099]
[]
[130278]
[132296]
[143041 143048]
[]
[]
[]
[]
[15118 15125 15884 15917]
[22672 23208 23227 23234 23274]
[]
[]
[48513 48525 48739 49014 49022]
[51372 51379 55785 55940]
[]
[66835 66858 66865 66883 70486]
[74167 74175]
[]
[]
[]
[]
[110454 110465 114099]
[]
[130278]
[132296]
[1

# Playground

In [64]:
cell_id = 16

## Load data 

### Metadata

In [65]:
with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                            cell_id), 'rb') as f:
    md = pickle.load(f) 

h = md['h']; w = md['w']; 
num_frames = md['num_frames']; filenames = md['filenames']; 
mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
offset = md['offset']; binning = md['binning'];
time_vec = md['time_vec']; frame_rate = md['frame_rate']; #cell_folder = md['cell_folder']
LED_off = md['LED_off']

### Trace from movie (multiple cells in FOV)

In [95]:
cell_id_2 = 17

# Load movie tifs
data_array = np.zeros([h, w, num_frames])
for frame in range(num_frames):
    if(np.mod(frame, 1000) == 0):
                print(frame)
    im = Image.open('{0}\{1}'.format(cell_folder, filenames[frame]))
    data_array[:, :, frame] = np.array(im)

# Use ROI mask to get data and background traces
print('Calculating data trace')
data = np.sum(data_array[mask, :], 0)
#print('Calculating background trace')
#background = np.sum(data_array[~mask, :], 0)
background = np.zeros(data.shape)

# Go from pixel intensity to electrons
data = np.subtract(data, area*offset*binning)*0.48
background = np.subtract(background, area_bg*offset*binning)*0.48

# Save data and background traces
with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                           cell_id), 'wb') as f:
    pickle.dump({'data':data, 'background':background}, f)
    
cell_id = cell_id_2

# Load metadata
with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                            cell_id), 'rb') as f:
    md = pickle.load(f) 

h = md['h']; w = md['w']; 
num_frames = md['num_frames']; filenames = md['filenames']; 
mask = md['mask']; area = md['area']; area_bg = md['area_bg']; 
offset = md['offset']; binning = md['binning'];
time_vec = md['time_vec']; frame_rate = md['frame_rate']; cell_folder = md['cell_folder']

# Use ROI mask to get data and background traces
print('Calculating data trace')
data = np.sum(data_array[mask, :], 0)
#print('Calculating background trace')
#background = np.sum(data_array[~mask, :], 0)
background = np.zeros(data.shape)

# Go from pixel intensity to electrons
data = np.subtract(data, area*offset*binning)*0.48
background = np.subtract(background, area_bg*offset*binning)*0.48

# Save data and background traces
with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                           cell_id), 'wb') as f:
    pickle.dump({'data':data, 'background':background}, f)
    

del data_array




0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000


KeyboardInterrupt: 

### Vis stim timing data

In [67]:
# Load vis stim data
with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
    v = pickle.load(f)

daq_rate= v['daq_rate']; trial_start_times = v['trial_start_times']; frame_times = v['frame_times']
first_stim_time = v['first_stim_time']; first_ori = v['first_ori']; first_stim_num = v['first_stim_num']
daq_time_vec = v['daq_time_vec']

### Spiking data

In [57]:
with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
    sd = pickle.load(f)

if manual:
    spike_frames = sd['spike_frames_manual'].astype(int)
    num_spikes = sd['num_spikes_manual']
else:
    F0 = sd['F0']; spike_frames = sd['spike_frames'] + LED_off # Spike frames as if starting from first camera frame
    num_spikes = sd['num_spikes']; sta = sd['sta']
    sta_error = sd['sta_error']; data = sd['data_detrend']

## Imshow cell preview

In [188]:
%matplotlib qt
plt.figure()
cell_id = 16

# Plot cell image
fnums = np.round(np.linspace(0, num_frames - 1, num_preview)).astype(int)
cell_folder = '{0}\Cell{1}'.format(animal_folder, cell_id)
im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[0]]))
w, h = im.size
cols_mask = np.array(mask.nonzero())[1]
left_edge = np.min(cols_mask) - 10 if np.min(cols_mask) > 10 else 0
right_edge = np.max(cols_mask) + 10 if np.max(cols_mask) + 10 < w else w
preview_array = np.zeros([h, w, num_preview])

for frame in range(num_preview):
    im = Image.open('{0}\{1}'.format(cell_folder, 
                                     filenames[fnums[frame]]))
    preview_array[:, :, frame] = np.array(im)
#preview = np.mean(preview_array[:, left_edge:right_edge, :], 2)
preview = np.mean(preview_array[:, :, :], 2)
plt.imshow(-preview, cmap = 'Greys')
plt.ylabel('Cell {0}'.format(cell_id))
plt.xlabel('1 pixel = ~1um')

# # Plot ROI mask
# mask_show = mask[:, left_edge:right_edge]
# plt.imshow(mask_show, cmap = 'Greys')
# plt.xlabel('1 pixel = ~1um')


Text(0.5,0,'1 pixel = ~1um')

## Plot spikes with visual stimulus

In [60]:
manual = True

marker_type = None; x_vals = 'time'
F0, window, data = detrend(animal_folder, cell_id, num_window, marker_type, x_vals)

data_plot = np.divide(data, F0)*100

# Bin camera frames by orientation
[frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
 first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)

%matplotlib qt
plt.figure()
plt.plot(time_vec, -data_plot, linewidth = 1,
         #marker = 'o'
        )
plt.scatter(time_vec[spike_frames], (np.ones(spike_frames.shape)*np.max(-data_plot)*1.05), color = 'k')

scale = np.max(-data_plot)*1.1
for ori in range(num_ori):
    plt.fill_between(time_vec, scale*frames_ori[ori, LED_off:], np.zeros(time_vec.shape), alpha = 0.3)

# plt.yticks(-np.array(plt.yticks())[0])
# plt.title('Cell %d spikes' %cell_id, fontsize = 20)
# plt.xlabel('Time (s)', fontsize = 17)
# plt.ylabel('dF/F (%)', fontsize = 17)


## Plot photodiode trace with inferred visual stimulus timings

In [83]:
manual = True

marker_type = None; x_vals = 'time'
F0, window, data = detrend(animal_folder, cell_id, num_window, marker_type, x_vals)

data_plot = np.divide(data, F0)*100

# Bin camera frames by orientation
[frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
 first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)

filename = '{0}\ANM{1}_cell{2}_DAQ'.format(daq_folder, 
                                               animal_ID, cell_id)
mat = spio.loadmat(filename, squeeze_me=True)

first_frame_time = daq_time_vec[frame_times[0][0]]
vis_stim = np.array(mat['data'][:, 0])
vis_time_vec = daq_time_vec - first_frame_time

window = int(0.05*daq_rate) 
if window%2 == 0:
    window += 1
poly = 1 # Degree of polynomial to fit
F0_vis_stim = savgol_filter(vis_stim, window, poly)  

%matplotlib qt
plt.figure()
plt.plot(vis_time_vec, vis_stim, linewidth = 1, color = 'k')
plt.plot(vis_time_vec, F0_vis_stim, linewidth = 3, color = 'r')
plt.plot(vis_time_vec, np.mean(F0_vis_stim[frame_times[0][0]:frame_times[0][-1]])*np.ones(vis_time_vec.shape), linewidth = 1, color = 'blue')
scale = np.max(vis_stim)*1.1
for ori in range(num_ori):
    plt.fill_between(time_vec, scale*frames_ori[ori, LED_off:], np.zeros(time_vec.shape), alpha = 0.3)

## Plot detrended traces for two cells

In [107]:
cell_id_2 = 17

# Load data for first cell
with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id), 'rb') as f:
    data_dict = pickle.load(f) # Data

data = data_dict['data']

window = int(num_frames/num_window) # Should be made dependent on frame rate
if window%2 == 0:
    window += 1
poly = 1 # Degree of polynomial to fit
F0 = savgol_filter(data, window, poly)

# Load data for second cell
with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id_2), 'rb') as f:
    data_dict = pickle.load(f) # Data

data_2 = data_dict['data']

window = int(num_frames/num_window) # Should be made dependent on frame rate
if window%2 == 0:
    window += 1
poly = 1 # Degree of polynomial to fit
F0_2 = savgol_filter(data_2, window, poly)

plt.figure()
plt.plot(time_vec, np.divide(data - F0, F0)*100, linewidth = 1)
plt.plot(time_vec, np.divide(data_2 - F0_2, F0_2)*100 + 20, linewidth = 1)
plt.ylabel('dF/F', fontsize = 15)
plt.xlabel('Time (s)', fontsize = 15)

Text(0.5,0,'Time (s)')

## Automated spike detection and STA

In [130]:
data3 = high_pass(animal_folder, cell_id, data2, freq_discard, 'frames')

In [131]:
peaks, data4, thresh_opt, cell_fpr = select_threshold(animal_folder, 
                                                cell_id, data3, num_thresh_test, 
                                                fpr, min_spikes)

In [133]:
num_spikes, spike_frames = plot_spikes(animal_folder, cell_id, peaks, data4, F0, thresh_opt)
scale = 10
for ori in range(num_ori):
    plt.fill_between(time_vec, scale*frames_ori[ori][LED_off:], np.ones(time_vec.shape)*(-5), alpha = 0.3)

In [82]:
sta, sta_error = find_sta(animal_folder, cell_id, sta_time, data2, F0, spike_frames_manual.astype(int), num_spikes, 10)

In [135]:
data_filt = temp_filter(animal_folder, cell_id)
plt.figure()
plt.plot(time_vec, data_filt, linewidth = 0.5)
plt.title('Cell {0} data filtered once with matched filter'.format(cell_id))

Text(0.5,1,'Cell 2 data filtered once with matched filter')

In [137]:
data3 = high_pass(animal_folder, cell_id, data_filt, freq_discard, 'frames')

In [138]:
peaks, data4, thresh_opt, cell_fpr = select_threshold(animal_folder, 
                                                cell_id, data3, num_thresh_test, 
                                                fpr, min_spikes)

Minimum false positive rate is 98%


In [140]:
num_spikes, spike_frames = plot_spikes(animal_folder, cell_id, peaks, data4, F0, thresh_opt)

In [141]:
sta, sta_error = find_sta(animal_folder, cell_id, sta_time, data2, F0, spike_frames, num_spikes, cell_fpr)

In [494]:
data4 = - data3
thresh_max = np.max(data4)
thresh_min = 0.2*thresh_max # This will be around zero
thresh_vals = np.arange(thresh_min, thresh_max, (thresh_max - thresh_min)/num_thresh_test)
spike_count = [np.sum(data4[peaks] > thresh) for thresh in thresh_vals]

In [510]:
plt.figure()
dif = np.diff(spike_count)
dif_smooth = savgol_filter(dif, 501, 1)
plt.plot(thresh_vals[1:], np.diff(spike_count))
plt.plot(thresh_vals[1:], dif_smooth, linewidth = 2)

[<matplotlib.lines.Line2D at 0xedef358>]

## Subthreshold OS

In [50]:
plt.figure()
axes = plt.gca()

# Bin camera frames by orientation
[frames_ori, frames_iti_ori, first_trial_ori, first_trial_time, 
 first_trial_num, first_ori, num_rep, rep_frames] = find_ori_frames(animal_folder, cell_id)

rep_frames = rep_frames.astype(int)

# Load data
with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id), 'rb') as f:
    data_dict = pickle.load(f) # Data

data = data_dict['data']

window = int(num_frames/num_window) # Should be made dependent on frame rate
if window%2 == 0:
    window += 1
poly = 1 # Degree of polynomial to fit
F0 = savgol_filter(data, window, poly)  

data = np.divide(data - F0, F0)*100

subth_os(animal_folder, cell_id, axes, num_rep, rep_frames, data, frames_ori, frames_iti_ori)

baseline = [-0.22547469  0.02911245 -0.06212429  0.13524169 -0.11067327 -0.23370324
 -0.31104997 -0.15475376]


### Variants of algo

In [51]:
text = True
summary = False
normalized = False

cell = cells.index(cell_id)

fig = plt.figure()
axes = plt.gca()

ori_max = 0

# Plot orientation selectivity
frames_ori = frames_ori.astype(bool)
frames_iti_ori = frames_iti_ori.astype(bool)

subth_ori = np.zeros([num_ori, num_rep])
subth_iti_ori = np.zeros([num_ori, num_rep])

num_frames_ori = frame_rate*vis_on
num_frames_iti = frame_rate*vis_off

data_vis_on = np.zeros([num_ori, num_rep, num_frames_ori + 10])
data_vis_off = np.zeros([num_ori, num_rep, num_frames_iti + 10])

for rep in range(num_rep - 1):
    data_rep = - data[rep_frames[rep]:rep_frames[rep + 1] - 1]       
    for ori in range(num_ori):
        frames_rep = frames_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
        frames_iti_rep = frames_iti_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
        data_ori = data_rep[frames_rep]
        data_iti_ori = data_rep[frames_iti_rep]
        subth_ori[ori][rep] = np.mean(data_ori[:-int(0.5*data_ori.shape[0])])
        subth_iti_ori[ori][rep] = np.mean(data_iti_ori)
        data_vis_on[ori][rep][:data_ori.shape[0]] = data_ori
        data_vis_off[ori][rep][:data_iti_ori.shape[0]] = data_iti_ori

        
data_vis_on_mean = np.mean(data_vis_on, 1)
data_vis_off_mean = np.mean(data_vis_off, 1)
baseline = np.mean(data_vis_off_mean, 1)

std_vis_on = np.std(data_vis_on, 1)
std_vis_off = np.std(data_vis_off, 1)

x_vals_ori = [np.linspace(degree, degree + 20, num_frames_ori + 10) for degree in ori_degrees]
x_vals_iti = [np.linspace(degree - 20, degree, num_frames_iti + 10) for degree in ori_degrees]


mean_subth_ori = np.mean(subth_ori, 1)
mean_subth_iti = np.mean(subth_iti_ori, 1)

for ori in range(num_ori):
    for rep in range(num_rep):
        axes.plot(x_vals_ori[ori], data_vis_on[ori][rep], color = 'grey', alpha = 0.3, linewidth = 0.5)
        axes.plot(x_vals_iti[ori], data_vis_off[ori][rep], color = 'grey', alpha = 0.3, linewidth = 0.5)
    axes.plot(x_vals_ori[ori], data_vis_on_mean[ori], color = 'k')
    #axes.fill_between(x_vals_ori[ori], data_vis_on_mean[ori] + std_vis_on[ori] - baseline[ori], 
     #                 data_vis_on_mean[ori] - std_vis_on[ori] - baseline[ori],
      #              color = 'blue', alpha = 0.3)
    axes.plot(x_vals_iti[ori], data_vis_off_mean[ori], color = 'k', )
    #axes.fill_between(x_vals_iti[ori], data_vis_off_mean[ori] + std_vis_off[ori], data_vis_off_mean[ori] - std_vis_off[ori],
     #               color = 'red', alpha = 0.3)
    #plt.plot(x_vals_ori[ori], np.ones(x_vals_ori[ori].shape)*baseline[ori])
  
axes.set_xticks(ori_degrees)
    
subth_plot = np.max(data_vis_on_mean, 1) - baseline
subth_max_pts = np.argmax(data_vis_on_mean, 1)
subth_std = np.std([[data_vis_on[ori][rep][subth_max_pts[ori]] for rep in range(num_rep)] for ori in range(num_ori)], 
                  1)

#axes.errorbar(ori_degrees, subth_plot, subth_std)

ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)
if normalized:
    subth_ori = np.divide(subth_ori, np.max(subth_ori, 0))
    mean_subth_ori = mean_subth_ori/np.max(mean_subth_ori)
    mean_subth_iti = mean_subth_iti/np.max(mean_subth_ori)
    yl = 'Mean dF/F in trial, normalized to preferred orientation'

else:
    yl = 'Mean dF/F in trial'
# axes.plot(ori_degrees, mean_subth_ori[ori_order], color = 'k', linewidth = 2, label = 'Vis stim')
# axes.plot(ori_degrees, mean_subth_iti[ori_order], color = 'k', linestyle = '--', linewidth = 1.5, label = 'ITI')
# axes.plot(ori_degrees, subth_ori[ori_order], color = 'grey', alpha = 0.4, linewidth = 1.5,
#          label = 'Individual repetitions')
#labels.append('Cell %d' %(cell_id))

# Make sure y axis has correct tick labels
# plt.yticks(-np.array(plt.yticks())[0].astype(float))

if summary:
    if text:
        axes.set_title('Subthreshold')
else:
    axes.set_title('Cell {0} orientation selectivity from subthreshold potential, {1} repetitions'.format(cell_id, num_rep))
    axes.set_ylabel(yl)
    axes.set_xlim((np.min(ori_degrees) - 10, np.max(ori_degrees) + 100))
    axes.set_xlabel('Degrees away from preferred orientation by spikes')
    #axes.legend(loc = 'best')



In [108]:
plt.figure()

plt.plot(time_vec, data, linewidth = 1)
plt.plot(time_vec, np.ones(time_vec.shape)*baseline_stim)
scale = np.max(data)*1.1
for ori in range(num_ori):
    plt.fill_between(time_vec, scale*frames_ori[ori, LED_off:], np.zeros(time_vec.shape), alpha = 0.3)
    
    

In [85]:
cell = cells.index(cell_id)
rep_frames = rep_frames.astype(int)
frames_ori = frames_ori.astype(bool)
frames_iti_ori = frames_iti_ori.astype(bool)

num_frames_ori = frame_rate*vis_on
num_frames_iti = frame_rate*vis_off

# Create arrays to hold activity binned by ori/rep
subth_ori = np.zeros([num_ori, num_rep])
subth_iti_ori = np.zeros([num_ori, num_rep])

data_vis_on = np.zeros([num_ori, num_rep, num_frames_ori + 10])
data_vis_off = np.zeros([num_ori, num_rep, num_frames_iti + 10])

for rep in range(num_rep - 1):
    data_rep = - data[rep_frames[rep]:rep_frames[rep + 1] - 1]       
    for ori in range(num_ori):
        frames_rep = frames_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]
        frames_iti_rep = frames_iti_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]

        data_ori = data_rep[frames_rep]
        data_iti_ori = data_rep[frames_iti_rep]

        data_vis_on[ori][rep][:data_ori.shape[0]] = data_ori
        data_vis_off[ori][rep][:data_iti_ori.shape[0]] = data_iti_ori

# Average over repetitions   
data_vis_on_mean = np.mean(data_vis_on, 1)
data_vis_off_mean = np.mean(data_vis_off, 1)

# Average over ITI for baseline
#baseline = np.mean(data_vis_off_mean, 1)
baseline = np.mean(np.mean(data_vis_off_mean, 1))*np.ones(num_ori)
print('baseline = {0}'.format(baseline))

# Find the variance over repetitions
std_vis_on = np.std(data_vis_on, 1)
std_vis_off = np.std(data_vis_off, 1)

# Get the mean and variance peak of the trace during visual stimulation
subth_plot = np.min(data_vis_on_mean, 1) - baseline
subth_min_pts = np.argmin(data_vis_on_mean, 1)
subth_std = np.std([[data_vis_on[ori][rep][subth_min_pts[ori]] for rep in range(num_rep)] for ori in range(num_ori)], 
              1)
subth_plot_reps = [[data_vis_on[ori][rep][subth_min_pts[ori]] for rep in range(num_rep)] for ori in range(num_ori)]

# Check if preferred orientation is specified by user, else calculate it
if ori_max == []:
    ori_max = np.argmax(subth_plot)

ori_order = np.roll(range(num_ori), max_ori_pos - ori_max)

baseline = [ 0.  0.  0.  0.  0.  0.  0.  0.]


In [103]:
plt.figure()
# plt.plot(frames_rep)
# plt.plot(frames_iti_rep)
for ori in range(1):
    plt.plot(frames_iti_ori[ori])
    plt.plot(frames_ori[ori])
    

In [92]:
frames_iti_ori.shape

(8, 145000)