# COBAR Project 

Authors: Célia Benquet, Artur Jesslen & Léa Schmidt

## Import library

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import seaborn as sns
from scipy import signal

### Utility functions

In [None]:
def color_background(ax, stage, vertical = True):
        label_on = 0
        label_off = 0
        bornes = [[stage[stage == label].index[id_] for id_ in [0,-1]] for label in stage.unique()]
        
        for born, stage_name in zip(bornes, stage.unique()):
            if stage_name.startswith('on'):
                if vertical:
                    ax.axvspan(born[0], born[1], facecolor='g', alpha=0.2, label =  "_"*label_on + "Light on")
                else:
                    ax.axhspan(born[0], born[1], facecolor='g', alpha=0.2, label =  "_"*label_on + "Light on")
                label_on += 1
            else:
                if vertical:
                    ax.axvspan(born[0], born[1], facecolor='r', alpha=0.2, label =  "_"*label_off + "Light off")
                else:
                    ax.axhspan(born[0], born[1], facecolor='r', alpha=0.2, label =  "_"*label_off + "Light off")
                label_off += 1

## Class experiment

In [None]:
class Experiment:
    def __init__(self, mov_av = 30):
        self.load_data()
        self.stage_name = ['off0', 'on0', 'off1', 'on1','off2', 'on2', 'off3']
        self.dt = 1. / 80
        self.pixel_to_mm = 32. / 832
        self.width_mov_av_claw = mov_av
        self.width_mov_av_speed = mov_av
        self.width_mov_av_ang_speed = mov_av
        
        self.include_stage_and_update_data()
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
        
        self.flies = []
        self.speeds_MDN = np.array([[]]) 
        self.speeds_SS = np.array([[]]) 
        self.speeds_PR = np.array([[]])
        
    def add_fly(self, fly):
        self.flies.append(fly)
        
    def include_stage_and_update_data(self, off0 = True, on0  = True, off1  = True, on1  = True, off2  = True, on2  = True, off3  = True):
        self.stage_activate = [off0, on0, off1, on1, off2, on2, off3]
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
    
    def load_data(self):
        path_MDN = 'data/MDN/U3_f'
        path_SS = 'data/SS01540/U3_f'
        path_PR = 'data/PR/U3_f'
        with open(path_MDN + '/MDN_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_MDN = pickle.load(f).reset_index()
        with open(path_SS + '/SS01540_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_SS = pickle.load(f).reset_index()
        with open(path_PR + '/PR_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_PR = pickle.load(f).reset_index()
            
    def prepare_data(self, data_raw):
        replicates_raw_sorted = []
        replicates = [data_raw[data_raw['replicate'] == i] for i in range(1, 13)]
        for replicate in replicates:
            stage_sorted = pd.DataFrame([])
            for stage, stage_is_included in zip(self.stage_name, self.stage_activate):
                if stage_is_included:
                    stage_sorted = stage_sorted.append(replicate[replicate['exp_stage'] == stage], ignore_index = True)
            stage_sorted = stage_sorted.reset_index()
            replicates_raw_sorted.append(stage_sorted)
        return replicates_raw_sorted


    def compute_claws(self, moving_average = False):
        for fly_ in self.flies:
            fly_.compute_claws(self, moving_average)
            
    def plot_claws(self, fly = None, MDN = False, SS = False, PR = False, display_background = False):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_claws(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_claws(MDN, SS, PR, display_background)
        

## Class Fly

In [None]:
class Fly:
    def __init__(self, exp, id_):
        self.id = id_
        
        self.replicates_nb_MDN = exp.data_MDN[exp.data_MDN['fly'] == id_]['replicate'].unique()
        self.replicates_nb_SS = exp.data_SS[exp.data_SS['fly'] == id_]['replicate'].unique()
        self.replicates_nb_PR = exp.data_PR[exp.data_PR['fly'] == id_]['replicate'].unique()
        
        self.replicates_MDN = [exp.replicates_MDN_sorted[i-1] for i in self.replicates_nb_MDN]
        self.replicates_SS = [exp.replicates_SS_sorted[i-1] for i in self.replicates_nb_SS]
        self.replicates_PR = [exp.replicates_PR_sorted[i-1] for i in self.replicates_nb_PR]
        
        #self.compute_pos_and_speed(exp)
        
        self.claws_name = ['LFclaw', 'LMclaw', 'LHclaw', 'RFclaw', 'RMclaw', 'RHclaw']
        #self.claws_name = ['LFtibiaTarsus', 'LMtibiaTarsus', 'LHtibiaTarsus', 'RFtibiaTarsus', 'RMtibiaTarsus', 'RHtibiaTarsus']
        self.compute_claws(exp)
    
    def sub_to_itself(self, data_frame):
        return data_frame.subtract(data_frame.copy().shift())
    
    def plot_claws(self, MDN = False, SS = False, PR = False, display_background = False):
        if MDN:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (MDN data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_MDN)), self.replicates_nb_MDN):
                plt.subplot(2, len(self.replicates_nb_MDN), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_MDN[i][j][1].shape[0]), self.positions_claws_MDN[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_MDN), i + len(self.replicates_nb_MDN) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_MDN[i][j][0], range(self.positions_claws_MDN[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_MDN[i]['exp_stage'], False)
                plt.legend()
            plt.show()
        if SS:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (SS data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_SS)), self.replicates_nb_SS):
                plt.subplot(2, len(self.replicates_nb_SS), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_SS[i][j][1].shape[0]), self.positions_claws_SS[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_SS), i + len(self.replicates_nb_SS) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_SS[i][j][0], range(self.positions_claws_SS[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_SS[i]['exp_stage'], False)
                plt.legend()
            plt.show()
        if PR:
            plt.figure(figsize=(40, 20))
            plt.suptitle('Claws position (PR data)', fontsize=32)
            for i, replicate_nb in zip(range(len(self.replicates_nb_PR)), self.replicates_nb_PR):
                plt.subplot(2, len(self.replicates_nb_PR), i + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(range(self.positions_claws_PR[i][j][1].shape[0]), self.positions_claws_PR[i][j][1], label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Frame number')
                plt.ylabel('Claws position (y axis)')
                plt.title('Fly {}: replicate {} - moving average width : {} (Y-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                plt.legend(self.claws_name)
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'])
                plt.legend()
                
                plt.subplot(2, len(self.replicates_nb_PR), i + len(self.replicates_nb_PR) + 1) 
                for j in range(len(self.claws_name)):
                    plt.plot(self.positions_claws_PR[i][j][0], range(self.positions_claws_PR[i][j][0].shape[0]), label=self.claws_name[j])
                plt.gca().invert_yaxis()
                plt.xlabel('Claws position (x axis)')
                plt.ylabel('Frame number')
                plt.title('Fly {}: replicate {} - moving average width : {} (X-axis)'.format(self.id,  replicate_nb, self.width_mov_av_claw if self.moving_average_claw else 'None'), fontsize=16)
                if display_background:
                    color_background(plt.gca(), self.replicates_PR[i]['exp_stage'], False)
                plt.legend()
            plt.show()

    def compute_claws(self, exp, moving_average = False):
        self.positions_claws_MDN = self.compute_position_claws(self.replicates_MDN, exp, moving_average)
        self.positions_claws_SS = self.compute_position_claws(self.replicates_SS, exp, moving_average)
        self.positions_claws_PR = self.compute_position_claws(self.replicates_PR, exp, moving_average)

    def compute_position_claws(self, replicates, exp, moving_average = False):
        self.width_mov_av_claw = exp.width_mov_av_claw
        self.moving_average_claw = moving_average
        if moving_average:
            ma_filter = np.ones(self.width_mov_av_claw) / self.width_mov_av_claw
            return [[[np.convolve(replicate[claws, ax], ma_filter, 'same') for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
        return [[[replicate[claws, ax] for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
    

## Data analysis

In [None]:
# Create experiment and load data 
exp = Experiment(mov_av = 30)

In [None]:
# Include only stage with true value and update data
exp.include_stage_and_update_data(off0 = True, on0 = True, \
                                  off1 = True, on1 = True, \
                                  off2 = True, on2 = True, \
                                  off3 = True)

In [None]:
# Create differents flies with given id
exp.add_fly(Fly(exp, id_ = 0))
exp.add_fly(Fly(exp, id_ = 1))
exp.add_fly(Fly(exp, id_ = 2))

In [None]:
#exp.replicates_MDN_sorted[0].head()


## Question 2.a

In [None]:
moving_average = False
exp.width_mov_av_claw = 20
display_background = True

In [None]:
exp.compute_claws(moving_average = moving_average)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0], MDN = True, display_background = display_background)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0], SS = True, display_background = display_background)

In [None]:
# fly can be None, of a list of fly id (i.e. [0,2])
exp.plot_claws(fly = [0], PR = True, display_background = display_background)

In [None]:
# Used to compute mean and variance for member 0
all_claws_MDN_x = np.array([[]])
all_claws_MDN_y = np.array([[]])
for fly in range(len(exp.flies)):
    claws_MDN_x = []
    claws_MDN_y = []
    nb_data_MDN = min([len(data[0][0]) for data in exp.flies[fly].positions_claws_MDN])
    for claw_MDN in exp.flies[fly].positions_claws_MDN:
        claws_MDN_x.append(claw_MDN[1][0][:nb_data_MDN])
        claws_MDN_y.append(claw_MDN[1][1][:nb_data_MDN])
    claws_MDN_x = np.asarray(claws_MDN_x)
    claws_MDN_y = np.asarray(claws_MDN_y)
    all_claws_MDN_x = np.concatenate((all_claws_MDN_x, claws_MDN_x), axis=0) if all_claws_MDN_x.size else claws_MDN
    all_claws_MDN_y = np.concatenate((all_claws_MDN_y, claws_MDN_y), axis=0) if all_claws_MDN_y.size else claws_MDN


In [None]:
def find_local_max(data):
    return np.r_[True, data[1:] < data[:-1]] & np.r_[data[:-1] < data[1:], True]

In [None]:
def find_frequency(local_max, dt = 1/80):
    time = np.where(local_max == True)[0] * dt
    period = np.roll(time, -1) - time 
    freq = local_max.copy().astype(int)
    freq[np.where(local_max == True)[0]] = 1 / period
    return time[:-2], 1 / period

In [None]:
def get_frequency(data, dt = 1/80):
    local_max = np.r_[True, data[1:] < data[:-1]] & np.r_[data[:-1] < data[1:], True]
    time = np.where(local_max == True)[0] * dt
    period = np.roll(time, -1) - time 
    freq = 1 / period
    return time, freq

In [None]:
time, freq = find_frequency(find_local_max(all_claws_MDN_x[0]))
plt.plot(time, freq)

## Clustering 

In [None]:
data = exp.flies[0].positions_claws_MDN

In [None]:
print(len(data[0][0][0]))

In [None]:
exp.flies[0].positions_claws_MDN[0]