In [312]:
%matplotlib qt5
import sys
sys.path.insert(0, 'C:\\Users\\francescag\\Documents\\SourceTree_repos\\Python_git\\freely_moving_photometry_analysis')
from scipy.interpolate import interp1d
from utils.plotting import calculate_error_bars
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm
import os
import peakutils
import matplotlib
from matplotlib.lines import Line2D
from utils.plotting_visuals import makes_plots_pretty
import pickle 
from utils.post_processing_utils import remove_exps_after_manipulations, remove_bad_recordings
from utils.regression.linear_regression_utils import get_first_x_sessions
import pandas as pd
from utils.reaction_time_utils import get_bpod_trial_nums_per_session
from utils.plotting import calculate_error_bars

In [301]:
def get_all_mice_data(experiments_to_process):
    exp_numbers = []
    mice = []
    for index, experiment in experiments_to_process.iterrows():
        mouse = experiment['mouse_id']
        date = experiment['date']
        dates = experiments_to_process[experiments_to_process['mouse_id'] == mouse]['date'].values
        session_starts = get_bpod_trial_nums_per_session(mouse, dates)
        session_ind = np.where(dates == date)[0][0]
        session_start_trial = session_starts[session_ind]
        saving_folder = 'W:\\photometry_2AC\\processed_data\\for_figure\\' + mouse + '\\'
        save_filename = mouse + '_' + date + '_' + 'aligned_traces_correct_incorrect.p'

        sorted_exps = pd.to_datetime(
            experiments_to_process[experiments_to_process['mouse_id'] == mouse]['date']).sort_values(ignore_index=True)
        date_as_dt = pd.to_datetime(date)
        exp_number = sorted_exps[sorted_exps == date_as_dt].index[0]
        exp_numbers.append(exp_number)
        with open(saving_folder + save_filename, "rb") as f:
            content = pickle.load(f)
        print(mouse, date)
        if index == 0:
            correct = content.choice_data.contra_correct_data.sorted_traces[:,int(160000/2-20000):int(160000/2+20000)]
            incorrect = content.choice_data.contra_incorrect_data.sorted_traces[:,int(160000/2-20000):int(160000/2+20000)]
            time_stamps = content.choice_data.contra_correct_data.time_points[int(160000/2-20000):int(160000/2+20000)]
            correct_trial_nums = content.choice_data.contra_correct_data.trial_nums + session_start_trial
            incorrect_trial_nums = content.choice_data.contra_incorrect_data.trial_nums + session_start_trial
            reaction_times = content.choice_data.contra_incorrect_data.reaction_times
        else:
            correct = np.vstack([correct, content.choice_data.contra_correct_data.sorted_traces[:,int(160000/2-20000):int(160000/2+20000)]])
            incorrect = np.vstack([incorrect,content.choice_data.contra_incorrect_data.sorted_traces[:,int(160000/2-20000):int(160000/2+20000)]])
            correct_trial_nums = np.concatenate((correct_trial_nums, content.choice_data.contra_correct_data.trial_nums + session_start_trial))
            incorrect_trial_nums = np.concatenate((incorrect_trial_nums, content.choice_data.contra_incorrect_data.trial_nums + session_start_trial))
            reaction_times = np.concatenate((reaction_times, content.choice_data.contra_incorrect_data.reaction_times))
    return correct, incorrect, correct_trial_nums, incorrect_trial_nums, reaction_times, time_stamps



In [299]:

mouse_ids = ['SNL_photo21']
site = 'tail'

experiment_record = pd.read_csv('W:\\photometry_2AC\\experimental_record.csv')
experiment_record['date'] = experiment_record['date'].astype(str)
clean_experiments = remove_exps_after_manipulations(experiment_record, mouse_ids)
all_experiments_to_process = clean_experiments[
    (clean_experiments['mouse_id'].isin(mouse_ids)) & (clean_experiments['recording_site'] == site)].reset_index(
    drop=True)
experiments_to_process = remove_bad_recordings(all_experiments_to_process).reset_index(drop=True)


removing SNL_photo21: ['20200829' '20200830' '20200831' '20200908' '20200910' '20200911'
 '20200915' '20200917' '20200918' '20200921' '20201007' '20201008'
 '20201009']
Int64Index([0, 11], dtype='int64')


In [302]:
correct, incorrect, correct_trial_nums, incorrect_trial_nums, reaction_times, time_stamps = get_all_mice_data(experiments_to_process)

SNL_photo21 20200806
SNL_photo21 20200808
SNL_photo21 20200810
SNL_photo21 20200812
SNL_photo21 20200814
SNL_photo21 20200816
SNL_photo21 20200818
SNL_photo21 20200820
SNL_photo21 20200822
SNL_photo21 20200824
SNL_photo21 20200827
SNL_photo21 20200828


In [303]:
median_reaction_time = np.median(reaction_times)

In [304]:
max_trials = np.max(np.concatenate((incorrect_trial_nums, correct_trial_nums)))
early_incorrect_trials = incorrect_trial_nums[ incorrect_trial_nums < int(max_trials/3)]
mid_incorrect_trials = incorrect_trial_nums[np.logical_and(incorrect_trial_nums < int(max_trials/3)*2, incorrect_trial_nums > int(max_trials/3))]
late_incorrect_trials = incorrect_trial_nums[np.logical_and(incorrect_trial_nums <= max_trials, incorrect_trial_nums > int(max_trials/3)*2)]

In [305]:
early_incorrect_inds = np.nonzero(np.in1d(early_incorrect_trials, incorrect_trial_nums))[0]
mid_incorrect_inds = np.nonzero(np.in1d(mid_incorrect_trials, incorrect_trial_nums))[0]
late_incorrect_inds = np.nonzero(np.in1d(late_incorrect_trials, incorrect_trial_nums))[0]

In [306]:
def find_nearest_trials(target_trials, other_trials):
    differences = (target_trials.reshape(1,-1) - other_trials.reshape(-1,1))
    indices = np.abs(differences).argmin(axis=0)
    residual = np.diagonal(differences[indices,])
    return indices

In [307]:
early_correct_inds = find_nearest_trials(early_incorrect_trials, correct_trial_nums)
mid_correct_inds = find_nearest_trials(mid_incorrect_trials, correct_trial_nums)
late_correct_inds = find_nearest_trials(late_incorrect_trials, correct_trial_nums)

In [308]:
def get_traces_and_mean(all_traces, inds):
    traces = all_traces[inds, :]
    mean_trace = np.mean(traces, axis=0)
    return traces, mean_trace

In [309]:
def plot_trace(all_traces, inds, time_stamps, ax, color='green'):
    traces, mean_trace = get_traces_and_mean(all_traces, inds)

    ax.plot(time_stamps, mean_trace, color=color)
    error_bar_lower, error_bar_upper = calculate_error_bars(mean_trace,
                                                            traces,
                                                            error_bar_method='sem')
    ax.fill_between(time_stamps, error_bar_lower, error_bar_upper, alpha=0.5,
                     facecolor=color, linewidth=0)
    ax.set_xlabel('time(s)')
    ax.set_ylabel('z-scored fluorescence')


In [324]:
def get_peak(all_traces, inds, median_reaction_time):
    traces, mean_trace = get_traces_and_mean(all_traces, inds)
    half_way = int(traces.shape[1]/2)
    trace_from_event = mean_trace[half_way:half_way + int(1000*(median_reaction_time))]
    trial_peak_inds = peakutils.indexes(trace_from_event.flatten('F'))
    if trial_peak_inds.shape[0] > 0 or len(trial_peak_inds > 1):
        trial_peak_inds = trial_peak_inds[0] 
        trial_peaks = trace_from_event.flatten('F')[trial_peak_inds]
    else:
        trial_peak_inds = np.argmax(trace_from_event) 
        trial_peaks = np.max(trace_from_event)
    trial_peak_inds += half_way
    return trial_peaks, trial_peak_inds

In [325]:
trial_peaks, trial_peak_inds = get_peak(correct, mid_incorrect_inds, median_reaction_time)

In [327]:
fig, axs = plt.subplots(1,3, sharey=True)
axs[0].set_title('Early')
plot_trace(correct, early_correct_inds, time_stamps, axs[0], color='green')
plot_trace(incorrect, early_incorrect_inds, time_stamps, axs[0], color='red')
axs[0].scatter(time_stamps[trial_peak_inds], trial_peaks, color='green')

axs[1].set_title('Middle')
plot_trace(correct, mid_correct_inds, time_stamps, axs[1], color='green')
plot_trace(incorrect, mid_incorrect_inds, time_stamps, axs[1], color='red')
trial_peaks, trial_peak_inds = get_peak(incorrect, mid_incorrect_inds, median_reaction_time)
axs[1].scatter(time_stamps[trial_peak_inds], trial_peaks, color='red')

axs[2].set_title('Late')
plot_trace(correct, late_correct_inds, time_stamps, axs[2], color='green')
plot_trace(incorrect, late_incorrect_inds, time_stamps, axs[2], color='red')
plt.tight_layout()
print(early_incorrect_inds.shape[0], mid_incorrect_inds.shape[0], late_incorrect_inds.shape[0])

146 52 14


In [321]:
time_stamps[trial_peak_inds]

-1.8765617285108034

In [297]:
incorrect_trial_nums.shape

(212,)

In [191]:
late_incorrect_trials

array([ 7009.,  7127.,  7197.,  7391.,  7483.,  8160.,  8284.,  8345.,
        8350.,  9288.,  9657.,  9829., 10064., 10084.])

In [192]:
max_trials

10107.0

In [None]:
dir = 'W:\\photometry_2AC\\processed_data\\for_figure\\'
file_name = 'correct_incorrect_group_data_' + site +'.npz'
np.savez(dir + file_name, correct=ipsi_choice, incorrect=incorrect, time_stamps=time_stamps)



In [302]:
correct = content.choice_data.contra_correct_data.mean_trace
incorrect = content.choice_data.contra_incorrect_data.mean_trace

In [303]:
plt.plot(correct[int(160000/2-20000):int(160000/2+20000)], color='green')
plt.plot(incorrect[int(160000/2-20000):int(160000/2+20000)], color='red')

[<matplotlib.lines.Line2D at 0x101ca02b0>]

In [304]:
# TO DO: find the nearest correct trial to incorrect trial and match them for average

In [305]:
incorrect_trials = content.choice_data.contra_incorrect_data.trial_nums
correct_trials =  content.choice_data.contra_correct_data.trial_nums

In [306]:
incorrect_trials.shape

(94,)

In [307]:
differences = (incorrect_trials.reshape(1,-1) - correct_trials.reshape(-1,1))

In [308]:
indices = np.abs(differences).argmin(axis=0)
residual = np.diagonal(differences[indices,])
indices.shape

(94,)

In [309]:
a = plt.plot(np.mean(content.choice_data.contra_incorrect_data.sorted_traces[:,int(160000/2-20000):int(160000/2+20000)].T, axis=1), color='red')
b = plt.plot(np.mean(content.choice_data.contra_correct_data.sorted_traces[indices,int(160000/2-20000):int(160000/2+20000)].T, axis=1), color='green')

(40,)