In [1]:
%matplotlib qt5
import pickle
from utils.plotting import mouse_data_for_group_plot
from utils.area_under_curve_utils import RawTracesZScored
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import numpy as np
import peakutils
import datetime 
import pandas as pd
from sklearn.linear_model import LinearRegression

In [2]:
def MakeDatesPretty(inputDates):
    # assumes input style YYYYMMDD_HHMMSS
    outputDates = []
    for date in inputDates: 
            x = datetime.datetime(int(date[0:4]), int(date[4:6]), int(date[6:8]))
            outputDates.append(x.strftime("%b%d")) 
    return(outputDates)

def percentage_correct_correlation(mouse, dates, peaks):
    BpodProtocol = '/Two_Alternative_Choice/'
    GeneralDirectory = 'W:/photometry_2AC/bpod_data/'
    DFfile = GeneralDirectory + mouse + BpodProtocol + 'Data_Analysis/' + mouse + '_dataframe.pkl'
    behavioural_stats = pd.read_pickle(DFfile)
    reformatted_dates = MakeDatesPretty(dates)
    percentage_correct = []
    for date_num, date in enumerate(reformatted_dates):
        points_for_day = behavioural_stats[behavioural_stats['SessionTime'].str.contains(date)]
        percentage_correct.append(100 * np.sum(points_for_day['FirstPokeCorrect'])/len(points_for_day))
        
    num_types = len(dates)
    colours = cm.viridis(np.linspace(0, 0.8, num_types))
    fig, axs = plt.subplots(1, ncols=1, figsize=(10, 8))
    fig.subplots_adjust(hspace=0.5, wspace=0.2)
    fig.suptitle('Exit centre poke', fontsize=16)
    fig.text(0.06, 0.02, mouse, fontsize=12)
    axs.title.set_text('Contralateral choice peak activity')
    axs.scatter(percentage_correct, peaks, color=colours)
    axs.set_xlabel('Percentage correct')
    axs.set_ylabel('Peak size (z-score)')
    return(peaks, percentage_correct)

def num_rewards_correlation(mouse, dates, peaks):
    BpodProtocol = '/Two_Alternative_Choice/'
    GeneralDirectory = 'W:/photometry_2AC/bpod_data/'
    DFfile = GeneralDirectory + mouse + BpodProtocol + 'Data_Analysis/' + mouse + '_dataframe.pkl'
    behavioural_stats = pd.read_pickle(DFfile)
    reformatted_dates = MakeDatesPretty(dates)
    num_rewards = []
    for date_num, date in enumerate(reformatted_dates):
        points_for_day = behavioural_stats[behavioural_stats['SessionTime'].str.contains(date)]
        num_rewards.append(np.sum(points_for_day['FirstPokeCorrect']))
    cum_num_rewards = np.cumsum(num_rewards)
    num_types = len(dates)
    colours = cm.viridis(np.linspace(0, 0.8, num_types))
    fig, axs = plt.subplots(1, ncols=1, figsize=(10, 8))
    fig.subplots_adjust(hspace=0.5, wspace=0.2)
    fig.suptitle('Exit centre poke', fontsize=16)
    fig.text(0.06, 0.02, mouse, fontsize=12)
    axs.title.set_text('Contralateral choice peak activity')
    axs.scatter(cum_num_rewards, peaks, color=colours)
    axs.set_xlabel('Number or rewards ever')
    axs.set_ylabel('Peak size (z-score)')
    return(peaks, cum_num_rewards)
        

In [4]:
def multi_day_peaks(mouse, dates, ipsi_or_contra):
    reformatted_dates = []
    for date in dates:
        year = int(date[0:4])
        month = int(date[4:6])
        day = int(date[6:])
        reformatted_dates.append(datetime.date(year, month, day))

    reform_dates = np.array(reformatted_dates, dtype='datetime64')    
    days_since_last_recording = np.concatenate([np.array([0], dtype='timedelta64[D]'),np.diff(reform_dates)]).astype(int)
    day_of_training = np.cumsum(days_since_last_recording)
    
    num_types = len(dates)
    colours = cm.viridis(np.linspace(0, 0.8, num_types))
    saving_folder = 'W:\\photometry_2AC\\processed_data\\' + mouse + '\\'
    fig, axs = plt.subplots(1, ncols=1, figsize=(10, 8))
    fig.subplots_adjust(hspace=0.5, wspace=0.2)
    fig.suptitle('Exit centre poke', fontsize=16)
    fig.text(0.06, 0.02, mouse, fontsize=12)
    all_peaks = []
    for date_num, date in enumerate(dates):
        mean_and_sem_filename = saving_folder + mouse + '_' + date + '_' + 'auc_correct_data.p'
        auc_data = pickle.load( open(mean_and_sem_filename, "rb" ))
        if ipsi_or_contra == 'contra':
            mean_auc = auc_data.contra_trials_auc
        elif ipsi_or_contra == 'ipsi':
            mean_auc = auc_data.ipsi_trials_auc
        all_peaks.append(mean_auc)
        axs.title.set_text('Contralateral choice peak activity')
        axs.scatter(day_of_training[date_num],mean_auc, color=colours[date_num])
        axs.set_xlabel('Days since start of training')
        axs.set_ylabel('Peak size (z-score)')
        axs.legend(dates,frameon=False)
    return(all_peaks)     

In [5]:
class mouseDates(object):
    def __init__(self, mouse_id, dates):
        self.mouse = mouse_id
        self.dates = dates
        
def peaks_correlations_multi_mice(mice_dates, ipsi_or_contra='contra'):
    fig, axs = plt.subplots(1, ncols=1, figsize=(10, 8))
    fig.subplots_adjust(hspace=0.5, wspace=0.2)
    fig.suptitle('Exit centre poke', fontsize=16)
    all_mice_aucs = []
    all_mice_dates = []
    for mouse_dates in mice_dates:
        mouse = mouse_dates.mouse
        dates = mouse_dates.dates
        reformatted_dates = []
        for date in dates:
            year = int(date[0:4])
            month = int(date[4:6])
            day = int(date[6:])
            reformatted_dates.append(datetime.date(year, month, day))

        reform_dates = np.array(reformatted_dates, dtype='datetime64')    
        days_since_last_recording = np.concatenate([np.array([0], dtype='timedelta64[D]'),np.diff(reform_dates)]).astype(int)
        day_of_training = np.cumsum(days_since_last_recording)

        saving_folder = 'W:\\photometry_2AC\\processed_data\\' + mouse + '\\'
        days_of_recording = []
        all_aucs = []
        for date_num, date in enumerate(dates):
            mean_and_sem_filename = saving_folder + mouse + '_' + date + '_' + 'auc_correct_data.p'
            auc_data = pickle.load( open(mean_and_sem_filename, "rb" ))
            mean_and_sem_filename = saving_folder + mouse + '_' + date + '_' + 'auc_correct_data.p'
            auc_data = pickle.load( open(mean_and_sem_filename, "rb" ))
            if ipsi_or_contra == 'contra':
                mean_auc = np.mean(auc_data.contra_trials_auc)
            elif ipsi_or_contra == 'ipsi':
                mean_auc = np.mean(auc_data.ipsi_trials_auc)
            all_aucs.append(mean_auc)
            days_of_recording.append(day_of_training[date_num])
        normalised_aucs = all_aucs/all_aucs[0] *100
        all_mice_aucs.append(normalised_aucs)
        all_mice_dates.append(days_of_recording)
        axs.scatter(days_of_recording, normalised_aucs, color='b')
    all_mice_dates_flat = [item for sublist in all_mice_dates for item in sublist]
    all_mice_aucs_flat = [item for sublist in all_mice_aucs for item in sublist]
    X = np.array(all_mice_dates_flat).reshape(-1, 1)
    Y = np.array(all_mice_aucs_flat).reshape(-1, 1)
    linear_regressor = LinearRegression()  # create object for the class
    linear_regressor.fit(X, Y)  # perform linear regression
    Y_pred = linear_regressor.predict(X)  # make predictions
    r_sq = linear_regressor.score(X,Y)
    plt.plot(X, Y_pred, lw=1, color='#746D69')
    axs.scatter(X, Y, color='#3F888F', alpha=0.8)
    print(r_sq)
        
    axs.title.set_text('Contralateral choice peak activity')
    axs.set_xlabel('Days since start of training')
    axs.set_ylabel('Percentage change since first day of recording')

In [9]:
mice_dates = []
mouse1 = 'SNL_photo16'
dates1 = ['20200210','20200213', '20200218',  '20200220', '20200224', '20200227', '20200303','20200305', '20200307',
         '20200310', '20200312']
mice_dates.append(mouseDates(mouse1, dates1))
mouse2 = 'SNL_photo17'
dates2 = ['20200204', '20200206', '20200208', '20200210', '20200212', '20200214', '20200218', '20200221', '20200224',
         '20200226', '20200228', '20200303','20200305', '20200307','20200310', '20200317']
mice_dates.append(mouseDates(mouse2, dates2))
mouse3 = 'SNL_photo18'
dates3 = ['20200223', '20200226', '20200228', '20200229','20200303', '20200305', '20200307',
         '20200310', '20200312', '20200313', '20200316']
mice_dates.append(mouseDates(mouse3, dates3))


In [10]:
peaks_correlations_multi_mice(mice_dates, 'contra')

0.25545859605730015


In [13]:
mouse = 'SNL_photo16'
dates = ['20200210','20200213', '20200218',  '20200220', '20200224', '20200227', '20200303','20200305', '20200307',
         '20200310', '20200312']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
#percentage_correct_correlation(mouse, dates, all_peaks)

In [11]:
mouse = 'SNL_photo17'
dates = ['20200204', '20200206', '20200208', '20200210', '20200212', '20200214', '20200218', '20200221', '20200224',
         '20200226', '20200228', '20200303','20200305', '20200307','20200310', '20200317']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
#all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra='ipsi')

#peaks_correlations(dates, all_peaks)
#percentage_correct_correlation(mouse, dates, all_peaks)
#num_rewards_correlation(mouse, dates, all_peaks)

In [12]:
mouse = 'SNL_photo18'
dates = ['20200223', '20200226', '20200228', '20200229','20200303', '20200305', '20200307',
         '20200310', '20200312', '20200313', '20200316']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
#all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra='ipsi')
#peaks_correlations(dates, all_peaks)
percentage_correct_correlation(mouse, dates, all_peaks)
num_rewards_correlation(mouse, dates, all_peaks)

([1.5903829497110198,
  1.5258351854234367,
  1.8508701450088065,
  1.63846638380776,
  1.7028560879570223,
  1.3484668438555063,
  1.123776120602668,
  1.3610808927193239,
  1.1831832590675613,
  1.1499446524653398,
  1.2976046217155228],
 array([  92.,  222.,  420.,  666.,  922., 1232., 1581., 1898., 2169.,
        2484., 2876.]))

In [34]:
mouse = 'SNL_photo19'
dates = ['20200221','20200224', '20200226', '20200228','20200229', '20200303',
         '20200305', '20200307','20200310','20200312']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
percentage_correct_correlation(mouse, dates, all_peaks)



([3462.2613104872667,
  4159.37911038568,
  3725.8169331519143,
  3604.344745509577,
  3544.637437991599,
  3124.4548829021733,
  2489.484462111761,
  2239.208062819633,
  2537.217175632211,
  3091.8950002690667],
 [44.875346260387815,
  58.529411764705884,
  70.50691244239631,
  69.75308641975309,
  68.55670103092784,
  83.59550561797752,
  86.0310421286031,
  87.73584905660377,
  90.51918735891648,
  nan])

In [36]:
mouse = 'SNL_photo20'
dates = ['20200224', '20200226', '20200228','20200229', '20200303', '20200303','20200305', '20200307'
         ,'20200310', '20200312']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
percentage_correct_correlation(mouse, dates, all_peaks)



([5862.070064252929,
  5073.386297590416,
  4433.040554826507,
  5029.199977031287,
  3991.587733314837,
  3991.587733314837,
  4313.553005077918,
  4172.875419017807,
  3773.097577958089,
  3432.6425003324994],
 [56.351791530944624,
  73.03822937625755,
  82.01438848920863,
  82.06388206388206,
  91.35514018691589,
  91.35514018691589,
  92.64305177111717,
  95.7683741648107,
  93.39285714285714,
  nan])

In [94]:
mouse = 'SNL_photo12'
dates = ['20200110', '20200111','20200113', '20200114','20200117', '20200118', '20200121', '20200123', '20200203', '20200205','20200212',
         '20200214', '20200218', '20200221', '20200224', '20200226', '20200303','20200304']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
percentage_correct_correlation(mouse, dates, all_peaks)

([5928.804603172931,
  4740.829983985563,
  4774.027877120693,
  4354.202478259092,
  4553.76342590866,
  3991.118902688623,
  3970.8477307038324,
  3548.0817949265847,
  3310.5508014928273,
  2964.1545029600447,
  3386.1849929841637,
  2382.3109841000787,
  2816.889811094379,
  3128.7789504668344,
  2750.9372497632517,
  3578.08318198894,
  2637.3458931306823,
  2394.1259818590106],
 [42.47787610619469,
  44.140625,
  44.0,
  48.53932584269663,
  66.52078774617068,
  66.74008810572687,
  83.87096774193549,
  85.60311284046692,
  93.5632183908046,
  94.05034324942791,
  93.44262295081967,
  96.29629629629629,
  94.75890985324948,
  96.51474530831099,
  95.12893982808023,
  92.85714285714286,
  91.96675900277009,
  95.99198396793587])

In [39]:
mouse = 'SNL_photo15'
dates = ['20200206','20200208', '20200210',  '20200213', '20200218', '20200221', '20200224','20200226','20200303','20200305', '20200307','20200310']
ipsi_or_contra = 'contra'
all_peaks = multi_day_peaks(mouse, dates, ipsi_or_contra)
percentage_correct_correlation(mouse, dates, all_peaks)

([5090.672488304103,
  5617.684968719536,
  5419.117791986009,
  4867.167465959169,
  4739.1331907609165,
  3205.018610899983,
  4249.089818358915,
  5450.093078657248,
  4525.470425104197,
  3817.995801994273,
  5306.319464833053,
  4280.415117432705],
 [52.58620689655172,
  68.06083650190114,
  79.4820717131474,
  87.41935483870968,
  89.1566265060241,
  89.1025641025641,
  89.21832884097034,
  92.16710182767623,
  94.95192307692308,
  93.28165374677003,
  93.42327150084317,
  91.72413793103448])