# COBAR Project 

Authors: Célia Benquet, Artur Jesslen & Léa Schmidt

## Import library

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import seaborn as sns
from scipy import signal

### Utility functions

In [None]:
def color_background(ax, stage, vertical = True):
        label_on = 0
        label_off = 0
        bornes = [[stage[stage == label].index[id_] for id_ in [0,-1]] for label in stage.unique()]
        
        for born, stage_name in zip(bornes, stage.unique()):
            if stage_name.startswith('on'):
                if vertical:
                    ax.axvspan(born[0], born[1], facecolor='g', alpha=0.2, label =  "_"*label_on + "Light on")
                else:
                    ax.axhspan(born[0], born[1], facecolor='g', alpha=0.2, label =  "_"*label_on + "Light on")
                label_on += 1
            else:
                if vertical:
                    ax.axvspan(born[0], born[1], facecolor='r', alpha=0.2, label =  "_"*label_off + "Light off")
                else:
                    ax.axhspan(born[0], born[1], facecolor='r', alpha=0.2, label =  "_"*label_off + "Light off")
                label_off += 1

def sub_to_itself(data_frame):
    return data_frame.subtract(data_frame.copy().shift())

def compute_frequency(data, moving_average = False, width_ma = None, dt = 1/80):
    freqs = np.zeros_like(data)
    for i in range(data.shape[0]):
        local_max = np.r_[True, data[i][1:] < data[i][:-1]] & np.r_[data[i][:-1] < data[i][1:], True]
        time = np.where(local_max == True)[0] * dt
        period = np.roll(time, -1) - time
        freq = local_max.copy().astype(int)
        freq[np.where(local_max == True)[0]] = 1 / period
        freqs[i,:] = fill_zeros_with_last(freq)
        if moving_average:
            ma_filter = np.ones(width_ma) / width_ma
            freqs[i,:] = np.convolve(freqs[i,:], ma_filter, 'same')
    return freqs

def fill_zeros_with_last(arr):
    prev = np.arange(len(arr))
    prev[arr == 0] = 0
    prev = np.maximum.accumulate(prev)
    return arr[prev]

## Class experiment

In [None]:
class Experiment:
    def __init__(self, mov_av = 30):
        self.load_data()
        self.stage_name = ['off0', 'on0', 'off1', 'on1','off2', 'on2', 'off3']
        self.claws_name = ['LFclaw', 'LMclaw', 'LHclaw', 'RFclaw', 'RMclaw', 'RHclaw']
        self.dt = 1. / 80
        self.pixel_to_mm = 32. / 832
        self.width_mov_av_claw = mov_av
        self.width_mov_av_speed = mov_av
        self.width_mov_av_ang_speed = mov_av
        self.width_mov_av_claw_freq = mov_av
        
        self.include_stage_and_update_data()
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
        
        self.flies = []
        self.speeds_MDN = np.array([[]]) 
        self.speeds_SS = np.array([[]]) 
        self.speeds_PR = np.array([[]])
        self.claws_MDN_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_MDN_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_SS_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_SS_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_PR_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_PR_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        
    def add_fly(self, fly):
        self.flies.append(fly)
        
    def include_stage_and_update_data(self, off0 = True, on0  = True, off1  = True, on1  = True, off2  = True, on2  = True, off3  = True):
        self.stage_activate = [off0, on0, off1, on1, off2, on2, off3]
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
    
    def load_data(self):
        path_MDN = 'data/MDN/U3_f'
        path_SS = 'data/SS01540/U3_f'
        path_PR = 'data/PR/U3_f'
        with open(path_MDN + '/MDN_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_MDN = pickle.load(f).reset_index()
        with open(path_SS + '/SS01540_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_SS = pickle.load(f).reset_index()
        with open(path_PR + '/PR_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_PR = pickle.load(f).reset_index()
            
    def prepare_data(self, data_raw):
        replicates_raw_sorted = []
        replicates = [data_raw[data_raw['replicate'] == i] for i in range(1, 13)]
        for replicate in replicates:
            stage_sorted = pd.DataFrame([])
            for stage, stage_is_included in zip(self.stage_name, self.stage_activate):
                if stage_is_included:
                    stage_sorted = stage_sorted.append(replicate[replicate['exp_stage'] == stage], ignore_index = True)
            stage_sorted = stage_sorted.reset_index()
            replicates_raw_sorted.append(stage_sorted)
        return replicates_raw_sorted

    def plot_positions(self, fly = None, MDN = False, SS = False, PR = False):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_positions(MDN, SS, PR)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_positions(MDN, SS, PR)
                
    def compute_pos_and_speed(self, moving_average = False):
        self.speeds_MDN = np.array([[]]) 
        self.speeds_SS = np.array([[]]) 
        self.speeds_PR = np.array([[]])
        for fly_ in self.flies:
            all_MDN, all_SS, all_PR = fly_.compute_pos_and_speed(self, moving_average)
            self.speeds_MDN = np.concatenate((self.speeds_MDN, all_MDN), axis=0) if self.speeds_MDN.size else all_MDN
            self.speeds_SS = np.concatenate((self.speeds_SS, all_SS), axis=0) if self.speeds_SS.size else all_SS
            self.speeds_PR = np.concatenate((self.speeds_PR, all_PR), axis=0) if self.speeds_PR.size else all_PR
            
            
    def compute_claws(self, moving_average = False):
        for fly_ in self.flies:
            fly_.compute_claws(self, moving_average)

    def compute_frequencies(self, moving_average = False):
        self.claws_MDN_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_MDN_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_SS_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_SS_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_PR_x = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        self.claws_PR_y = [np.array([[]])]*len(self.claws_name) # list of size #claws * (#fly*#replicates/fly) * #nb_data
        for claw in range(len(self.claws_name)):
            for fly_ in self.flies:
                #Sorry for the following lines.... :(
                claws_MDN_x, claws_MDN_y, claws_SS_x, claws_SS_y, claws_PR_x, claws_PR_y = fly_.compute_frequencies(claw)
                self.claws_MDN_x[claw] = np.concatenate((self.claws_MDN_x[claw], compute_frequency(claws_MDN_x, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_MDN_x[claw].size else compute_frequency(claws_MDN_x, moving_average, self.width_mov_av_claw_freq)
                self.claws_MDN_y[claw] = np.concatenate((self.claws_MDN_y[claw], compute_frequency(claws_MDN_y, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_MDN_y[claw].size else compute_frequency(claws_MDN_y, moving_average, self.width_mov_av_claw_freq)
                self.claws_SS_x[claw] = np.concatenate((self.claws_SS_x[claw], compute_frequency(claws_SS_x, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_SS_x[claw].size else compute_frequency(claws_SS_x, moving_average, self.width_mov_av_claw_freq)
                self.claws_SS_y[claw] = np.concatenate((self.claws_SS_y[claw], compute_frequency(claws_SS_y, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_SS_y[claw].size else compute_frequency(claws_SS_y, moving_average, self.width_mov_av_claw_freq)
                self.claws_PR_x[claw] = np.concatenate((self.claws_PR_x[claw], compute_frequency(claws_PR_x, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_PR_x[claw].size else compute_frequency(claws_SS_y, moving_average, self.width_mov_av_claw_freq)
                self.claws_PR_y[claw] = np.concatenate((self.claws_PR_y[claw], compute_frequency(claws_PR_y, moving_average, self.width_mov_av_claw_freq)), axis=0) if self.claws_PR_y[claw].size else compute_frequency(claws_PR_y, moving_average, self.width_mov_av_claw_freq)
        
            
    def plot_speeds(self, fly = None, MDN = False, SS = False, PR = False, display_background = True):
        if fly is None:
            self.plot_speeds_stats(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_speeds(MDN, SS, PR, display_background)
                
    def plot_speeds_stats(self, MDN, SS, PR, display_background): 
        if MDN:
            mean, sigma = self.speeds_MDN.mean(axis=0), self.speeds_MDN.std(axis=0)
            fram = np.arange(len(mean))
            plt.figure(figsize=(40, 10))
            plt.plot(fram, mean, lw=2, color='r')
            plt.fill_between(fram, mean+sigma, mean-sigma, facecolor='r', alpha=0.5)
            plt.axhline(linewidth=1, color='grey')
            plt.xlabel('Frame number')
            plt.ylabel('Speed')
            plt.title('Forward/backward speed (MDN data)', fontsize=16)
            if display_background:
                color_background(plt.gca(), self.replicates_MDN_sorted[0]['exp_stage'])
            plt.show()
        if SS:
            mean, sigma = self.speeds_SS.mean(axis=0), self.speeds_SS.std(axis=0)
            fram = np.arange(len(mean))
            plt.figure(figsize=(40, 10))
            plt.plot(fram, mean, lw=2, color='g')
            plt.fill_between(fram, mean+sigma, mean-sigma, facecolor='g', alpha=0.5)
            plt.axhline(linewidth=1, color='grey')
            plt.xlabel('Frame number')
            plt.ylabel('Speed')
            plt.title('Forward/backward speed (SS data)', fontsize=16)
            if display_background:
                color_background(plt.gca(), self.replicates_SS_sorted[0]['exp_stage'])
            plt.show()
        if PR:
            mean, sigma = self.speeds_PR.mean(axis=0), self.speeds_PR.std(axis=0)
            fram = np.arange(len(mean))
            plt.figure(figsize=(40, 10))
            plt.plot(fram, mean, lw=2, color='b')
            plt.fill_between(fram, mean+sigma, mean-sigma, facecolor='b', alpha=0.5)
            plt.axhline(linewidth=1, color='grey')
            plt.xlabel('Frame number')
            plt.ylabel('Speed')
            plt.title('Forward/backward speed (PR data)', fontsize=16)
            if display_background:
                color_background(plt.gca(), self.replicates_PR_sorted[0]['exp_stage'])
            plt.show()
    
    def plot_angular_speeds(self, fly = None, MDN = False, SS = False, PR = False, display_background = True):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_angular_speeds(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_angular_speeds(MDN, SS, PR, display_background)

           
    def plot_claws(self, fly = None, MDN = False, SS = False, PR = False, display_background = False):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_claws(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_claws(MDN, SS, PR, display_background)
        
    def plot_claw_frequeny(self, MDN = False, SS = False, PR = False, display_background = False):
        self.plot_frequency_stats(MDN, SS, PR, display_background)
                
    def plot_frequency_stats(self, MDN, SS, PR, display_background): 
        if MDN:
            fig, axes = plt.subplots(1, len(self.claws_MDN_x), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_MDN_x):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='r')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='r', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - x axis (MDN data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_MDN_sorted[0]['exp_stage'])
            
            fig, axes = plt.subplots(1, len(self.claws_MDN_y), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_MDN_y):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='r')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='r', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - y axis (MDN data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_MDN_sorted[0]['exp_stage'])
            plt.show()
        if SS:
            fig, axes = plt.subplots(1, len(self.claws_SS_x), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_SS_x):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='g')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='g', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - x axis (SS data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_SS_sorted[0]['exp_stage'])
            
            fig, axes = plt.subplots(1, len(self.claws_SS_y), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_SS_y):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='g')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='g', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - y axis (SS data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_SS_sorted[0]['exp_stage'])
            plt.show()
        if PR:
            fig, axes = plt.subplots(1, len(self.claws_PR_x), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_PR_x):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='b')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='b', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - x axis (PR data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_PR_sorted[0]['exp_stage'])
            
            fig, axes = plt.subplots(1, len(self.claws_PR_y), figsize = (60, 10))
            for i, member_data in enumerate(self.claws_PR_y):
                mean, sigma = member_data.mean(axis=0), member_data.std(axis=0)
                fram = np.arange(len(mean))
                axes[i].plot(fram, mean, lw=2, color='b')
                axes[i].fill_between(fram, mean+sigma, mean-sigma, facecolor='b', alpha=0.5)
                axes[i].axhline(linewidth=1, color='grey')
                axes[i].set_xlabel('Frame number')
                axes[i].set_ylabel('Frequency [Hz]')
                axes[i].set_title('{} frequency - y axis (PR data)'.format(self.claws_name[i]), fontsize=16)
                if display_background:
                    color_background(axes[i], self.replicates_PR_sorted[0]['exp_stage'])
            plt.show()

## Class Fly

In [None]:
class Fly:
    def __init__(self, exp, id_):
        self.id = id_
        
        self.replicates_nb_MDN = exp.data_MDN[exp.data_MDN['fly'] == id_]['replicate'].unique()
        self.replicates_nb_SS = exp.data_SS[exp.data_SS['fly'] == id_]['replicate'].unique()
        self.replicates_nb_PR = exp.data_PR[exp.data_PR['fly'] == id_]['replicate'].unique()
        
        self.replicates_MDN = [exp.replicates_MDN_sorted[i-1] for i in self.replicates_nb_MDN]
        self.replicates_SS = [exp.replicates_SS_sorted[i-1] for i in self.replicates_nb_SS]
        self.replicates_PR = [exp.replicates_PR_sorted[i-1] for i in self.replicates_nb_PR]
        
        self.compute_pos_and_speed(exp)
        
        self.claws_name = ['LFclaw', 'LMclaw', 'LHclaw', 'RFclaw', 'RMclaw', 'RHclaw']
        #self.claws_name = ['LFtibiaTarsus', 'LMtibiaTarsus', 'LHtibiaTarsus', 'RFtibiaTarsus', 'RMtibiaTarsus', 'RHtibiaTarsus']
        self.compute_claws(exp)
    
    def plot_positions(self, MDN = False, SS = False, PR = False):
        if MDN:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Position (MDN data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_MDN)), self.replicates_nb_MDN):
                plt.subplot(1, len(self.replicates_nb_MDN), i + 1) 
                plt.plot(self.positions_MDN[i][0], self.positions_MDN[i][1], color= 'r')
                plt.xlim(0, 832)
                plt.ylim(0, 832)
                plt.gca().invert_yaxis()
                plt.xlabel('X-axis')
                plt.ylabel('Y-axis')
                plt.title('Fly {}: replicate {}'.format(self.id,  replicate_nb), fontsize=16)
            plt.show()
        if SS:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Position (SS data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_SS)), self.replicates_nb_SS):
                plt.subplot(1, len(self.replicates_nb_SS), i + 1) 
                plt.plot(self.positions_SS[i][0], self.positions_SS[i][1], color= 'g')
                plt.xlim(0, 832)
                plt.ylim(0, 832)
                plt.gca().invert_yaxis()
                plt.xlabel('X-axis')
                plt.ylabel('Y-axis')
                plt.title('Fly {}: replicate {}'.format(self.id,  replicate_nb), fontsize=16)
            plt.show()
        if PR:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Position (PR data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_PR)), self.replicates_nb_PR):
                plt.subplot(1, len(self.replicates_nb_PR), i + 1) 
                plt.plot(self.positions_PR[i][0], self.positions_PR[i][1], color= 'b')
                plt.xlim(0, 832)
                plt.ylim(0, 832)
                plt.gca().invert_yaxis()
                plt.xlabel('X-axis')
                plt.ylabel('Y-axis')
                plt.title('Fly {}: replicate {}'.format(self.id,  replicate_nb), fontsize=16)
            plt.show()
            
    def plot_speeds(self, MDN = False, SS = False, PR = False, display_background = False):
        if MDN:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Forward/Backward Speed (MDN data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_MDN)), self.replicates_nb_MDN):
                plt.subplot(1, len(self.replicates_nb_MDN), i + 1) 
                plt.plot(range(self.speeds_MDN[i].shape[0]), self.speeds_MDN[i], color= 'r')
                plt.xlabel('Frame number')
                plt.ylabel('Speed (Fly longitudinal axis)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_speed if self.moving_average_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='grey')
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'])
                    plt.legend()
            plt.show()
        if SS:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Forward/Backward Speed (SS data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_SS)), self.replicates_nb_SS):
                plt.subplot(1, len(self.replicates_nb_SS), i + 1) 
                plt.plot(range(self.speeds_SS[i].shape[0]), self.speeds_SS[i], color= 'g')
                plt.xlabel('Frame number')
                plt.ylabel('Speed (Fly longitudinal axis)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_speed if self.moving_average_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='k')
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'])
                    plt.legend()
            plt.show()
        if PR:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Forward/Backward Speed (PR data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_PR)), self.replicates_nb_PR):
                plt.subplot(1, len(self.replicates_nb_PR), i + 1) 
                plt.plot(range(self.speeds_PR[i].shape[0]), self.speeds_PR[i], color= 'b')
                plt.xlabel('Frame number')
                plt.ylabel('Speed (Fly longitudinal axis)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_speed if self.moving_average_ang_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='k')
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'])
                    plt.legend()
            plt.show()
    
    def plot_angular_speeds(self, MDN = False, SS = False, PR = False, display_background = False):
        if MDN:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Angular Speed (MDN data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_MDN)), self.replicates_nb_MDN):
                plt.subplot(1, len(self.replicates_nb_MDN), i + 1) 
                plt.plot(range(self.ang_speeds_MDN[i].shape[0]), self.ang_speeds_MDN[i], color= 'r')
                plt.xlabel('Frame number')
                plt.ylabel('Angular speed (deg/s)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_ang_speed if self.moving_average_ang_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='grey')
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'])
                    plt.legend()
            plt.show()
        if SS:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Angular Speed (SS data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_SS)), self.replicates_nb_SS):
                plt.subplot(1, len(self.replicates_nb_SS), i + 1) 
                plt.plot(range(self.ang_speeds_SS[i].shape[0]), self.ang_speeds_SS[i], color= 'g')
                plt.xlabel('Frame number')
                plt.ylabel('Angular speed (deg/s)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_ang_speed if self.moving_average_ang_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='k')
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'])
                    plt.legend()
            plt.show()
        if PR:
            plt.figure(figsize=(40, 10))
            plt.suptitle('Angular Speed (PR data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_PR)), self.replicates_nb_PR):
                plt.subplot(1, len(self.replicates_nb_PR), i + 1) 
                plt.plot(range(self.ang_speeds_PR[i].shape[0]), self.ang_speeds_PR[i], color= 'b')
                plt.xlabel('Frame number')
                plt.ylabel('Angular speed (deg/s)')
                plt.title('Fly {}: replicate {} - moving average width : {}'.format(self.id,  replicate_nb, self.width_mov_av_ang_speed if self.moving_average_ang_speed else 'None'), fontsize=16)
                plt.axhline(linewidth=1, color='k')
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'])
                    plt.legend()
            plt.show()
            
    def plot_claws(self, MDN = False, SS = False, PR = False, display_background = False):
        if MDN:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (MDN data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_MDN)), self.replicates_nb_MDN):
                plt.subplot(2, len(self.replicates_nb_MDN), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_MDN[i][j][1].shape[0]), self.positions_claws_MDN[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_MDN), i + len(self.replicates_nb_MDN) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_MDN[i][j][0], range(self.positions_claws_MDN[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'], False)
                plt.legend()
            plt.show()
        if SS:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (SS data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_SS)), self.replicates_nb_SS):
                plt.subplot(2, len(self.replicates_nb_SS), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_SS[i][j][1].shape[0]), self.positions_claws_SS[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_SS), i + len(self.replicates_nb_SS) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_SS[i][j][0], range(self.positions_claws_SS[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'], False)
                plt.legend()
            plt.show()
        if PR:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (PR data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_PR)), self.replicates_nb_PR):
                plt.subplot(2, len(self.replicates_nb_PR), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_PR[i][j][1].shape[0]), self.positions_claws_PR[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                plt.legend(self.claws_name)
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_PR), i + len(self.replicates_nb_PR) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_PR[i][j][0], range(self.positions_claws_PR[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'], False)
                plt.legend()
            plt.show()
                
    def compute_pos_and_speed(self, exp, moving_average = False):
        self.positions_MDN = self.compute_position(self.replicates_MDN)
        self.positions_SS = self.compute_position(self.replicates_SS)
        self.positions_PR = self.compute_position(self.replicates_PR)
        
        self.speeds_MDN = self.compute_speed(self.replicates_MDN, exp, moving_average)
        self.speeds_SS = self.compute_speed(self.replicates_SS, exp, moving_average)
        self.speeds_PR = self.compute_speed(self.replicates_PR, exp, moving_average)
        
        self.ang_speeds_MDN = self.compute_angular_speed(self.replicates_MDN, exp, moving_average)
        self.ang_speeds_SS = self.compute_angular_speed(self.replicates_SS, exp, moving_average)
        self.ang_speeds_PR = self.compute_angular_speed(self.replicates_PR, exp, moving_average)
        
        # Used to compute mean and variance
        all_speed_MDN = []
        all_speed_SS = []
        all_speed_PR = []
        nb_data_MDN = min([len(data) for data in self.speeds_MDN])
        nb_data_PR = min([len(data) for data in self.speeds_PR])
        for speed_MDN, speed_PR in zip(self.speeds_MDN, self.speeds_PR):
            all_speed_MDN.append(speed_MDN[:nb_data_MDN])
            all_speed_PR.append(speed_PR[:nb_data_PR])
        nb_data_SS = min([len(data) for data in self.speeds_SS])
        for speed_SS in self.speeds_SS:            
            all_speed_SS.append(speed_SS[:nb_data_SS])
        return (np.asarray(all_speed_MDN), np.asarray(all_speed_SS), np.asarray(all_speed_PR))
    
    def compute_claws(self, exp, moving_average = False):
        self.positions_claws_MDN = self.compute_position_claws(self.replicates_MDN, exp, moving_average)
        self.positions_claws_SS = self.compute_position_claws(self.replicates_SS, exp, moving_average)
        self.positions_claws_PR = self.compute_position_claws(self.replicates_PR, exp, moving_average)
        
    def compute_frequencies(self, claw = 0):
        # Used to compute frequency, mean and variance
        claws_MDN_x, claws_MDN_y = ([], [])
        claws_SS_x, claws_SS_y = ([], [])
        claws_PR_x, claws_PR_y = ([], [])
        nb_data = min([len(data[0][0]) for data in self.positions_claws_MDN] \
                      + [len(data[0][0]) for data in self.positions_claws_SS] \
                      + [len(data[0][0]) for data in self.positions_claws_PR])
        for claw_MDN, claw_SS in zip(self.positions_claws_MDN, self.positions_claws_SS):
            claws_MDN_x.append(claw_MDN[claw][0][:nb_data])
            claws_MDN_y.append(claw_MDN[claw][1][:nb_data])
            claws_SS_x.append(claw_SS[claw][0][:nb_data])
            claws_SS_y.append(claw_SS[claw][1][:nb_data])
        for claw_PR in self.positions_claws_PR:
            claws_PR_x.append(claw_MDN[claw][0][:nb_data])
            claws_PR_y.append(claw_MDN[claw][1][:nb_data])
        claws_MDN_x = np.asarray(claws_MDN_x)
        claws_MDN_y = np.asarray(claws_MDN_y)
        claws_SS_x = np.asarray(claws_SS_x)
        claws_SS_y = np.asarray(claws_SS_y)
        claws_PR_x = np.asarray(claws_PR_x)
        claws_PR_y = np.asarray(claws_PR_y)
        return claws_MDN_x, claws_MDN_y, claws_SS_x, claws_SS_y, claws_PR_x, claws_PR_y
        
    def compute_position(self, replicates):
        return [[replicate['center','pos' + ax] for ax in ['x', 'y']] for replicate in replicates]
    
    def compute_speed(self, replicates, exp, moving_average = False):
        self.width_mov_av_speed = exp.width_mov_av_speed
        self.moving_average_speed = moving_average
        positions = [[sub_to_itself(replicate['center','pos' + ax]) for ax in ['x', 'y']] for replicate in replicates]
        fly_orientations = [replicate['center','orientation'] for replicate in replicates]
        
        speed_magnitudes = [np.sqrt(pos[0]**2 + pos[1]**2) * exp.pixel_to_mm / exp.dt for pos in positions]
        speed_angles = [(np.arctan2(pos[0], -pos[1]) * 180 / np.pi + 360) % 360  for pos in positions]
        angles = [(speed_angle - fly_orientation + 180) % 360 - 180 for speed_angle, fly_orientation in zip(speed_angles, fly_orientations)]

        if self.moving_average_speed:
            ma_filter = np.ones(self.width_mov_av_speed) / self.width_mov_av_speed
            speed_magnitudes = [np.convolve(speed_magnitude, ma_filter, 'same') for speed_magnitude in speed_magnitudes]
            projection_ratio = [np.convolve(np.cos(angle * np.pi/180), ma_filter, 'same') for angle in angles]
            return [ratio * speed_magnitude for ratio, speed_magnitude in zip(projection_ratio, speed_magnitudes)]
        
        return [np.cos(angle * np.pi/180) * speed_magnitude for angle, speed_magnitude in zip(angles, speed_magnitudes)]
    
    def compute_angular_speed(self, replicates, exp, moving_average = False):
        self.width_mov_av_ang_speed = exp.width_mov_av_ang_speed
        self.moving_average_ang_speed = moving_average
        
        fly_angular_speeds = [(sub_to_itself(replicate['center','orientation'])+540)%360-180 for replicate in replicates]

        if self.moving_average_ang_speed:
            ma_filter = np.ones(self.width_mov_av_ang_speed) / self.width_mov_av_ang_speed
            return [np.convolve(fly_ang_speed, ma_filter, 'same') for fly_ang_speed in fly_angular_speeds]
        
        return fly_angular_speeds

    def compute_position_claws(self, replicates, exp, moving_average = False):
        self.width_mov_av_claw = exp.width_mov_av_claw
        self.moving_average_claw = moving_average
        if moving_average:
            ma_filter = np.ones(self.width_mov_av_claw) / self.width_mov_av_claw
            return [[[np.convolve(replicate[claws, ax], ma_filter, 'same') for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
        return [[[replicate[claws, ax] for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
    

## Data analysis

In [None]:
# Create experiment and load data 
exp = Experiment(mov_av = 30)

In [None]:
# Include only stage with true value and update data
exp.include_stage_and_update_data(off0 = True, on0 = True, \
                                  off1 = True, on1 = True, \
                                  off2 = True, on2 = True, \
                                  off3 = True)

In [None]:
# Create differents flies with given id
exp.add_fly(Fly(exp, id_ = 0))
exp.add_fly(Fly(exp, id_ = 1))
exp.add_fly(Fly(exp, id_ = 2))

In [None]:
#exp.replicates_MDN_sorted[0].head()


## 2. Question 1a

#### Fly position

#### i. MDN data

In [None]:
exp.plot_positions(fly = [0,1,2], MDN = True) # fly can be None, or a list of fly id (i.e. [0,2])

#### ii. SS data

In [None]:
exp.plot_positions(fly = [0,1,2], SS = True) # fly can be None, or a list of fly id (i.e. [0,2])

#### iii. PR data

In [None]:
exp.plot_positions(fly = [0,1,2], PR = True) # fly can be None, of a list of fly id (i.e. [0,2])

## 2. Question 1b

##### the forward / backward speed over time

In [None]:
moving_average = True
exp.width_mov_av_speed = 100
display_background = True 

In [None]:
exp.compute_pos_and_speed(moving_average = moving_average)

#### i. Data statistics

In [None]:
# fly = None will plot statistics
exp.plot_speeds(fly = None, MDN = True, SS = True, PR = True)

#### ii. MDN data

In [None]:
# fly can be None (statistics), or a list of fly id (i.e. [0,2])
exp.plot_speeds(fly = [0,1,2], MDN = True, display_background = display_background)

#### iii. SS data

In [None]:
# fly can be None (statistics), or a list of fly id (i.e. [0,2])
exp.plot_speeds(fly = [0,1,2], SS = True, display_background = display_background)

#### iv. PR data

In [None]:
# fly can be None (statistics), or a list of fly id (i.e. [0,2])
exp.plot_speeds(fly = [0,1,2], PR = True, display_background = display_background)

the angular speed over time

In [None]:
moving_average = True
exp.width_mov_av_ang_speed = 20
display_background = True 

In [None]:
exp.compute_pos_and_speed(moving_average = moving_average)

#### i. MDN data

In [None]:
# fly can be None, or a list of fly id (i.e. [0,2])
exp.plot_angular_speeds(fly = [0,1,2], MDN = True, display_background = display_background)

#### ii. SS data

In [None]:
# fly can be None, or a list of fly id (i.e. [0,2])
exp.plot_angular_speeds(fly = [0,1,2], SS = True, display_background = display_background)

#### iii. PR data

In [None]:
# fly can be None, or a list of fly id (i.e. [0,2])
exp.plot_angular_speeds(fly = [0,1,2], PR = True, display_background = display_background)

## Question 2.a

#### Claw position 

In [None]:
moving_average = False
exp.width_mov_av_claw = 20
display_background = True

In [None]:
exp.compute_claws(moving_average = moving_average)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0,1,2], MDN = True, display_background = display_background)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0,1,2], SS = True, display_background = display_background)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0,1,2], PR = True, display_background = display_background)

#### Claw frequency

In [None]:
moving_average = False
exp.width_mov_av_claw = 40

In [None]:
exp.compute_claws(moving_average = moving_average)

In [None]:
moving_average = True
exp.width_mov_av_claw_freq = 20
display_background = True

In [None]:
exp.compute_frequencies(moving_average = moving_average)

In [None]:
exp.plot_claw_frequeny(MDN = True, display_background = display_background)

In [None]:
exp.plot_claw_frequeny(SS = True, display_background = display_background)

In [None]:
exp.plot_claw_frequeny(PR = True, display_background = display_background)