# COBAR Project 

Authors: Célia Benquet, Artur Jesslen & Léa Schmidt

## Import library

In [None]:
import pickle
import matplotlib.pyplot as plt
import pickle
import numpy as np
import pandas as pd 
import seaborn as sns
from scipy import signal

## Class experiment

In [None]:
class Experiment:
    def __init__(self, mov_av = 30):
        self.load_data()
        self.stage_name = ['off0', 'on0', 'off1', 'on1','off2', 'on2', 'off3']
        self.dt = 1. / 80
        self.pixel_to_mm = 32. / 832
        self.width_mov_av_claw = mov_av
        self.width_mov_av_speed = mov_av
        
        self.include_stage_and_update_data()
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
        
        self.flies = []
        
    def add_fly(self, fly):
        self.flies.append(fly)
        
    def include_stage_and_update_data(self, off0 = True, on0  = True, off1  = True, on1  = True, off2  = True, on2  = True, off3  = True):
        self.stage_activate = [off0, on0, off1, on1, off2, on2, off3]
        
        self.replicates_MDN_sorted = self.prepare_data(self.data_MDN)
        self.replicates_SS_sorted = self.prepare_data(self.data_SS)
        self.replicates_PR_sorted = self.prepare_data(self.data_PR)
    
    def load_data(self):
        path_MDN = 'data/MDN/U3_f'
        path_SS = 'data/SS01540/U3_f'
        path_PR = 'data/PR/U3_f'
        with open(path_MDN + '/MDN_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_MDN = pickle.load(f).reset_index()
        with open(path_SS + '/SS01540_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_SS = pickle.load(f).reset_index()
        with open(path_PR + '/PR_U3_f_trackingData.pkl', 'rb') as f: 
            self.data_PR = pickle.load(f).reset_index()
            
    def prepare_data(self, data_raw):
        replicates_raw_sorted = []
        replicates = [data_raw[data_raw['replicate'] == i] for i in range(1, 13)]
        for replicate in replicates:
            stage_sorted = pd.DataFrame([])
            for stage, stage_is_included in zip(self.stage_name, self.stage_activate):
                if stage_is_included:
                    stage_sorted = stage_sorted.append(replicate[replicate['exp_stage'] == stage], ignore_index = True)
            stage_sorted = stage_sorted.reset_index()
            replicates_raw_sorted.append(stage_sorted)
        return replicates_raw_sorted

    def plot_positions(self, fly = None, MDN = False, SS = False, PR = False):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_positions(MDN, SS, PR)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_positions(MDN, SS, PR)
                
    def compute_pos_and_speed(self, moving_average = False):
        self.speeds_MDN = np.array([[]]) 
        self.speeds_SS = np.array([[]]) 
        self.speeds_PR = np.array([[]]) 
        for fly_ in self.flies:
            all_MDN, all_SS, all_PR = fly_.compute_pos_and_speed(self, moving_average)
            self.speeds_MDN = np.concatenate((self.speeds_MDN, all_MDN), axis=0) if self.speeds_MDN.size else all_MDN
            self.speeds_SS = np.concatenate((self.speeds_SS, all_SS), axis=0) if self.speeds_SS.size else all_SS
            self.speeds_PR = np.concatenate((self.speeds_PR, all_PR), axis=0) if self.speeds_PR.size else all_PR
        self.speeds_MDN = np.delete(self.speeds_MDN, (0), axis=0)
        self.speeds_SS = np.delete(self.speeds_SS, (0), axis=0)
        self.speeds_PR = np.delete(self.speeds_PR, (0), axis=0)
            
    def compute_claws(self, moving_average = False):
        for fly_ in self.flies:
            fly_.compute_claws(self, moving_average)
            
    def plot_speeds(self, fly = None, MDN = False, SS = False, PR = False, display_background = True):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_speeds(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_speeds(MDN, SS, PR, display_background)
                
    def plot_claws(self, fly = None, MDN = False, SS = False, PR = False, display_background = False):
        if fly is None:
            for fly_ in self.flies:
                fly_.plot_claws(MDN, SS, PR, display_background)
        else:
            for fly_nb in fly:
                self.flies[fly_nb].plot_claws(MDN, SS, PR, display_background)
        

## Class Fly

In [None]:
class Fly:
    def __init__(self, exp, id_):
        self.id = id_
        
        self.replicates_nb_MDN = exp.data_MDN[exp.data_MDN['fly'] == id_]['replicate'].unique()
        self.replicates_nb_SS = exp.data_SS[exp.data_SS['fly'] == id_]['replicate'].unique()
        self.replicates_nb_PR = exp.data_PR[exp.data_PR['fly'] == id_]['replicate'].unique()
        
        self.replicates_MDN = [exp.replicates_MDN_sorted[i-1] for i in self.replicates_nb_MDN]
        self.replicates_SS = [exp.replicates_SS_sorted[i-1] for i in self.replicates_nb_SS]
        self.replicates_PR = [exp.replicates_PR_sorted[i-1] for i in self.replicates_nb_PR]
        
        self.compute_pos_and_speed(exp)
        
        self.claws_name = ['LFclaw', 'LMclaw', 'LHclaw', 'RFclaw', 'RMclaw', 'RHclaw']
        #self.claws_name = ['LFtibiaTarsus', 'LMtibiaTarsus', 'LHtibiaTarsus', 'RFtibiaTarsus', 'RMtibiaTarsus', 'RHtibiaTarsus']
        self.compute_claws(exp)
        
    def compute_pos_and_speed(self, exp, moving_average = False):
        self.positions_MDN = self.compute_position(self.replicates_MDN)
        self.positions_SS = self.compute_position(self.replicates_SS)
        self.positions_PR = self.compute_position(self.replicates_PR)
        
        self.speeds_MDN = self.compute_speed(self.replicates_MDN, exp, moving_average)
        self.speeds_SS = self.compute_speed(self.replicates_SS, exp, moving_average)
        self.speeds_PR = self.compute_speed(self.replicates_PR, exp, moving_average)
        
        # Used to compute mean and variance
        all_speed_MDN = []
        all_speed_SS = []
        all_speed_PR = []
        nb_data_MDN = min([len(data) for data in self.speeds_MDN])
        nb_data_PR = min([len(data) for data in self.speeds_PR])
        for speed_MDN, speed_PR in zip(self.speeds_MDN, self.speeds_PR):
            all_speed_MDN.append(speed_MDN[:nb_data_MDN])
            all_speed_PR.append(speed_PR[:nb_data_PR])
        nb_data_SS = min([len(data) for data in self.speeds_SS])
        for speed_SS in self.speeds_SS:            
            all_speed_SS.append(speed_SS[:nb_data_SS])
        return (np.asarray(all_speed_MDN), np.asarray(all_speed_SS), np.asarray(all_speed_PR))
        
        
    
    def compute_claws(self, exp, moving_average = False):
        self.positions_claws_MDN = self.compute_position_claws(self.replicates_MDN, exp, moving_average)
        self.positions_claws_SS = self.compute_position_claws(self.replicates_SS, exp, moving_average)
        self.positions_claws_PR = self.compute_position_claws(self.replicates_PR, exp, moving_average)
        
    def compute_position(self, replicates):
        return [[replicate['center','pos' + ax] for ax in ['x', 'y']] for replicate in replicates]
    
    def compute_speed(self, replicates, exp, moving_average = False):
        self.width_mov_av_speed = exp.width_mov_av_speed
        self.moving_average_speed = moving_average
        positions = [[self.sub_to_itself(replicate['center','pos' + ax]) for ax in ['x', 'y']] for replicate in replicates]
        fly_orientations = [replicate['center','orientation'] for replicate in replicates]
        
        speed_magnitudes = [np.sqrt(pos[0]**2 + pos[1]**2) * exp.pixel_to_mm / exp.dt for pos in positions]
        speed_angles = [(np.arctan2(pos[0], -pos[1]) * 180 / np.pi + 360) % 360  for pos in positions]
        angles = [(speed_angle - fly_orientation + 180) % 360 - 180 for speed_angle, fly_orientation in zip(speed_angles, fly_orientations)]

        if self.moving_average_speed:
            ma_filter = np.ones(self.width_mov_av_speed) / self.width_mov_av_speed
            speed_magnitudes = [np.convolve(speed_magnitude, ma_filter, 'same') for speed_magnitude in speed_magnitudes]
            projection_ratio = [np.convolve(np.cos(angle * np.pi/180), ma_filter, 'same') for angle in angles]
            return [ratio * speed_magnitude for ratio, speed_magnitude in zip(projection_ratio, speed_magnitudes)]
        
        return [np.cos(angle * np.pi/180) * speed_magnitude for angle, speed_magnitude in zip(angles, speed_magnitudes)]

    
    def compute_position_claws(self, replicates, exp, moving_average = False):
        self.width_mov_av_claw = exp.width_mov_av_claw
        self.moving_average_claw = moving_average
        if moving_average:
            ma_filter = np.ones(self.width_mov_av_claw) / self.width_mov_av_claw
            return [[[np.convolve(replicate[claws, ax], ma_filter, 'same') for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
        return [[[replicate[claws, ax] for ax in ['x', 'y']] for claws in self.claws_name] for replicate in replicates]
    
    
    def sub_to_itself(self, data_frame):
        return data_frame.subtract(data_frame.copy().shift())


## Data analysis

In [None]:
# Create experiment and load data 
exp = Experiment(mov_av = 30)

In [None]:
# Include only stage with true value and update data
exp.include_stage_and_update_data(off0 = True, on0 = True, \
                                  off1 = True, on1 = True, \
                                  off2 = False, on2 = False, \
                                  off3 = False)

In [None]:
# Create differents flies with given id
exp.add_fly(Fly(exp, id_ = 0))
exp.add_fly(Fly(exp, id_ = 1))
exp.add_fly(Fly(exp, id_ = 2))

In [None]:
#exp.replicates_MDN_sorted[0].head()


## 2. Question 1a

## 2. Question 1b

the forward / backward speed over time

In [None]:
moving_average = True
exp.width_mov_av_speed = 100
display_background = True 

In [None]:
exp.compute_pos_and_speed(moving_average = moving_average)

#### i. MDN data

In [None]:
exp.plot_speeds(fly = [0], MDN = True, display_background = display_background)

#### ii. SS data

In [None]:
exp.plot_speeds(fly = [0], SS = True, display_background = display_background)

#### iii. PR data

In [None]:
exp.plot_speeds(fly = [0], PR = True, display_background = display_background)

## Question 2.a

In [None]:
moving_average = True
exp.width_mov_av_claw = 20
display_background = True

In [None]:
exp.compute_claws(moving_average = moving_average)

In [None]:
exp.plot_claws(fly = [0], MDN = True, display_background = display_background)

In [None]:
exp.plot_claws(fly = [0], SS = True, display_background = display_background)

In [None]:
exp.plot_claws(fly = [0], PR = True, display_background = display_background)

## Clustering 

In [None]:
exp.data_MDN.head()

In [None]:
data_clustering_MDN = exp.data_MDN[['LFclaw', 'LMclaw', 'LHclaw', 'RFclaw', 'RMclaw', 'RHclaw']]