# About

This notebook is for analysing orientation selectivity in pyramidal cells. These are the steps:

1. Rigid motion correction for the full frame.
2. Average the first 50 frames (after discarding LED off time) and present to the user to draw an ROI. 
3. Initial trace is sum of pixels in user-drawn ROI.
4. De-trending: Find F0 by doing a lowpass filter and subtract it from the trace.
5. High-pass filter the de-trended traces by taking an FFT and discarding all frequencies below 50Hz. 
6. Spike detection: Either manual or automated, as follows - 

    1. Smooth data for initial detection
    2. Finding a good threshold for spike detection:  For a particular threshold, detect spikes as peaks above the threshold that are spaced at a minimum distance of 4 points (should also be dependent on frame rate). Then do this over a range of thresholds and plot the number of spikes detected as a function of threshold. Find the inflection point (maximum slope) and choose this as your threshold. 
7. Aligning to visual stimulus: Create vector of visual stimulus on time for each of the eight conditions. Figure out how to do this in a smart way. 
8. Bin the spikes according to trials and orientation. Plot the orientation selectivity (with error bars)
Most of these are translated from Kaspar's Matlab code.

# Setup

## Imports

In [62]:
from PIL import Image
import math
import matplotlib.path as mplPath
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import pickle
import pylab as pl
import random
np.random.seed(1)
import re
import scipy.io as spio
from scipy.optimize import curve_fit
from scipy.signal import medfilt
from scipy.signal import savgol_filter
import sys
import time

## Parameters

In [97]:
animal_ID = 402362

if animal_ID == 402362:
    cells = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 
            31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
    cells_m = [1, 2, 4, 5, 13, 20, 21]
    # Frame rates in Hz
    frame_rates = {'1':500, '2':500, '3':500, '4':500, '5':500, '6':500, '7':500, '8':500, '9':500, '10':500, 
                   '11':500, '12':500, '13':500, '14':400, '15':400, '16':400, '17':400, '20':500, '21':500,
                  '18':500, '19': 500, }
    frame_rates = {'%d'%cell_id:500 for cell_id in cells}
    frame_rates['14'] = frame_rates['15'] = frame_rates['16'] = frame_rates['17'] = 400
elif animal_ID == 402361:
    cells = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
    cells_m = [2, 4, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17]
    # Frame rates in Hz
    frame_rates = {'2': 510, '3': 500, '4': 500, '5': 500, '6': 500, '7': 500, '8': 500, '9': 500, '10': 500, 
                   '11': 500, '12': 666, '13': 500, '14': 400, '15': 450, '16': 450, '17': 500, '18': 500,
                  '19': 500, '20':500, '21':500, '22':500, '23':450, '24':500, '25':500}
num_cells = len(cells)

%matplotlib qt

raw_data_folder = 'E:\ST_Voltron\{0}'.format(animal_ID)
animal_folder = 'F:\{0}'.format(animal_ID)
daq_folder = 'Z:\ST-Voltron_DAQ' # Where camera TTLs and vis stim timings are stored
image_folder = 'C:\\Users\\singha\\Dropbox\\Voltron project\\Amrita\\{0}'.format(animal_ID) # Where figures should be saved
num_preview = 100 # Number of frames to average for initial ROI drawing

binning = 4*4 # Camera setting - number of pixels binned while imaging
offset = 100 # Camera offset per pixel (in bits)

# Spike detection parameters
window_size = 5001 # Length of window for detrending by piecewise linear fit (with sliding window) in frames
                   # Should be odd

freq_discard = 50 # Highest frequency in Hz to discard for high-pass filtering
num_thresh_test = 1000 # Number of spike detection thresholds to test between min and max
fpr = 0.2 # Desired false positive rate as a fraction of detected spikes
sta_time = 40 # Time in ms (before and after spike) over which spike triggered average is calculated
min_spikes = 10 # Minimum number of spikes detected initially

# Visual stimulus parameters
num_ori = 8
vis_on = 1 # Stimulus on time in seconds
vis_off = 1 # Stimulus off time in seconds
ori_degrees = [-135, -90, -45, 0, 45, 90, 135, 180] # Orientations

# Orientation tuning parameters
max_ori_pos = 3 # Position at which preferred orientation should be plotted
baseline_start_time = -80 # Time in ms before trial start to be included in baseline
baseline_end_time = 20 # Time in ms after trial start to be included in baseline
response_start_time = 100 # Time in ms after trial start after which response starts
response_end_time = 400 # Time in ms after trial start when response ends

# Save parameters (metadata)
for cell_id in cells:
    try:
        with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
            md = pickle.load(f)
    except (FileNotFoundError, EOFError):
        md = {}
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'wb') as f:
        md['daq_folder'] = daq_folder
        md['num_preview'] = num_preview
        md['frame_rate'] = frame_rates['%d'%cell_id]
        md['offset']=  offset
        md['binning'] = binning
        md['window_size'] = window_size
        md['freq_discard'] = freq_discard
        md['num_thresh_test'] = num_thresh_test
        md['fpr'] = fpr
        md['sta_time'] = sta_time
        md['min_spikes'] = min_spikes
        md['num_ori'] = num_ori
        md['vis_on'] = vis_on
        md['vis_off'] = vis_off
        md['max_ori_pos'] = max_ori_pos
        md['ori_degrees'] = ori_degrees
        md['image_folder'] = image_folder
        md['baseline_start_frame'] = int(baseline_start_time*md['frame_rate']/1000)
        md['baseline_end_frame'] = int(baseline_end_time*md['frame_rate']/1000)
        md['response_start_frame'] = int(response_start_time*md['frame_rate']/1000)
        md['response_end_frame'] = int(response_end_time*md['frame_rate']/1000)
        pickle.dump(md, f)


# Rigid motion correction

# Get trace from ROI

## Execute all

In [40]:
# Get ROIs
for cell_id in [49]:
    print('Cell {0}'.format(cell_id))
    #Sort filenames
    [filenames, num_frames] = filename_sort(raw_data_folder, cell_id)
    # Get ROI 
    get_ROI(animal_folder, raw_data_folder, cell_id, num_frames, filenames)
    del filenames
    

Cell 49
Sorting filenames for cell 49
Do you want to segment another cell? Enter 1 if yes, 0 if no0


In [42]:
# Check masks
for cell_id in [49]:
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) # Metadata
    mask = md['mask']
    plt.figure()
    plt.title('Cell {0}'.format(cell_id))
    pl.imshow(mask)

In [33]:
# Get traces from ROIs (TIME CONSUMING)
for cell_id in range(18, 50):
    print('Cell {0}'.format(cell_id))
    # If data trace already exists, don't do it again
    filename = '{0}\Traces\Cell_{1}.pkl'.format(animal_folder, cell_id)
    file = Path(filename)
    if file.is_file():
        continue
    else:
        trace_from_movie(animal_folder, cell_id)

Cell 20
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000

In [34]:
# Remove LED off time
for cell_id in range(20, 26):
    # If data trace with LED off time removed already exists, don't do it again
#     filename = '{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id)
#     file = Path(filename)
#     if file.is_file():
#         continue
#     else:
    print('Cell {0}'.format(cell_id))
    remove_LED_off(animal_folder, cell_id)

Cell 20
LED off till 911th frame
Cell 21
LED off till 1060th frame
Cell 22
LED off till 1035th frame
Cell 23
LED off till 509th frame
Cell 24
LED off till 1614th frame
Cell 25
LED off till 628th frame


## Filename sorting

In [64]:
# Sort the file names of frames in a natural ascending order without requiring leading zeros 

def filename_sort(raw_data_folder, cell_id):
    # Get list of filenames for frames (sorted)
    print('Sorting filenames for cell %d' %cell_id)
    cell_folder = '{0}\Cell{1}'.format(raw_data_folder, cell_id)
    filenames = [f for f in os.listdir(cell_folder) if f.endswith('.tif')]
    num_frames = len(filenames)
    filenames = sort_nat_ascend(filenames, num_frames)
    return [filenames, num_frames]

def sort_nat_ascend(C, num_frames):
    # C is a list of filenames of individual frames to be sorted
    # Frame number is the last run of digits in the filename
    
    C2 = [re.sub('\d', '0', str) for str in C] # Replace runs of digits with zeros
    C3 = [np.array(list(str)) for str in C2] # Convert each string to an array of chars
    digits = [str == '0' for str in C3] # Positions of digits 
    
    # Extract the start and end indices of the last run of digits (frame number)
    end = np.array([np.squeeze(digit.nonzero())[-1]  for digit in digits]) + 1
    start = [end[frame] - np.argmax(digits[frame][(end[frame] - 1):0:-1] == False) for frame in range(num_frames)]
    
    # Sort by numerical values of frame numbers
    frame_numbers = [int(C[frame][start[frame]:end[frame]]) for frame in range(num_frames)]
    return [str for _,str in sorted(zip(frame_numbers,C))] 

## Interactive ROI drawing

### Main function

In [65]:
# This function also serves as the point at which all 
# metadata about the cell is saved. 

def get_ROI(animal_folder, raw_data_folder, cell_id, num_frames, filenames):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    offset = md['offset']; binning = md['binning']; num_preview = md['num_preview']; frame_rate = md['frame_rate']; 
    
    # Make preview file - 100 frames evenly spaced in the session
    fnums = np.round(np.linspace(0, num_frames - 1, num_preview)).astype(int)

    md['cell_folder'] = cell_folder = '{0}\Cell{1}'.format(raw_data_folder, cell_id)
    im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[0]]))
    w, h = im.size
    md['w'] = w
    md['h'] = h
    
    preview_array = np.zeros([h, w, num_preview])

    for frame in range(num_preview):
        im = Image.open('{0}\{1}'.format(cell_folder, filenames[fnums[frame]]))
        preview_array[:, :, frame] = np.array(im)
    preview = np.mean(preview_array, 2)
    
    # Get ROIs
    pl.imshow(preview)
    plt.title('Click to draw polygon ROI and doubleclick when done')
    
    next_cell = 1
    while (next_cell):
        my_roi = roipoly(roicolor='r') # draw new ROI in red color
        plt.pause(12)
        md['mask'] = mask = my_roi.getMask(preview)

        cols_mask = np.array(mask.nonzero())[1]
        rows_mask = np.array(mask.nonzero())[0]
        md['left_edge'] = left_edge = np.min(cols_mask) - 10 if np.min(cols_mask) > 10 else 0
        md['right_edge'] = right_edge = np.max(cols_mask) + 10 if np.max(cols_mask) + 10 < w else w
        md['top_edge'] = top_edge = np.min(rows_mask) - 10 if np.min(rows_mask) > 10 else 0
        md['bottom_edge'] = bottom_edge = np.max(rows_mask) + 10 if np.max(rows_mask) + 10 < h else h

        md['area'] = area = np.sum(mask)
        md['area_bg'] = area_bg = w*h - area  
    
        # Save metadata
        md['filenames'] = filenames
        md['num_frames'] = len(filenames)
        
        with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, 
                                            cell_id), 'wb') as f:
            pickle.dump(md, f)
        
        next_cell = int(input('Do you want to segment another cell? Enter 1 if yes, 0 if no'))
        if next_cell:
            ind = cells.index(cell_id) + 1
            cell_id = int(input('Enter cell id (previous cell was {0})'.format(cell_id)))
            if not cell_id in cells: 
                cells.insert(ind, cell_id)
            
    

### Class roipoly

In [66]:
# This code is from https://github.com/jdoepfert, 
#which can be used without permission

class roipoly:
    def __init__(self, fig=[], ax=[], roicolor='b'):
        if fig == []:
            fig = plt.gcf()

        if ax == []:
            ax = plt.gca()

        self.previous_point = []
        self.allxpoints = []
        self.allypoints = []
        self.start_point = []
        self.end_point = []
        self.line = None
        self.roicolor = roicolor
        self.fig = fig
        self.ax = ax
        #self.fig.canvas.draw()

        self.__ID1 = self.fig.canvas.mpl_connect(
            'motion_notify_event', self.__motion_notify_callback)
        self.__ID2 = self.fig.canvas.mpl_connect(
            'button_press_event', self.__button_press_callback)

        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()

    def getMask(self, currentImage):
        ny, nx = np.shape(currentImage)
        #print(self.allxpoints)
        poly_verts = [(self.allxpoints[0], self.allypoints[0])]
        for i in range(len(self.allxpoints)-1, -1, -1):
            poly_verts.append((self.allxpoints[i], self.allypoints[i]))

        # Create vertex coordinates for each grid cell...
        # (<0,0> is at the top left of the grid in this system)
        x, y = np.meshgrid(np.arange(nx), np.arange(ny))
        x, y = x.flatten(), y.flatten()
        points = np.vstack((x,y)).T

        ROIpath = mplPath.Path(poly_verts)
        grid = ROIpath.contains_points(points).reshape((ny,nx))
        return grid

    def displayROI(self,**linekwargs):
        l = plt.Line2D(self.allxpoints +
                     [self.allxpoints[0]],
                     self.allypoints +
                     [self.allypoints[0]],
                     color=self.roicolor, **linekwargs)
        ax = plt.gca()
        ax.add_line(l)
        plt.draw()

    def displayMean(self,currentImage, **textkwargs):
        mask = self.getMask(currentImage)
        meanval = np.mean(np.extract(mask, currentImage))
        stdval = np.std(np.extract(mask, currentImage))
        string = "%.3f +- %.3f" % (meanval, stdval)
        plt.text(self.allxpoints[0], self.allypoints[0],
                 string, color=self.roicolor,
                 bbox=dict(facecolor='w', alpha=0.6), **textkwargs)

    def __motion_notify_callback(self, event):
        if event.inaxes:
            ax = event.inaxes
            x, y = event.xdata, event.ydata
            # Move line around
            if (event.button == None or event.button == 1) and self.line != None: 
                self.line.set_data([self.previous_point[0], x],
                                   [self.previous_point[1], y])
                self.fig.canvas.draw()


    def __button_press_callback(self, event):
        if event.inaxes:
            x, y = event.xdata, event.ydata
            ax = event.inaxes
            # If you press the left button, single click
            if event.button == 1 and event.dblclick == False:  
                if self.line == None: # if there is no line, create a line
                    self.line = plt.Line2D([x, x],
                                           [y, y],
                                           marker='o',
                                           color=self.roicolor)
                    self.start_point = [x,y]
                    self.previous_point =  self.start_point
                    self.allxpoints=[x]
                    self.allypoints=[y]

                    ax.add_line(self.line)
                    self.fig.canvas.draw()
                    # add a segment
                else: # if there is a line, create a segment
                    self.line = plt.Line2D([self.previous_point[0], x],
                                           [self.previous_point[1], y],
                                           marker = 'o',
                                           color=self.roicolor)
                    self.previous_point = [x,y]
                    self.allxpoints.append(x)
                    self.allypoints.append(y)

                    event.inaxes.add_line(self.line)
                    self.fig.canvas.draw()
            elif ((event.button == 1 and event.dblclick==True) or (event.button == 3 and event.dblclick==False)) and self.line != None: # close the loop and disconnect
                    self.fig.canvas.mpl_disconnect(self.__ID1) #joerg
                    self.fig.canvas.mpl_disconnect(self.__ID2) #joerg

                    self.line.set_data([self.previous_point[0],
                                    self.start_point[0]],
                                   [self.previous_point[1],
                                    self.start_point[1]])
                    ax.add_line(self.line)
                    self.fig.canvas.draw()
                    self.line = None

                    if sys.flags.interactive:
                        pass
                    else:
                        #figure has to be closed so that code can continue
                        plt.close(self.fig) 

## Trace from movie

In [67]:
def trace_from_movie(animal_folder, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    h = md['h']; w = md['w']; num_frames = md['num_frames']; filenames = md['filenames']; 
    cell_folder = md['cell_folder']; mask = md['mask']; area = md['area'];
    offset = md['offset']; binning = md['binning'];
    
    # Load movie tifs
    data_array = np.zeros([h, w, num_frames])
    for frame in range(num_frames):
        if(np.mod(frame, 1000) == 0):
                    print(frame)
        im = Image.open('{0}\{1}'.format(cell_folder, filenames[frame]))
        data_array[:, :, frame] = np.array(im)
    
    # Use ROI mask to get data and background traces
    print('Calculating data trace')
    data = np.sum(data_array[mask, :], 0)
    del data_array # Free up memory

    # Go from pixel intensity to electrons
    data = np.subtract(data, area*offset*binning)*0.48

    # Save data and background traces
    with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                               cell_id), 'wb') as f:
        pickle.dump({'data':data}, f)

    

## Remove LED off time

In [68]:
def remove_LED_off(animal_folder, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    frame_rate = md['frame_rate'];
    
    # Load data
    with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, cell_id), 'rb') as f:
        data_dict = pickle.load(f)         
    data = data_dict['data']
    
    # Plot data 
    plt.figure()
    plt.plot(data[:10*frame_rate])
    plt.title('Cell {0}: hit spacebar and click on start point'.format(cell_id))                                       
    plt.xlabel('Frames')
    plt.ylabel('Electrons summed over mask')
    plt.grid()
    plt.pause(2)
    
    lo = LED_off_fn()
    plt.pause(10)
    LED_off = int(lo.get_LED_off())
    print('LED off till %dth frame' %LED_off)
    
    # Remove LED off frames
    data_dict['data'] = data[LED_off:]
    
    # Save modified traces
    with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(data_dict, f)
    
    # Save metadata
    md['LED_off'] = LED_off

    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(md, f)
       

In [69]:
class LED_off_fn:
    def __init__(self):
        self.fig = plt.gcf()
        self.ax = plt.gca()
        
        self.next_click_is_pt = False

        self.__ID1 = self.fig.canvas.mpl_connect('key_press_event', self.__key_press_callback)
        self.__ID2 = self.fig.canvas.mpl_connect('button_press_event', self.__button_press_callback)

        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()

    def __key_press_callback(self, event):
        if event.key == ' ':
            self.next_click_is_pt = True

    def __button_press_callback(self, event):
        if event.inaxes and self.next_click_is_pt:
            self.LED_off = event.xdata
            self.ax.scatter(event.xdata, event.ydata, color = 'r')

    def get_LED_off(self):
        return self.LED_off

# Visual stimulus alignment

## Execute all

In [None]:
for cell_id in range(20, 50):
    print(cell_id)
    get_daq(animal_ID, cell_id)
    get_trial_times(animal_ID, cell_id)

20
After what time to look for the first frame?0
Hit b/p and then click on a point equal to the baseline/peak in trial, respectively
Ready
Camera frames last longer than visual stimulus!
21
After what time to look for the first frame?0
Hit b/p and then click on a point equal to the baseline/peak in trial, respectively
Ready
Camera frames last longer than visual stimulus!
22
After what time to look for the first frame?0
Hit b/p and then click on a point equal to the baseline/peak in trial, respectively
Ready
Camera frames last longer than visual stimulus!
23


## Get DAQ data

In [70]:
def get_daq(animal_ID, cell_id):
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    
    daq_folder = md['daq_folder']; num_frames = md['num_frames']; LED_off = md['LED_off'] 
    
    filename = '{0}\ANM{1}_cell{2}_DAQ'.format(daq_folder, animal_ID, cell_id)
    mat = spio.loadmat(filename, squeeze_me=True)
    
    vis_stim = np.array(mat['data'][:, 0])
    camera_output = np.array(mat['data'][:, 1])
    daq_rate = mat['rate']
    daq_time_vec = mat['timeStamps']
    
    # There might be some extra frames before the imaging session. Those need to be removed
    # manually in order to find the first camera frame automatically
    plt.figure()
    plt.plot(daq_time_vec, camera_output)
    plt.title('Cell {0} camera TTLs'.format(cell_id))
    plt.pause(2)
    offset = int(input('After what time to look for the first frame?'))
    camera_output = camera_output[int(offset*daq_rate):]
    
    # Camera frame on times, in terms of daq data points
    high_pts = np.array((camera_output > 0.2).nonzero())
    camera_output[high_pts] = 1
    
    low_pts = np.array((camera_output < 0.2).nonzero())
    camera_output[low_pts] = 0
    
    dif = np.diff(camera_output, 1)    
    frame_times = np.array((dif == 1).nonzero()) + offset*daq_rate
    
    # There are some extra TTLs as compared to actually recorded frames. 
    # These amount to about 40ms of imaging time.
    # We will assume, for now, that the last 40ms of TTLs are spurious. 
    l = len(frame_times[0])
    frame_times = frame_times[0][:num_frames]
    md['time_vec'] = daq_time_vec[frame_times][LED_off:]

    # Save data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump({'vis_stim': vis_stim, 'camera_output': camera_output, 'daq_rate': daq_rate, 
                    'daq_time_vec': daq_time_vec, 'offset': offset, 'frame_times': frame_times,}, f)
    
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(md, f)
 


## Get trial times from vis stim

### Main function

In [117]:
def get_trial_times(animal_ID, cell_id):
    # Here, 'time' refers to the index in the DAQ recording
    # 'frame' refers to the camera frame number from the TTLs

    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']; num_frames = md['num_frames']

    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim'];
    daq_time_vec = v['daq_time_vec']

    window = int(0.02*daq_rate) 
    if window%2 == 0:
        window += 1
    poly = 1 # Degree of polynomial to fit
    F0_vis_stim = savgol_filter(vis_stim, window, poly)  

    vb = vis_stim_baseline(F0_vis_stim)
    plt.pause(10)

    baseline_vis = F0_vis_stim[vb.get_baseline()]
    trial_peak_vis = F0_vis_stim[vb.get_peak()] - baseline_vis

    # Find points in the vicinity of baseline
    above_baseline = np.array(F0_vis_stim > baseline_vis - trial_peak_vis*0.2)
    below_baseline = np.array(F0_vis_stim < baseline_vis + trial_peak_vis*0.2)
    equal_baseline = np.logical_and(above_baseline, below_baseline)

    # Find trial start and end times
    dif = np.diff(equal_baseline.astype(int))
    falls = np.array((dif == -1).nonzero())[0]
    trial_start_times = [fall for fall in falls if np.sum(equal_baseline[fall - int(0.5*vis_on*daq_rate):fall]) >= 0.5*vis_on*daq_rate]
    rises = np.array((dif == 1)).nonzero()[0]
    trial_end_times = [rise for rise in rises if np.sum(equal_baseline[rise:int(0.5*vis_on*daq_rate) + rise]) >= 0.5*vis_on*daq_rate - 1]

    # Find the first trial and last trial (overlapping with camera frames)
    offset = 1 if trial_end_times[0] < trial_start_times[0] else 0
    first_iti_num = (trial_end_times > frame_times[LED_off]).nonzero()[0][0]
    first_iti_time = trial_end_times[first_iti_num]
    first_stim_num = (trial_start_times > first_iti_time).nonzero()[0][0]
    first_stim_time = trial_start_times[first_stim_num]
    try:
        last_stim_num = (trial_end_times > frame_times[-1]).nonzero()[0][0] - offset
    except IndexError:
        print('Camera frames last longer than visual stimulus!')
        last_stim_num = len(trial_end_times) - offset

    # First orientation seen during imaging, taking first orientation presented as 0:
    first_ori = np.mod(first_stim_num, num_ori) 

    # Camera frames for each repetition, sorted into orientations
    frames_ori = np.zeros([num_ori, num_frames - LED_off])
    frames_iti_ori = np.zeros([num_ori, num_frames - LED_off])
    ori = first_ori

    for trial in range(first_stim_num, last_stim_num):
        frames_ori[ori][(frame_times[LED_off:] > trial_start_times[trial]).nonzero()[0]] = 1
        frames_ori[ori][(frame_times[LED_off:] > trial_end_times[trial + offset]).nonzero()[0]] = 0
        frames_iti_ori[ori][(frame_times[LED_off:] > trial_end_times[trial + offset - 1]).nonzero()[0]] = 1
        frames_iti_ori[ori][(frame_times[LED_off:] > trial_start_times[trial]).nonzero()[0]] = 0
        ori = np.mod(ori + 1, num_ori)

    # Find the number of times all orientations are repeated
    num_rep = int((last_stim_num - first_stim_num)/num_ori)

    # Find the first camera frame for each repetition
    first_trial_ori = first_stim_num + np.roll(range(num_ori), first_ori)
    trial_start_ori = [trial_start_times[first_trial_ori[ori]::num_ori] for ori in range(num_ori)]
    rep_frames_daq = trial_start_ori[first_ori][:num_rep]
    rep_frames = [(frame_times[LED_off:] > rep_frames_daq[rep]).nonzero()[0][0] for rep in range(num_rep)]

    # Add another repetition to get the frames for the last repetition 
    if frame_times[-1] - LED_off > trial_start_times[last_stim_num]:
        rep_frames = np.append(rep_frames, (frame_times[LED_off:] >= trial_start_times[last_stim_num]).nonzero()[0][0])
    else:
        rep_frames = np.append(rep_frames, frame_times[-1] - LED_off)
    
    # Save data
    v['F0_vis_stim'] = F0_vis_stim
    v['trial_start_times'] = trial_start_times
    v['trial_end_times'] = trial_end_times
    v['frames_ori'] = frames_ori
    v['frames_iti_ori'] = frames_iti_ori
    v['num_rep'] = num_rep
    v['rep_frames'] = rep_frames
    v['frame_times'] = frame_times # Extra TTLs removed
    v['first_stim_num'] = first_stim_num
    v['first_ori'] = first_ori
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(v, f)


### Interactive clicking functions

In [72]:
class vis_stim_baseline:
    
    def __init__(self, F0_vis_stim):  

        print('Hit b/p and then click on a point equal to the baseline/peak in trial, respectively')
        
        # Plot the smoothed trace
        plt.figure()
        plt.plot(F0_vis_stim)
        plt.pause(10) 

        self.next_click_is_baseline = False
        self.click = False
        self.baseline_time = 0
        self.approx_baseline = 0
        self.baseline_frame = 0
        
        self.next_click_is_peak = False
        self.peak_time = 0
        self.approx_peak = 0
        self.peak_frame = 0
        
        self.F0 = F0_vis_stim
        
        fig = plt.gcf()

        self.click_id = fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.key_id = fig.canvas.mpl_connect('key_press_event', self.keypress)
        
        print('Ready')
        
        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()
        

    def keypress(self, event):
        if event.key == 'b':
            self.next_click_is_baseline = True
        elif event.key == 'p':
            self.next_click_is_peak = True

    def onclick(self, event):
        ax = plt.gca()
        if self.next_click_is_baseline:
            self.baseline_time = event.xdata
            self.approx_baseline = event.ydata
            self.baseline_frame = int(self.baseline_time)
            ax.scatter(self.baseline_frame, self.F0[self.baseline_frame], color = 'b')
            self.next_click_is_baseline = False
        if self.next_click_is_peak:
            self.peak_time = event.xdata
            self.approx_peak = event.ydata
            self.peak_frame = int(self.peak_time)
            ax.scatter(self.peak_frame, self.F0[self.peak_frame], color = 'r')
            self.next_click_is_peak = False
     
    def get_baseline(self):
        return self.baseline_frame
    
    def get_peak(self):
        return self.peak_frame
         

## Check visual stimulus alignment

In [22]:
cell_id = 22

# Load metadata
with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
    md = pickle.load(f) 
LED_off = md['LED_off']

# Load vis stim data
with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
    v = pickle.load(f)

daq_rate= v['daq_rate']; frame_times = v['frame_times']; daq_time_vec = v['daq_time_vec']; 
vis_stim = v['vis_stim']; camera_output = v['camera_output']; frames_ori = v['frames_ori'] 

plt.figure()
plt.plot(daq_time_vec, vis_stim, color = 'k')
#plt.plot(F0_vis_stim, color = 'r', linewidth = 2)
scale = np.max(vis_stim)
for ori in range(num_ori):
    plt.fill_between(daq_time_vec[frame_times[LED_off:]], scale*frames_ori[ori], np.zeros(len(frame_times) - LED_off), alpha = 0.3)

# Spike detection

## Execute all

In [41]:
ms = {}
for cell_id in range(20, 26):
    F0, dF, dF_F = detrend(animal_folder, cell_id)
    try:
        with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
            sd = pickle.load(f)
        spike_frames_manual = sd['spike_frames_manual']
        sd['F0'] = F0
        sd['dF'] = dF
        sd['dF_F'] = dF_F
        with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'wb') as f:
            pickle.dump(sd, f)
            
    except (KeyError, FileNotFoundError, EOFError) as e:
        print('Cell {0}'.format(cell_id))
        ms['%d'%cell_id] = manual_spikes(animal_folder, cell_id, F0, dF, dF_F)
        

Cell 20
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike
Cell 21
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike
Cell 22
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike
Cell 23
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike
Cell 24
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike
Cell 25
Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike


In [255]:
spike_frames_manual = ms.get_spikes()
try:
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, 
                                                        cell_id), 'rb') as f:
        sd = pickle.load(f)
        spike_frames_manual = sd['spike_frames_manual'] 
        sd['dF'] = dF; sd['dF_F'] = dF_F; sd['F0'] = F0
except KeyError:
    sd['spike_frames_manual'] = spike_frames_manual; 
    sd['num_spikes_manual'] = len(spike_frames_manual)
except (FileNotFoundError, EOFError):
    sd = {'spike_frames_manual': spike_frames_manual, 
          'num_spikes_manual': len(spike_frames_manual)}

with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, 
                                                    cell_id), 'wb') as f:
    pickle.dump(sd, f)

## Detrending

In [73]:
def detrend(animal_folder, cell_id): 
    # window_size should be about 5000 frames
    # x_vals should be 'time' or 'frames' depending on what the detrended data should be plotted against
    
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    window_size = md['window_size']
    
    # Load data
    with open('{0}\Traces\Cell_{1}_LED_on.pkl'.format(animal_folder, cell_id), 'rb') as f:
        data_dict = pickle.load(f) # Data
        
    data = data_dict['data']
    
    # Fit a piecewise linear curve to the data trace (F0)
    poly = 1 # Degree of polynomial to fit
    F0 = savgol_filter(data, window_size, poly)    
    
    # Calculate dF
    dF = data - F0 # dF is better than dF/F for spike detection since noise gets amplified towards end of trace
                   # in dF/F due to F0 decreasing by bleaching
    dF_F = np.divide(dF, F0)*100 # dF/F in %
    
    return F0, dF, dF_F


## Manual spike annotation

In [74]:
class manual_spikes:
    
    def __init__(self, animal_folder, cell_id, F0, dF, dF_F):  

        self.data = -dF_F
        self.F0 = F0
        self.dF = dF
        self.animal_folder = animal_folder
        self.cell_id = cell_id
        
        print('Hit spacebar and then click in the vicinity of spikes. Press d to delete last spike')
        
        # Load metadata
        with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
            md = pickle.load(f) 

        time_vec = md['time_vec'];
        
        # Plot the detrended trace
        fig = plt.figure()
        self.axes = plt.gca()
        
        plt.plot(self.data)
        plt.title('Cell %d: spacebar to annotate spike, d to delete spike, s to save' %cell_id)
        #plt.yticks(plt.yticks()[0], -plt.yticks()[0])
        plt.ylabel('dF/F in %')

        # Find peaks (local minima in windows of three points) in the data
        dif = np.diff(self.data, 1) # Discrete difference, ignoring last frame
        dif = np.append(dif, False)
        # Boolean array, true for points higher than following point
        diff_pos = dif < 0 
        # Reverse discrete difference, ignoring the first frame
        dif_rev = np.flip(np.diff(np.flip(self.data, 0), 1), 0) 
        dif_rev = np.insert(dif_rev, 0, False)
        # Boolean array, true for points lower than the following point
        diff_rev_pos = dif_rev < 0 
        peaks = np.logical_and(diff_pos, diff_rev_pos)

        self.peak_pos = peaks.nonzero()[0]
        self.data_peaks = self.data[peaks]

        self.spike_frames_manual = np.zeros(0)

        self.next_click_is_spike = False
        self.spikes_are_done = False
        self.click = False
        self.approx_frame = 0
        self.approx_height = 0
        self.delete = False

        self.click_id = fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.key_id = fig.canvas.mpl_connect('key_press_event', self.keypress)
        
        if sys.flags.interactive:
            plt.show(block=False)
        else:
            plt.show()
        

    def keypress(self, event):
        if event.key == ' ':
            self.next_click_is_spike = True
        elif event.key == 'd':
            np.delete(self.spike_frames_manual, -1)
            self.spike_pt.remove()
            print('Last spike deleted')
        elif event.key == 's':
            self.spikes_are_done = True
            try:
                with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(self.animal_folder, 
                                                                    self.cell_id), 'rb') as f:
                    sd = pickle.load(f)
                    spike_frames_manual = sd['spike_frames_manual'] 
            except KeyError:
                sd['spike_frames_manual'] = self.spike_frames_manual; 
                sd['num_spikes_manual'] = len(self.spike_frames_manual)
                sd['dF'] = self.dF; sd['dF_F'] = - (self.data); sd['F0'] = self.F0
            except (FileNotFoundError, EOFError):
                sd = {'spike_frames_manual': self.spike_frames_manual, 
                      'num_spikes_manual': len(self.spike_frames_manual),
                     'F0': self.F0, 'dF': self.dF, 'dF_F': -self.data}

            with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(self.animal_folder, 
                                                                self.cell_id), 'wb') as f:
                pickle.dump(sd, f)

    def onclick(self, event):
        ax = self.axes
        if self.next_click_is_spike:
            self.approx_frame = event.xdata
            self.approx_height = event.ydata
            
            self.near_peaks = self.peak_pos[np.argpartition(np.abs(self.peak_pos - self.approx_frame), 
                                                            1)[:1]]
            self.spike_frame = self.near_peaks[np.argmin(np.power(self.data[self.near_peaks] + 
                                                                  self.approx_height, 2) + 
                                               np.power(self.near_peaks - self.approx_frame, 2))]           
            self.spike_pt = ax.scatter(self.spike_frame, self.data[self.spike_frame], color = 'r')
            #ax.scatter(self.near_peaks, self.data[self.near_peaks], color = 'k')
            self.spike_frames_manual = np.append(self.spike_frames_manual, self.spike_frame)
            self.next_click_is_spike = False      
     
    def get_spikes(self):
        #num_spikes_manual = len[self.spike_frames_manual]
        return self.spike_frames_manual
    
        

# Orientation selectivity

## Execute all

In [95]:
for cell_id in range(20, 26):
    print(cell_id)
    try:
        spike_os(animal_folder, cell_id)
        subth_os(animal_folder, cell_id)
    except KeyError as e:
        print(e)
        print('Cell {0} not done'.format(cell_id))

20
21
22
23
24
25


## Spikes

In [75]:
def spike_os(animal_folder, cell_id):
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']; frame_rate = md['frame_rate']; time_vec = md['time_vec']; 
    ori_degrees = md['ori_degrees']; max_ori_pos = md['max_ori_pos']

    # Load data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f) # Data
    spikes = np.unique(sd['spike_frames_manual'].astype(int))

    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim']; 
    frames_ori = v['frames_ori']; frames_iti_ori = v['frames_iti_ori']; num_rep = v['num_rep']
    rep_frames = v['rep_frames']; trial_start_times = v['trial_start_times']; trial_end_times = v['trial_end_times']
    daq_time_vec = v['daq_time_vec']; first_ori = v['first_ori']
    
    num_spikes_ori = np.zeros([num_ori, num_rep])
    num_spikes_iti_ori = np.zeros([num_ori, num_rep])
    spikes_rep = {'{0}'.format(rep):[] for rep in range(num_rep)}
         
    for rep in range(num_rep):
        spikes_rep['{0}'.format(rep)] = spikes[spikes > rep_frames[rep]]
        spikes_rep['{0}'.format(rep)] = spikes_rep['{0}'.format(rep)][spikes_rep['{0}'.format(rep)] < rep_frames[rep + 1]]
        for ori in range(num_ori):
            num_spikes_ori[ori][rep] = np.sum(frames_ori[ori][spikes_rep['{0}'.format(rep)]])
            num_spikes_iti_ori[ori][rep] = np.sum(frames_iti_ori[ori][spikes_rep['{0}'.format(rep)]])  
    
    pref_ori = np.argmax(np.mean(num_spikes_ori, 1))
    
    sd['spikes_rep'] = spikes_rep
    sd['num_spikes_ori'] = num_spikes_ori
    sd['num_spikes_iti_ori'] = num_spikes_iti_ori
    sd['pref_ori'] = pref_ori
    sd['mean_spikes_ori'] = np.mean(num_spikes_ori, 1)
    sd['mean_spikes_iti'] = np.mean(num_spikes_iti_ori, 1)
    
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(sd, f)

    

## Subthreshold

In [93]:
def subth_os(animal_folder, cell_id):
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']; frame_rate = md['frame_rate']; time_vec = md['time_vec']; 
    ori_degrees = md['ori_degrees']; max_ori_pos = md['max_ori_pos']; freq_discard = md['freq_discard']
    num_frames = md['num_frames']; baseline_start_frame = md['baseline_start_frame']; 
    baseline_end_frame = md['baseline_end_frame']; response_start_frame = md['response_start_frame'];
    response_end_frame = md['response_end_frame']

    # Load data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f) # Data
    dF_F = sd['dF_F']
    
    # Low pass filtering by fft
    # data_fft = np.fft.fft(dF_F)
    # num_freq_discard = np.argmin(np.power((freq - freq_discard), 2))
    # data_fft[num_freq_discard:num_frames - num_freq_discard] = 0 
    # low_pass = np.real(np.fft.ifft(data_fft))

    # Low pass filtering by median filter
    kernel_size = int(frame_rate/freq_discard)
    kernel_size = kernel_size + 1 if np.mod(kernel_size, 2) == 0 else kernel_size
    low_pass = medfilt(dF_F, kernel_size)

    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim']; 
    frames_ori = v['frames_ori'].astype(bool); frames_iti_ori = v['frames_iti_ori'].astype(bool); num_rep = v['num_rep']
    rep_frames = v['rep_frames']; trial_start_times = v['trial_start_times']; trial_end_times = v['trial_end_times']
    daq_time_vec = v['daq_time_vec']; first_ori = v['first_ori']
    
    response = np.zeros([num_ori, num_rep])
    baseline = np.zeros([num_ori, num_rep])
    
    for rep in range(num_rep):
        data_rep = low_pass[rep_frames[rep]:rep_frames[rep + 1] - 1]   
        for ori in range(num_ori):
            data_stim_ori = data_rep[frames_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]]
            data_iti_ori = data_rep[frames_iti_ori[ori][rep_frames[rep]:rep_frames[rep + 1] - 1]]
            response[ori][rep] = np.mean(data_stim_ori[response_start_frame:response_end_frame])
            baseline[ori][rep] = np.mean(np.concatenate((data_iti_ori[baseline_start_frame:],
                                                                 data_stim_ori[:baseline_end_frame])))   
    
    sd['response'] = response
    sd['baseline'] = baseline
    sd['subth_plot_all'] = response - baseline
    sd['subth_plot_mean'] = np.mean(response - baseline, 1)
    
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'wb') as f:
        pickle.dump(sd, f)

    

# Plotting

## dF/F and annotated spikes

In [41]:
for cell_id in range(20, 26):
    print(cell_id)
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']

    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_time_vec = v['daq_time_vec']; frame_times = v['frame_times']
    time_vec = daq_time_vec[frame_times[LED_off:]];
    
    # Load data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f) # Data
    dF_F = sd['dF_F']; spikes = sd['spike_frames_manual'].astype(int)

    plt.figure(figsize = (30, 2))
    plt.plot(time_vec, dF_F, linewidth = 0.5)
    plt.scatter(time_vec[spikes], dF_F[spikes], color = 'r')
    plt.ylabel('dF/F in %')
    plt.xlabel('Time in s')
    plt.title('Cell {0}: annotated spikes'.format(cell_id))
    plt.savefig('{0}\Cell_{1}_spikes'.format(image_folder, cell_id))


20
21
22
23
24
25


## Orientation selectivity

In [96]:
for cell_id in range(20, 26):
    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']; frame_rate = md['frame_rate']; time_vec = md['time_vec']; 
    ori_degrees = md['ori_degrees']; max_ori_pos = md['max_ori_pos']; num_frames = md['num_frames']
    image_folder = md['image_folder']

    # Load data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f) # Data
    num_spikes_ori = sd['num_spikes_ori']; num_spikes_iti = sd['num_spikes_iti_ori']; pref_ori = sd['pref_ori'];
    mean_spikes_ori = sd['mean_spikes_ori']; mean_spikes_iti = sd['mean_spikes_iti']
    mean_subth_ori = sd['mean_subth_ori']; peak_subth_iti_ori = sd['peak_subth_iti_ori']; subth_plot_mean = sd['subth_plot_mean']
    subth_plot_all = sd['subth_plot_all']; num_spikes = sd['num_spikes_manual']

    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim']; 
    frames_ori = v['frames_ori']; frames_iti_ori = v['frames_iti_ori']; num_rep = v['num_rep']
    rep_frames = v['rep_frames']; trial_start_times = v['trial_start_times']; trial_end_times = v['trial_end_times']
    daq_time_vec = v['daq_time_vec']; first_ori = v['first_ori']

    fig, axes = plt.subplots(nrows=2,ncols=1, figsize = (10, 16))

    # OS by spikes
    ori_order = np.roll(range(num_ori), max_ori_pos - pref_ori)
    axes[0].plot(ori_degrees, mean_spikes_ori[ori_order], color = 'k', linewidth = 2, label = 'Vis stim')
    axes[0].plot(ori_degrees, mean_spikes_iti[ori_order], color = 'k', linestyle = '--', linewidth = 1.5, label = 'ITI')
    axes[0].plot(ori_degrees, num_spikes_ori[ori_order], color = 'grey', alpha = 0.4, linewidth = 1.5,
             label = 'Individual repetitions')
    axes[0].set_xticks(ori_degrees)
    axes[0].set_xlabel('Degrees away from preferred orientation', fontsize = 15)
    axes[0].set_ylabel('Number of spikes in trial', fontsize = 15)
    axes[0].set_title('Cell %d orientation tuning from spikes' %cell_id, fontsize = 15)
    axes[0].legend(loc = 'best')

    # OS by subthreshold
    axes[1].plot(ori_degrees, subth_plot_mean[ori_order], color = 'k', linewidth = 2)
    axes[1].plot(ori_degrees, subth_plot_all[ori_order], color = 'grey', alpha = 0.4, linewidth = 1.5,)
    axes[1].set_xticks(ori_degrees)
    axes[1].set_xlabel('Degrees away from preferred orientation by spikes', fontsize = 15)
    axes[1].set_ylabel('Mean dF/F in trial - peak dF/F in ITI', fontsize = 15)
    axes[1].set_title('Orientation tuning from subthreshold', fontsize = 15)

    plt.savefig('{0}\Cell_{1}_OS'.format(image_folder, cell_id))

## Traces for all trials

In [45]:
for cell_id in range(20, 26):

    # Load metadata
    with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
        md = pickle.load(f) 
    LED_off = md['LED_off']; frame_rate = md['frame_rate']; time_vec = md['time_vec']; 
    ori_degrees = md['ori_degrees']; max_ori_pos = md['max_ori_pos']; num_frames = md['num_frames']
    freq_discard = md['freq_discard']; vis_on = md['vis_on']

    # Load data
    with open('{0}\Traces_2\Cell_{1}_spikes.pkl'.format(animal_folder, cell_id), 'rb') as f:
        sd = pickle.load(f) # Data
    dF_F = sd['dF_F']; pref_ori = sd['pref_ori']; spikes_rep = sd['spikes_rep']

    # Load vis stim data
    with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
        v = pickle.load(f)
    daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim']; 
    frames_ori = v['frames_ori']; frames_iti_ori = v['frames_iti_ori']; num_rep = v['num_rep']
    rep_frames = v['rep_frames']; trial_start_times = v['trial_start_times']; trial_end_times = v['trial_end_times']
    daq_time_vec = v['daq_time_vec']; first_ori = v['first_ori']; first_stim_num = v['first_stim_num']

    fig = plt.figure(figsize = (10, 16))
    axes = plt.gca()

    trial_frames = np.sum(frames_ori, 0)

    # Low pass filtering by fft
    # data_fft = np.fft.fft(dF_F)
    # num_freq_discard = np.argmin(np.power((freq - freq_discard), 2))
    # data_fft[num_freq_discard:num_frames - num_freq_discard] = 0 
    # low_pass = np.real(np.fft.ifft(data_fft))

    # Low pass filtering by median filter
    kernel_size = int(frame_rate/freq_discard)
    kernel_size = kernel_size + 1 if np.mod(kernel_size, 2) == 0 else kernel_size
    low_pass = medfilt(dF_F, kernel_size)

    data_rep = np.zeros([num_rep - 1, np.max(np.diff(rep_frames))])
    psth = np.zeros([num_rep - 1, np.max(np.diff(rep_frames))])
    
    # For PSTH, add gaussian for every spike
    sigma = 0.1*frame_rate/1000
    

    for rep in range(num_rep - 1):
        num_frames_rep = rep_frames[rep + 1] - 1 - rep_frames[rep]
        #data_rep[rep][:num_frames_rep] = dF_F[rep_frames[rep]:rep_frames[rep + 1] - 1]   
        data_rep[rep][:num_frames_rep] = low_pass[rep_frames[rep]:rep_frames[rep + 1] - 1]
        time_rep = time_vec[rep_frames[rep]:rep_frames[rep + 1] - 1]
        tr = len(time_rep)
        plt.plot(time_rep - time_rep[0], data_rep[rep][:num_frames_rep] + rep*20, color = 'k')

        plt.scatter(time_vec[spikes_rep['{0}'.format(rep)]] - time_rep[0], 
                    np.ones(spikes_rep['{0}'.format(rep)].shape) + rep*20 - 8,
                   color = 'r')
        for spike in spikes_rep['%d'%rep]:
            mu = time_vec[spike]
            psth[rep][:tr] = psth[rep][:tr] + np.exp(-np.power(time_rep - mu, 2)/2/sigma/sigma)
        
        trial_frames_rep = trial_frames[rep_frames[rep]:rep_frames[rep + 1] - 1]
        plt.fill_between(time_rep - time_rep[0], (rep*20 - 10)*np.ones(time_rep.shape), 
                         rep*20 - 10 + 20*trial_frames_rep, color = 'grey', alpha = 0.3)

    # Average over repetitions (enlarge the trace so you can see)
    mean_trace = np.mean(data_rep, 0)[:num_frames_rep]
    lower = np.min(mean_trace); upper = np.max(mean_trace); middle = (lower + upper)/2
    mean_trace = mean_trace - middle
    mean_trace = mean_trace*10/(upper - middle)
    plt.plot(time_rep - time_rep[0], mean_trace - 20, color = 'k')
    plt.fill_between(time_rep - time_rep[0], -30*np.ones(time_rep.shape), 
                         -30 + 20*trial_frames_rep, color = 'grey', alpha = 0.3)

    # Plot scalebar for average trace
    num_points_scalebar = int(vis_on*frame_rate*0.1)
    scalebar = np.ones(num_points_scalebar)*20
    plt.fill_between(np.arange(-2*num_points_scalebar, -num_points_scalebar)/frame_rate, scalebar - 48, scalebar - 32, color = 'k')
    plt.text(-10*num_points_scalebar/frame_rate, -25, '  {0}% \n dF/F'.format(int((upper - lower)*20/16)))

    # Plot PSTH (enlarge the trace so you can see)
    psth = np.sum(psth, 0)[:num_frames_rep]
    lower = np.min(psth); upper = np.max(psth); middle = (lower + upper)/2
    psth = psth - middle
    psth = psth*10/(upper - middle)
    plt.plot(time_rep - time_rep[0], psth - 40, color = 'r')
    plt.fill_between(time_rep - time_rep[0], -50*np.ones(time_rep.shape), 
                         -50 + 20*trial_frames_rep, color = 'grey', alpha = 0.3)
    
    # Plot scalebar for PSTH
    num_points_scalebar = int(vis_on*frame_rate*0.1)
    scalebar = np.ones(num_points_scalebar)*20
    plt.fill_between(np.arange(-2*num_points_scalebar, -num_points_scalebar)/frame_rate, scalebar - 68, scalebar - 52, color = 'k')
    plt.text(-10*num_points_scalebar/frame_rate, -45, '{0} spi- \n kes'.format(np.ceil((upper - lower)*20/16)))

    pref_ori_pos = pref_ori - first_ori if pref_ori >= first_ori else num_ori + 1 + pref_ori - first_ori
    x_ticks = daq_time_vec[trial_start_times[first_stim_num:first_stim_num + num_ori]]
    axes.set_xticks(x_ticks - x_ticks[0] + vis_on/2)
    ori_label_shift = pref_ori_pos - max_ori_pos - 1 if pref_ori_pos >= max_ori_pos else num_ori + pref_ori_pos - max_ori_pos 
    axes.set_xticklabels(np.roll(ori_degrees, ori_label_shift))
    axes.set_xlabel('Degrees away from preferred orientation by spikes', fontsize = 15)
    axes.set_yticks([-20, -40])
    axes.set_yticklabels(['Avg', 'PSTH'], fontsize = 12)
    axes.set_ylabel('dF/F in %, {0} reps'.format(num_rep - 1), fontsize = 15)
    axes.set_title('Cell {0}: Low pass filtered data (<50Hz)'.format(cell_id), fontsize = 15)

    plt.savefig('{0}\Cell_{1}_all_trials'.format(image_folder, cell_id))


# Ondrej's data

In [178]:
# After this, start from Remove LED off time
cell_id = 21
file = '{0}\Traces\Cell{1}.csv'.format(animal_folder, cell_id)
data = np.squeeze(np.array(pd.read_csv(file)))
with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'wb') as f:
        md['num_frames'] = len(data)
        md['area'] = 144        
        pickle.dump(md, f)
        
with open('{0}\Traces\Cell_{1}.pkl'.format(animal_folder, 
                                       cell_id), 'wb') as f:
    pickle.dump({'data':data}, f)

# Playground

In [116]:
cell_id = 20
with open('{0}\Traces\Cell_{1}_metadata.pkl'.format(animal_folder, cell_id), 'rb') as f:
    md = pickle.load(f) 
LED_off = md['LED_off']; num_frames = md['num_frames']

# Load vis stim data
with open('{0}\Vis_stim_info\Cell_{1}_vis_stim.pkl'.format(animal_folder, cell_id), 'rb') as f:
    v = pickle.load(f)
daq_rate= v['daq_rate']; frame_times = v['frame_times']; vis_stim = v['vis_stim'];
daq_time_vec = v['daq_time_vec']

window = int(0.02*daq_rate) 
if window%2 == 0:
    window += 1
poly = 1 # Degree of polynomial to fit
F0_vis_stim = savgol_filter(vis_stim, window, poly)  

vb = vis_stim_baseline(F0_vis_stim)
plt.pause(10)

baseline_vis = F0_vis_stim[vb.get_baseline()]
trial_peak_vis = F0_vis_stim[vb.get_peak()] - baseline_vis

# Find points in the vicinity of baseline
above_baseline = np.array(F0_vis_stim > baseline_vis - trial_peak_vis*0.2)
below_baseline = np.array(F0_vis_stim < baseline_vis + trial_peak_vis*0.2)
equal_baseline = np.logical_and(above_baseline, below_baseline)

# Find trial start and end times
dif = np.diff(equal_baseline.astype(int))
falls = np.array((dif == -1).nonzero())[0]
trial_start_times = [fall for fall in falls if np.sum(equal_baseline[fall - int(0.5*vis_on*daq_rate):fall]) >= 0.5*vis_on*daq_rate]
rises = np.array((dif == 1)).nonzero()[0]
trial_end_times = [rise for rise in rises if np.sum(equal_baseline[rise:int(0.5*vis_on*daq_rate) + rise]) >= 0.5*vis_on*daq_rate - 1]

# Find the first trial and last trial (overlapping with camera frames)
offset = 1 if trial_end_times[0] < trial_start_times[0] else 0
first_iti_num = (trial_end_times > frame_times[LED_off]).nonzero()[0][0]
first_iti_time = trial_end_times[first_iti_num]
first_stim_num = (trial_start_times > first_iti_time).nonzero()[0][0]
first_stim_time = trial_start_times[first_stim_num]
try:
    last_stim_num = (trial_end_times > frame_times[-1]).nonzero()[0][0] - offset
except IndexError:
    print('Camera frames last longer than visual stimulus!')
    last_stim_num = len(trial_end_times) - offset

# First orientation seen during imaging, taking first orientation presented as 0:
first_ori = np.mod(first_stim_num, num_ori) 

# Camera frames for each repetition, sorted into orientations
frames_ori = np.zeros([num_ori, num_frames - LED_off])
frames_iti_ori = np.zeros([num_ori, num_frames - LED_off])
ori = first_ori

for trial in range(first_stim_num, last_stim_num):
    frames_ori[ori][(frame_times[LED_off:] > trial_start_times[trial]).nonzero()[0]] = 1
    frames_ori[ori][(frame_times[LED_off:] > trial_end_times[trial + offset]).nonzero()[0]] = 0
    frames_iti_ori[ori][(frame_times[LED_off:] > trial_end_times[trial + offset - 1]).nonzero()[0]] = 1
    frames_iti_ori[ori][(frame_times[LED_off:] > trial_start_times[trial]).nonzero()[0]] = 0
    ori = np.mod(ori + 1, num_ori)

# Find the number of times all orientations are repeated
num_rep = int((last_stim_num - first_stim_num)/num_ori)

# Find the first camera frame for each repetition
first_trial_ori = first_stim_num + np.roll(range(num_ori), first_ori)
trial_start_ori = [trial_start_times[first_trial_ori[ori]::num_ori] for ori in range(num_ori)]
rep_frames_daq = trial_start_ori[first_ori][:num_rep]
rep_frames = [(frame_times[LED_off:] > rep_frames_daq[rep]).nonzero()[0][0] for rep in range(num_rep)]

# Add another repetition to get the frames for the last repetition 
if frame_times[-1] - LED_off > trial_start_times[last_stim_num]:
    rep_frames = np.append(rep_frames, (frame_times[LED_off:] >= trial_start_times[last_stim_num]).nonzero()[0][0])
else:
    rep_frames = np.append(rep_frames, frame_times[-1] - LED_off)


Hit b/p and then click on a point equal to the baseline/peak in trial, respectively
Ready
Camera frames last longer than visual stimulus!
